root / ase / visualize / fieldplotter.py @ 1
Historique | Voir | Annoter | Télécharger (10,86 ko)
1 |
"""plotting fields defined on atoms during a simulation."""
|
---|---|
2 |
|
3 |
from ase.visualize.primiplotter import PostScriptFile, PnmFile, GifFile, JpegFile, X11Window |
4 |
from ase.visualize.primiplotter import PrimiPlotter as _PrimiPlotter |
5 |
import numpy |
6 |
import time |
7 |
|
8 |
class FieldPlotter(_PrimiPlotter): |
9 |
def __init__(self, atoms, datasource=None, verbose=0, timing=0, |
10 |
interval=1, initframe=0): |
11 |
_PrimiPlotter.__init__(self, atoms, verbose=verbose, timing=timing,
|
12 |
interval=interval, initframe=initframe) |
13 |
self.datasource = datasource
|
14 |
self.dims = (100,100) |
15 |
self.set_plot_plane("xy") |
16 |
self.set_data_range("plot") |
17 |
self.set_background(0.0) |
18 |
self.set_red_yellow_colors()
|
19 |
|
20 |
def set_plot_plane(self, plane): |
21 |
"""Set the plotting plane to xy, xz or yz (default: xy)"""
|
22 |
if plane in ("xy", "xz", "yz"): |
23 |
self.plane = plane
|
24 |
else:
|
25 |
raise ValueError, "The argument to plotPlane must be 'xy', 'xz' or 'yz'." |
26 |
|
27 |
def set_data_range(self, range1, range2=None): |
28 |
"""Set the range of the data used when coloring.
|
29 |
|
30 |
This function sets the range of data values mapped unto colors
|
31 |
in the final plot.
|
32 |
|
33 |
Three possibilities:
|
34 |
|
35 |
'data': Autoscale using the data on visible atoms.
|
36 |
The range goes from the lowest to the highest
|
37 |
value present on the atoms. If only a few atoms
|
38 |
have extreme values, the entire color range may not
|
39 |
be used on the plot, as many values may be averaged
|
40 |
on each point in the plot.
|
41 |
|
42 |
'plot': Autoscale using the data on the plot. Unlike 'data'
|
43 |
this guarantees that the entire color range is used.
|
44 |
|
45 |
min, max: Use the range [min, max]
|
46 |
|
47 |
"""
|
48 |
if (range1 == "data" or range1 == "plot") and range2 == None: |
49 |
self.autorange = range1
|
50 |
elif range2 != None: |
51 |
self.autorange = None |
52 |
self.range = (range1, range2)
|
53 |
else:
|
54 |
raise ValueError, "Illegal argument(s) to set_data_range" |
55 |
|
56 |
def set_background(self, value): |
57 |
"""Set the data value of the background. See also set_background_color
|
58 |
|
59 |
Set the value of the background (parts of the plot without atoms) to
|
60 |
a specific value, or to 'min' or 'max' representing the minimal or
|
61 |
maximal data values on the atoms.
|
62 |
|
63 |
Calling set_background cancels previous calls to set_background_color.
|
64 |
"""
|
65 |
self.background = value
|
66 |
self.backgroundcolor = None |
67 |
|
68 |
def set_background_color(self, color): |
69 |
"""Set the background color. See also set_background.
|
70 |
|
71 |
Set the background color. Use a single value in the range [0, 1[
|
72 |
for gray values, or a tuple of three such values as an RGB color.
|
73 |
|
74 |
Calling set_background_color cancels previous calls to set_background.
|
75 |
"""
|
76 |
self.background = None |
77 |
self.backgroundcolor = color
|
78 |
|
79 |
def set_red_yellow_colors(self, reverse=False): |
80 |
"""Set colors to Black-Red-Yellow-White (a.k.a. STM colors)"""
|
81 |
self.set_colors([(0.0, 0, 0, 0), |
82 |
(0.33, 1, 0, 0), |
83 |
(0.66, 1, 1, 0), |
84 |
(1.0, 1, 1, 1)], |
85 |
reverse) |
86 |
|
87 |
def set_black_white_colors(self, reverse=False): |
88 |
"""Set the color to Black-White (greyscale)"""
|
89 |
self.set_colors([(0.0, 0), (1.0, 1)], reverse) |
90 |
|
91 |
def set_colors(self, colors, reverse=False): |
92 |
colors = numpy.array(colors, numpy.float) |
93 |
if len(colors.shape) != 2: |
94 |
raise ValueError, "Colors must be a 2D array." |
95 |
if reverse:
|
96 |
colors[:,0] = 1 - colors[:,0] |
97 |
colors = numpy.array(colors[::-1,:])
|
98 |
#print colors
|
99 |
if colors[0,0] != 0.0 or colors[-1,0] != 1.0: |
100 |
raise ValueError, "First row must define the value 0 and last row must define the value 1" |
101 |
if colors.shape[1] == 2: |
102 |
self.colormode = 1 |
103 |
elif colors.shape[1] == 4: |
104 |
self.colormode = 3 |
105 |
else:
|
106 |
raise ValueError, "Color specification must be Nx2 (grey) or Nx4 (rgb) matrix." |
107 |
self.colorfunction = InterpolatingFunction(colors[:,0], colors[:,1:]) |
108 |
|
109 |
def plot(self, data=None): |
110 |
"""Create a plot now. Does not respect the interval timer.
|
111 |
|
112 |
This method makes a plot unconditionally. It does not look at
|
113 |
the interval variable, nor is this plot taken into account in
|
114 |
the counting done by the update() method if an interval
|
115 |
variable was specified.
|
116 |
|
117 |
If data is specified, it must be an array of numbers with the
|
118 |
same length as the atoms. That data will then be plotted. If
|
119 |
no data is given, the data source specified when creating the
|
120 |
plotter is used.
|
121 |
|
122 |
"""
|
123 |
if self.timing: |
124 |
self._starttimer()
|
125 |
self.log("FieldPlotter: Starting plot at " |
126 |
+ time.strftime("%a, %d %b %Y %H:%M:%S"))
|
127 |
if data is None: |
128 |
data = self.datasource()
|
129 |
if len(data) != len(self.atoms): |
130 |
raise ValueError, ("Data has wrong length: %d instead of %d." |
131 |
% (len(data), len(self.atoms))) |
132 |
|
133 |
invisible = self._getinvisible()
|
134 |
coords = self._rotate(self._getpositions()) |
135 |
radii = self._getradii()
|
136 |
if self.autoscale: |
137 |
self._autoscale(coords,radii)
|
138 |
scale = self.scale * self.relativescale |
139 |
coords = scale * coords |
140 |
center = self._getcenter(coords)
|
141 |
offset = numpy.array(self.dims + (0.0,))/2.0 - center |
142 |
coords = coords + offset |
143 |
radii = radii * scale |
144 |
self.log("Scale is %f and size is (%d, %d)" |
145 |
% (scale, self.dims[0], self.dims[1])) |
146 |
self.log("Physical size of plot is %f Angstrom times %f Angstrom" |
147 |
% (self.dims[0] / scale, self.dims[1] / scale)) |
148 |
|
149 |
# Remove invisible atoms
|
150 |
selector = numpy.logical_not(invisible) |
151 |
coords = numpy.compress(selector, coords, 0)
|
152 |
radii = numpy.compress(selector, radii) |
153 |
data = numpy.compress(selector, data) |
154 |
|
155 |
self.log("plotting data in the range [%f,%f]" % |
156 |
(data.min(), data.max())) |
157 |
# Now create the output array
|
158 |
sumarray = numpy.zeros(self.dims, numpy.float)
|
159 |
weight = numpy.zeros(self.dims)
|
160 |
|
161 |
# Loop over all atoms, and plot them
|
162 |
nmiss = 0
|
163 |
if self.plane == "xy": |
164 |
xy = coords[:,:2]
|
165 |
elif self.plane == "xz": |
166 |
xy = coords[:,::2]
|
167 |
elif self.plane == "yz": |
168 |
xy = coords[:,1:]
|
169 |
else:
|
170 |
raise RuntimeError, "self.plane is bogus: "+str(self.plane) |
171 |
assert xy.shape[1] == 2 |
172 |
|
173 |
self.log("plotting %d atoms on %d * %d (= %d) grid" % |
174 |
(len(xy), sumarray.shape[0], sumarray.shape[1], |
175 |
len(sumarray.flat)))
|
176 |
|
177 |
xy = xy.astype(numpy.int) |
178 |
for i in xrange(len(xy)): |
179 |
(x, y) = xy[i] |
180 |
d = data[i] |
181 |
if (x >= 0 and x < self.dims[0] and y >= 0 and y < self.dims[1]): |
182 |
sumarray[x,y] += d |
183 |
weight[x,y] += 1
|
184 |
else:
|
185 |
nmiss += 1
|
186 |
print "... %d atoms fell outside plot." % (nmiss,) |
187 |
|
188 |
datamap = self._makedatamap(sumarray, weight, data.min(), data.max())
|
189 |
self.log("Range of data map: [%f, %f]" % |
190 |
(datamap.min(), datamap.max())) |
191 |
plot = self._makeplotmap(datamap, weight)
|
192 |
#self.log("Range of plot: [%f, %f]" %
|
193 |
# (min(plot.flat), max(plot.flat)))
|
194 |
examinplot = plot[:] |
195 |
examinplot.shape = (plot.shape[0] * plot.shape[1],) + plot.shape[2:] |
196 |
self.log("Range of plot: %s -> %s" % |
197 |
(str(examinplot.min(0)), str(examinplot.max(0)))) |
198 |
del examinplot
|
199 |
for device in self.outputdevice: |
200 |
device.inform_about_scale(scale) |
201 |
device.plotArray(self.n, numpy.swapaxes(plot,0,1)) |
202 |
self.n = self.n + 1 |
203 |
self.log("FieldPlotter: Finished plotting at " |
204 |
+ time.strftime("%a, %d %b %Y %H:%M:%S"))
|
205 |
self.log("\n\n") |
206 |
|
207 |
|
208 |
def _makedatamap(self, sumarray, weight, minimum, maximum): |
209 |
background = numpy.equal(weight, 0)
|
210 |
print "Number of background points:", sum(background.flat) |
211 |
datamap = sumarray / numpy.where(background, 1, weight)
|
212 |
|
213 |
if self.background is not None: |
214 |
if self.background == "min": |
215 |
bg = minimum |
216 |
elif self.background == "max": |
217 |
bg = maximum |
218 |
else:
|
219 |
bg = self.background
|
220 |
datamap = numpy.where(background, bg, datamap) |
221 |
|
222 |
if self.autorange == "data": |
223 |
datamap = (datamap - minimum) / (maximum - minimum) |
224 |
self.log("Autorange using data. Data range is [%f, %f]" |
225 |
% (minimum, maximum)) |
226 |
elif self.autorange == "plot": |
227 |
ma = numpy.where(background, minimum, datamap).max() |
228 |
mi = numpy.where(background, maximum, datamap).min() |
229 |
datamap = (datamap - mi) / (ma - mi) |
230 |
self.log("Autorange using plot. Data range is [%f, %f]" |
231 |
% (mi, ma)) |
232 |
else:
|
233 |
assert self.autorange == None |
234 |
datamap = (datamap - self.range[0]) / (self.range[1] |
235 |
- self.range[0]) |
236 |
datamap = numpy.clip(datamap, 0.0, 1.0) |
237 |
self.log("Data range specified by user: [%f, %f]" % self.range) |
238 |
datamap = numpy.where(background, bg, datamap) |
239 |
assert datamap.min() >= 0 and datamap.max() <= 1.0 |
240 |
|
241 |
return datamap
|
242 |
|
243 |
def _makeplotmap(self, datamap, weight): |
244 |
plot = numpy.zeros(self.dims + (self.colormode,), numpy.float) |
245 |
for i in range(self.dims[0]): |
246 |
for j in range(self.dims[1]): |
247 |
if self.backgroundcolor is not None and weight[i,j] == 0: |
248 |
plot[i,j,:] = self.backgroundcolor
|
249 |
else:
|
250 |
x = datamap[i,j] |
251 |
plot[i,j,:] = self.colorfunction(x)
|
252 |
return plot
|
253 |
|
254 |
class InterpolatingFunction: |
255 |
def __init__(self, xpoints, ypoints): |
256 |
if len(xpoints) != len(ypoints): |
257 |
raise ValueError, "Length of x and y arrays should be the same." |
258 |
idx = xpoints.argsort() |
259 |
self.xpoints = xpoints[idx]
|
260 |
self.ypoints = ypoints[idx]
|
261 |
def __call__(self, x): |
262 |
n = self.xpoints.searchsorted(x)
|
263 |
if n == 0: |
264 |
return self.ypoints[0] |
265 |
if n == len(self.xpoints): |
266 |
return self.xpoints[-1] |
267 |
x0 = self.xpoints[n-1] |
268 |
x1 = self.xpoints[n]
|
269 |
y0 = self.ypoints[n-1] |
270 |
y1 = self.ypoints[n]
|
271 |
return y0 + (y1 - y0) / (x1 - x0) * (x - x0)
|
272 |
|