Statistiques
| Révision :

root / ase / visualize / fieldplotter.py @ 4

Historique | Voir | Annoter | Télécharger (10,86 ko)

1
"""plotting fields defined on atoms during a simulation."""
2

    
3
from ase.visualize.primiplotter import PostScriptFile, PnmFile, GifFile, JpegFile, X11Window
4
from ase.visualize.primiplotter import PrimiPlotter as _PrimiPlotter
5
import numpy
6
import time
7

    
8
class FieldPlotter(_PrimiPlotter):
9
    def __init__(self, atoms, datasource=None, verbose=0, timing=0,
10
                 interval=1, initframe=0):
11
        _PrimiPlotter.__init__(self, atoms, verbose=verbose, timing=timing,
12
                               interval=interval, initframe=initframe)
13
        self.datasource = datasource
14
        self.dims = (100,100)
15
        self.set_plot_plane("xy")
16
        self.set_data_range("plot")
17
        self.set_background(0.0)
18
        self.set_red_yellow_colors()
19
        
20
    def set_plot_plane(self, plane):
21
        """Set the plotting plane to xy, xz or yz (default: xy)"""
22
        if plane in ("xy", "xz", "yz"):
23
            self.plane = plane
24
        else:
25
            raise ValueError, "The argument to plotPlane must be 'xy', 'xz' or 'yz'."
26

    
27
    def set_data_range(self, range1, range2=None):
28
        """Set the range of the data used when coloring.
29

30
        This function sets the range of data values mapped unto colors
31
        in the final plot.
32
        
33
        Three possibilities:
34

35
        'data':        Autoscale using the data on visible atoms.
36
                       The range goes from the lowest to the highest
37
                       value present on the atoms.  If only a few atoms
38
                       have extreme values, the entire color range may not
39
                       be used on the plot, as many values may be averaged
40
                       on each point in the plot.
41

42
        'plot':        Autoscale using the data on the plot.  Unlike 'data'
43
                       this guarantees that the entire color range is used.
44

45
        min, max:      Use the range [min, max]
46
                       
47
        """
48
        if (range1 == "data" or range1 == "plot") and range2 == None:
49
            self.autorange = range1
50
        elif range2 != None:
51
            self.autorange = None
52
            self.range = (range1, range2)
53
        else:
54
            raise ValueError, "Illegal argument(s) to set_data_range"
55

    
56
    def set_background(self, value):
57
        """Set the data value of the background.  See also set_background_color
58

59
        Set the value of the background (parts of the plot without atoms) to
60
        a specific value, or to 'min' or 'max' representing the minimal or
61
        maximal data values on the atoms.
62

63
        Calling set_background cancels previous calls to set_background_color.
64
        """
65
        self.background = value
66
        self.backgroundcolor = None
67

    
68
    def set_background_color(self, color):
69
        """Set the background color.  See also set_background.
70

71
        Set the background color.  Use a single value in the range [0, 1[
72
        for gray values, or a tuple of three such values as an RGB color.
73

74
        Calling set_background_color cancels previous calls to set_background.
75
        """
76
        self.background = None
77
        self.backgroundcolor = color
78

    
79
    def set_red_yellow_colors(self, reverse=False):
80
        """Set colors to Black-Red-Yellow-White (a.k.a. STM colors)"""
81
        self.set_colors([(0.0, 0, 0, 0),
82
                        (0.33, 1, 0, 0),
83
                        (0.66, 1, 1, 0),
84
                        (1.0, 1, 1, 1)],
85
                       reverse)
86
        
87
    def set_black_white_colors(self, reverse=False):
88
        """Set the color to Black-White (greyscale)"""
89
        self.set_colors([(0.0, 0),  (1.0, 1)], reverse)
90

    
91
    def set_colors(self, colors, reverse=False):
92
        colors = numpy.array(colors, numpy.float)
93
        if len(colors.shape) != 2:
94
            raise ValueError, "Colors must be a 2D array."
95
        if reverse:
96
            colors[:,0] = 1 - colors[:,0]
97
            colors = numpy.array(colors[::-1,:])
98
            #print colors
99
        if colors[0,0] != 0.0 or colors[-1,0] != 1.0:
100
            raise ValueError, "First row must define the value 0 and last row must define the value 1"
101
        if colors.shape[1] == 2:
102
            self.colormode = 1
103
        elif colors.shape[1] == 4:
104
            self.colormode = 3
105
        else:
106
            raise ValueError, "Color specification must be Nx2 (grey) or Nx4 (rgb) matrix."
107
        self.colorfunction = InterpolatingFunction(colors[:,0], colors[:,1:])
108
        
109
    def plot(self, data=None):
110
        """Create a plot now.  Does not respect the interval timer.
111

112
        This method makes a plot unconditionally.  It does not look at
113
        the interval variable, nor is this plot taken into account in
114
        the counting done by the update() method if an interval
115
        variable was specified.
116

117
        If data is specified, it must be an array of numbers with the
118
        same length as the atoms.  That data will then be plotted.  If
119
        no data is given, the data source specified when creating the
120
        plotter is used.
121
        
122
        """
123
        if self.timing:
124
            self._starttimer()
125
        self.log("FieldPlotter: Starting plot at "
126
                 + time.strftime("%a, %d %b %Y %H:%M:%S"))
127
        if data is None:
128
            data = self.datasource()
129
        if len(data) != len(self.atoms):
130
            raise ValueError, ("Data has wrong length: %d instead of %d."
131
                               % (len(data), len(self.atoms)))
132
        
133
        invisible = self._getinvisible()
134
        coords = self._rotate(self._getpositions())
135
        radii = self._getradii()
136
        if self.autoscale:
137
            self._autoscale(coords,radii)
138
        scale = self.scale * self.relativescale
139
        coords = scale * coords
140
        center = self._getcenter(coords)
141
        offset = numpy.array(self.dims + (0.0,))/2.0 - center
142
        coords = coords + offset
143
        radii = radii * scale
144
        self.log("Scale is %f and size is (%d, %d)"
145
                 % (scale, self.dims[0], self.dims[1]))
146
        self.log("Physical size of plot is %f Angstrom times %f Angstrom"
147
                 % (self.dims[0] / scale, self.dims[1] / scale))
148

    
149
        # Remove invisible atoms
150
        selector = numpy.logical_not(invisible)
151
        coords = numpy.compress(selector, coords, 0)
152
        radii = numpy.compress(selector, radii)
153
        data = numpy.compress(selector, data)
154

    
155
        self.log("plotting data in the range [%f,%f]" %
156
                   (data.min(), data.max()))
157
        # Now create the output array
158
        sumarray = numpy.zeros(self.dims, numpy.float)
159
        weight = numpy.zeros(self.dims)
160

    
161
        # Loop over all atoms, and plot them
162
        nmiss = 0
163
        if self.plane == "xy":
164
            xy = coords[:,:2]
165
        elif self.plane == "xz":
166
            xy = coords[:,::2]
167
        elif self.plane == "yz":
168
            xy = coords[:,1:]
169
        else:
170
            raise RuntimeError, "self.plane is bogus: "+str(self.plane)
171
        assert xy.shape[1] == 2
172

    
173
        self.log("plotting %d atoms on %d * %d (= %d) grid" %
174
                   (len(xy), sumarray.shape[0], sumarray.shape[1],
175
                    len(sumarray.flat)))
176
                                                            
177
        xy = xy.astype(numpy.int)
178
        for i in xrange(len(xy)):
179
            (x, y) = xy[i]
180
            d = data[i]
181
            if (x >= 0 and x < self.dims[0] and y >= 0 and y < self.dims[1]):
182
                sumarray[x,y] += d
183
                weight[x,y] += 1
184
            else:
185
                nmiss += 1
186
        print "... %d atoms fell outside plot." % (nmiss,)
187

    
188
        datamap = self._makedatamap(sumarray, weight, data.min(), data.max())
189
        self.log("Range of data map: [%f, %f]" %
190
                   (datamap.min(), datamap.max()))
191
        plot = self._makeplotmap(datamap, weight)
192
        #self.log("Range of plot: [%f, %f]" %
193
        #           (min(plot.flat), max(plot.flat)))
194
        examinplot = plot[:]
195
        examinplot.shape = (plot.shape[0] * plot.shape[1],) + plot.shape[2:]
196
        self.log("Range of plot: %s -> %s" %
197
                 (str(examinplot.min(0)), str(examinplot.max(0))))
198
        del examinplot
199
        for device in self.outputdevice:
200
            device.inform_about_scale(scale)
201
            device.plotArray(self.n, numpy.swapaxes(plot,0,1))
202
        self.n = self.n + 1
203
        self.log("FieldPlotter: Finished plotting at "
204
                 + time.strftime("%a, %d %b %Y %H:%M:%S"))
205
        self.log("\n\n")
206

    
207
        
208
    def _makedatamap(self, sumarray, weight, minimum, maximum):
209
        background = numpy.equal(weight, 0)
210
        print "Number of background points:", sum(background.flat)
211
        datamap = sumarray / numpy.where(background, 1, weight)
212
        
213
        if self.background is not None:
214
            if self.background == "min":
215
                bg = minimum
216
            elif self.background == "max":
217
                bg = maximum
218
            else:
219
                bg = self.background
220
            datamap = numpy.where(background, bg, datamap)
221
            
222
        if self.autorange == "data":
223
            datamap = (datamap - minimum) / (maximum - minimum)
224
            self.log("Autorange using data.  Data range is [%f, %f]"
225
                     % (minimum, maximum))
226
        elif self.autorange == "plot":
227
            ma = numpy.where(background, minimum, datamap).max()
228
            mi = numpy.where(background, maximum, datamap).min()
229
            datamap = (datamap - mi) / (ma - mi)
230
            self.log("Autorange using plot.  Data range is [%f, %f]"
231
                     % (mi, ma))
232
        else:
233
            assert self.autorange == None
234
            datamap = (datamap - self.range[0]) / (self.range[1]
235
                                                   - self.range[0])
236
            datamap = numpy.clip(datamap, 0.0, 1.0)
237
            self.log("Data range specified by user: [%f, %f]" % self.range)
238
        datamap = numpy.where(background, bg, datamap)
239
        assert datamap.min() >= 0 and datamap.max() <= 1.0
240
        
241
        return datamap
242

    
243
    def _makeplotmap(self, datamap, weight):
244
        plot = numpy.zeros(self.dims + (self.colormode,), numpy.float)
245
        for i in range(self.dims[0]):
246
            for j in range(self.dims[1]):
247
                if self.backgroundcolor is not None and weight[i,j] == 0:
248
                    plot[i,j,:] = self.backgroundcolor
249
                else:
250
                    x = datamap[i,j]
251
                    plot[i,j,:] = self.colorfunction(x)
252
        return plot
253
    
254
class InterpolatingFunction:
255
    def __init__(self, xpoints, ypoints):
256
        if len(xpoints) != len(ypoints):
257
            raise ValueError, "Length of x and y arrays should be the same."
258
        idx = xpoints.argsort()
259
        self.xpoints = xpoints[idx]
260
        self.ypoints = ypoints[idx]
261
    def __call__(self, x):
262
        n = self.xpoints.searchsorted(x)
263
        if n == 0:
264
            return self.ypoints[0]
265
        if n == len(self.xpoints):
266
            return self.xpoints[-1]
267
        x0 = self.xpoints[n-1]
268
        x1 = self.xpoints[n]
269
        y0 = self.ypoints[n-1]
270
        y1 = self.ypoints[n]
271
        return y0 + (y1 - y0) / (x1 - x0) * (x - x0)
272