Statistiques
| Révision :

root / ase / visualize / primiplotter.py @ 3

Historique | Voir | Annoter | Télécharger (34,75 ko)

1
"""An experimental package for making plots during a simulation.
2

3
A PrimiPlotter can plot a list of atoms on one or more output devices.
4
"""
5

    
6
from numpy import *
7
from ase.visualize.colortable import color_table
8
import ase.data
9
import sys, os, time, weakref
10

    
11
class PrimiPlotterBase:
12
    "Base class for PrimiPlotter and Povrayplotter."
13
    #def set_dimensions(self, dims):
14
    #    "Set the size of the canvas (a 2-tuple)."
15
    #    self.dims = dims
16
        
17
    def set_rotation(self, rotation):
18
        "Set the rotation angles (in degrees)."
19
        self.angles[:] = array(rotation) * (pi/180)
20
        
21
    def set_radii(self, radii):
22
        """Set the atomic radii.  Give an array or a single number."""
23
        self.radius = radii
24

    
25
    def set_colors(self, colors):
26
        """Explicitly set the colors of the atoms."""
27
        self.colors = colors
28

    
29
    def set_color_function(self, colors):
30
        """Set a color function, to be used to color the atoms."""
31
        if callable(colors):
32
            self.colorfunction = colors
33
        else:
34
            raise TypeError, "The color function is not callable."
35

    
36
    def set_invisible(self, inv):
37
        """Choose invisible atoms."""
38
        self.invisible = inv
39

    
40
    def set_invisibility_function(self, invfunc):
41
        """Set an invisibility function."""
42
        if callable(invfunc):
43
            self.invisibilityfunction = invfunc
44
        else:
45
            raise TypeError, "The invisibility function is not callable."
46

    
47
    def set_cut(self, xmin=None, xmax=None, ymin=None, ymax=None,
48
               zmin=None, zmax=None):
49
        self.cut = {"xmin":xmin, "xmax":xmax, "ymin":ymin, "ymax":ymax,
50
                    "zmin":zmin, "zmax":zmax}
51
    
52
    def update(self, newatoms = None):
53
        """Cause a plot (respecting the interval setting).
54

55
        update causes a plot to be made.  If the interval variable was
56
        specified when the plotter was create, it will only produce a
57
        plot with that interval.  update takes an optional argument,
58
        newatoms, which can be used to replace the list of atoms with
59
        a new one.
60
        """
61
        if newatoms is not None:
62
            self.atoms = newatoms
63
        if self.skipnext <= 0:
64
            self.plot()
65
            self.skipnext = self.interval
66
        self.skipnext -= 1
67
        
68
    def set_log(self, log):
69
        """Sets a file for logging.
70

71
        log may be an open file or a filename.
72
        """
73
        if hasattr(log, "write"):
74
            self.logfile = log
75
            self.ownlogfile = False
76
        else:
77
            self.logfile = open(log, "w")
78
            self.ownlogfile = True
79

    
80
    def log(self, message):
81
        """logs a message to the file set by set_log."""
82
        if self.logfile is not None:
83
            self.logfile.write(message+"\n")
84
            self.logfile.flush()
85
        self._verb(message)
86
        
87
    def _verb(self, txt):
88
        if self.verbose:
89
            sys.stderr.write(txt+"\n")
90
    
91
    def _starttimer(self):
92
        self.starttime = time.time()
93

    
94
    def _stoptimer(self):
95
        elapsedtime = time.time() - self.starttime
96
        self.totaltime = self.totaltime + elapsedtime
97
        print "plotting time %s sec (total %s sec)" % (elapsedtime,
98
                                                       self.totaltime)
99

    
100
    def _getpositions(self):
101
        return self.atoms.get_positions()
102

    
103
    def _getradii(self):
104
        if self.radius is not None:
105
            if hasattr(self.radius, "shape"):
106
                return self.radius   # User has specified an array
107
            else:
108
                return self.radius * ones(len(self.atoms), float)
109
        # No radii specified.  Try getting them from the atoms.
110
        try:
111
            return self.atoms.get_atomic_radii()
112
        except AttributeError:
113
            try:
114
                z = self._getatomicnumbers()
115
            except AttributeError:
116
                pass
117
            else:
118
                return ase.data.covalent_radii[z]
119
        # No radius available.  Defaulting to 1.0
120
        return ones(len(self.atoms), float)
121

    
122
    def _getatomicnumbers(self):
123
        return self.atoms.get_atomic_numbers()
124
    
125
    def _getcolors(self):
126
        # Try any explicitly given colors
127
        if self.colors is not None:
128
            if type(self.colors) == type({}):
129
                self.log("Explicit colors dictionary")
130
                return _colorsfromdict(self.colors,
131
                                       asarray(self.atoms.get_tags(),int))
132
            else:
133
                self.log("Explicit colors")
134
                return self.colors
135
        # Try the color function, if given
136
        if self.colorfunction is not None:
137
            self.log("Calling color function.")
138
            return self.colorfunction(self.atoms)
139
        # Maybe the atoms know their own colors
140
        try:
141
            c = self.atoms.get_colors()
142
        except AttributeError:
143
            c = None
144
        if c is not None:
145
            if type(c) == type({}):
146
                self.log("Color dictionary from atoms.get_colors()")
147
                return _colorsfromdict(c, asarray(self.atoms.get_tags(),int))
148
            else:
149
                self.log("Colors from atoms.get_colors()")
150
                return c
151
        # Default to white atoms
152
        self.log("No colors: using white")
153
        return ones(len(self.atoms), float)
154

    
155
    def _getinvisible(self):
156
        if self.invisible is not None:
157
            inv = self.invisible
158
        else:
159
            inv = zeros(len(self.atoms))
160
        if self.invisibilityfunction:
161
            inv = logical_or(inv, self.invisibilityfunction(self.atoms))
162
        r = self._getpositions()
163
        if len(r) > len(inv):
164
            # This will happen in parallel simulations due to ghost atoms.
165
            # They are invisible.  Hmm, this may cause trouble.
166
            i2 = ones(len(r))
167
            i2[:len(inv)] = inv
168
            inv = i2
169
            del i2
170
        if self.cut["xmin"] is not None:
171
            inv = logical_or(inv, less(r[:,0], self.cut["xmin"]))
172
        if self.cut["xmax"] is not None:
173
            inv = logical_or(inv, greater(r[:,0], self.cut["xmax"]))
174
        if self.cut["ymin"] is not None:
175
            inv = logical_or(inv, less(r[:,1], self.cut["ymin"]))
176
        if self.cut["ymax"] is not None:
177
            inv = logical_or(inv, greater(r[:,1], self.cut["ymax"]))
178
        if self.cut["zmin"] is not None:
179
            inv = logical_or(inv, less(r[:,2], self.cut["zmin"]))
180
        if self.cut["zmax"] is not None:
181
            inv = logical_or(inv, greater(r[:,2], self.cut["zmax"]))
182
        return inv        
183

    
184
    def __del__(self):
185
        if self.ownlogfile:
186
            self.logfile.close()
187
            
188
class PrimiPlotter(PrimiPlotterBase):
189
    """Primitive PostScript-based plots during a simulation.
190

191
    The PrimiPlotter plots atoms during simulations, extracting the
192
    relevant information from the list of atoms.  It is created using
193
    the list of atoms as an argument to the constructor.  Then one or
194
    more output devices must be attached using set_output(device).  The
195
    list of supported output devices is at the end.
196

197
    The atoms are plotted as circles.  The system is first rotated
198
    using the angles specified by set_rotation([vx, vy, vz]).  The
199
    rotation is vx degrees around the x axis (positive from the y
200
    toward the z axis), then vy degrees around the y axis (from x
201
    toward z), then vz degrees around the z axis (from x toward y).
202
    The rotation matrix is the same as the one used by RasMol.
203

204
    Per default, the system is scaled so it fits within the canvas
205
    (autoscale mode).  Autoscale mode is enabled and disables using
206
    autoscale("on") or autoscale("off").  A manual scale factor can be
207
    set with set_scale(scale), this implies autoscale("off").  The
208
    scale factor (from the last autoscale event or from set_scale) can
209
    be obtained with get_scale().  Finally, an explicit autoscaling can
210
    be triggered with autoscale("now"), this is mainly useful before
211
    calling get_scale or before disabling further autoscaling.
212
    Finally, a relative scaling factor can be set with
213
    SetRelativeScaling(), it is multiplied to the usual scale factor
214
    (from autoscale or from set_scale).  This is probably only useful in
215
    connection with autoscaling.
216

217
    The radii of the atoms are obtained from the first of the following
218
    methods which work:
219
    
220
    1.  If the radii are specified using PrimiPlotter.set_radii(r),
221
        they are used.  Must be an array, or a single number.
222

223
    2.  If the atoms has a get_atomic_radii() method, it is used.  This is
224
        unlikely.
225

226
    3.  If the atoms has a get_atomic_numbers() method, the
227
        corresponding covalent radii are extracted from the
228
        ASE.ChemicalElements module.
229

230
    4.  If all else fails, the radius is set to 1.0 Angstrom.
231

232
    The atoms are colored using the first of the following methods
233
    which work.
234

235
    1.  If colors are explicitly set using PrimiPlotter.set_colors(),
236
        they are used.
237

238
    2.  If these colors are specified as a dictionary, the tags
239
        (from atoms.get_tags()) are used as an index into the
240
        dictionary to get the actual colors of the atoms.
241

242
    3.  If a color function has been set using
243
        PrimiPlotter.set_color_function(), it is called with the atoms
244
        as an argument, and is expected to return an array of colors.
245

246
    4.  If the atoms have a get_colors() method, it is used to get the
247
        colors.
248

249
    5.  If these colors are specified as a dictionary, the tags
250
        (from atoms.get_tags()) are used as an index into the
251
        dictionary to get the actual colors of the atoms.
252

253
    6.  If all else fails, the atoms will be white.
254

255
    The colors are specified as an array of colors, one color per
256
    atom.  Each color is either a real number from 0.0 to 1.0,
257
    specifying a grayscale (0.0 = black, 1.0 = white), or an array of
258
    three numbers from 0.0 to 1.0, specifying RGB values.  The colors
259
    of all atoms are thus a Numerical Python N-vector or a 3xN matrix.
260

261
    In cases 1a and 3a above, the keys of the dictionary are integers,
262
    and the values are either numbers (grayscales) or 3-vectors (RGB
263
    values), or strings with X11 color names, which are then
264
    translated to RGB values.  Only in case 1a and 3a are strings
265
    recognized as colors.
266

267
    Some atoms may be invisible, and thus left out of the plot.
268
    Invisible atoms are determined from the following algorithm.
269
    Unlike the radius or the coloring, all points below are tried and
270
    if an atom is invisible by any criterion, it is left out of the plot.
271

272
    1.  All atoms are visible.
273
    
274
    2.  If PrimiPlotter.set_invisible() has be used to specify invisible
275
        atoms, any atoms for which the value is non-zero becomes invisible.
276

277
    3.  If an invisiblility function has been set with
278
        PrimiPlotter.set_invisibility_function(), it is called with the
279
        atoms as argument.  It is expected to return an integer per
280
        atom, any non-zero value makes that atom invisible.
281

282
    4.  If a cut has been specified using set_cut, any atom outside the
283
        cut is made invisible.
284

285
    Note that invisible atoms are still included in the algorithm for
286
    positioning and scaling the plot.
287

288
    
289
    The following output devices are implemented.
290
    
291
    PostScriptFile(prefix):  Create PS files names prefix0000.ps etc.
292

293
    PnmFile(prefix):  Similar, but makes PNM files.
294

295
    GifFile(prefix):  Similar, but makes GIF files.
296

297
    JpegFile(prefix):  Similar, but makes JPEG files.
298

299
    X11Window():  Show the plot in an X11 window using ghostscript.
300

301
    Output devices writing to files take an extra optional argument to
302
    the constructor, compress, specifying if the output file should be
303
    gzipped.  This is not allowed for some (already compressed) file
304
    formats.
305

306
    Instead of a filename prefix, a filename containing a % can be
307
    used.  In that case the filename is expected to expand to a real
308
    filename when used with the Python string formatting operator (%)
309
    with the frame number as argument.  Avoid generating spaces in the
310
    file names: use e.g. %03d instead of %3d.  
311
    """
312
    def __init__(self, atoms, verbose=0, timing=0, interval=1, initframe=0):
313
        """
314

315
        Parameters to the constructor:
316

317
        atoms: The atoms to be plottet.
318

319
        verbose = 0:  Write progress information to stderr.
320

321
        timing = 0:  Collect timing information.
322

323
        interval = 1: If specified, a plot is only made every
324
        interval'th time update() is called.  Deprecated, normally you
325
        should use the interval argument when attaching the plotter to
326
        e.g. the dynamics.
327

328
        initframe = 0: Initial frame number, i.e. the number of the
329
        first plot.
330
        
331
        """
332
        self.atoms = atoms
333
        self.outputdevice = []
334
        self.angles = zeros(3, float)
335
        self.dims = (512, 512)
336
        self.verbose = verbose
337
        self.timing = timing
338
        self.totaltime = 0.0
339
        self.radius = None
340
        self.colors = None
341
        self.colorfunction = None
342
        self.n = initframe
343
        self.interval = interval
344
        self.skipnext = 0 # Number of calls to update before anything happens.
345
        self.a_scale = 1
346
        self.relativescale = 1.0
347
        self.invisible = None
348
        self.invisibilityfunction = None
349
        self.set_cut()   # No cut
350
        self.isparallel = 0
351
        self.logfile = None
352
        self.ownlogfile = False
353
        
354
    def set_output(self, device):
355
        self.outputdevice.append(device)
356
        device.set_dimensions(self.dims)
357
        device.set_owner(weakref.proxy(self))
358

    
359
    def set_dimensions(self, dims):
360
        "Set the size of the canvas (a 2-tuple)."
361
        if self.outputdevice:
362
            raise RuntimeError("Cannot set dimensions after an output device has been specified.")
363
        self.dims = dims
364
        
365
    def autoscale(self, mode):
366
        if mode == "on":
367
            self.a_scale = 1
368
        elif mode == "off":
369
            self.a_scale = 0
370
        elif mode == "now":
371
            coords = self._rotate(self.atoms.get_positions())
372
            radii = self._getradii()
373
            self._autoscale(coords, radii)
374
        else:
375
            raise ValueError, "Unknown autoscale mode: ",+str(mode)
376

    
377
    def set_scale(self, scale):
378
        self.autoscale("off")
379
        self.scale = scale
380

    
381
    def get_scale(self):
382
        return self.scale
383

    
384
    def set_relative_scale(self, rscale = 1.0):
385
        self.relativescale = rscale
386

    
387
    def plot(self):
388
        """Create a plot now.  Does not respect the interval timer.
389

390
        This method makes a plot unconditionally.  It does not look at
391
        the interval variable, nor is this plot taken into account in
392
        the counting done by the update() method if an interval
393
        variable was specified.
394
        """
395
        if self.timing:
396
            self._starttimer()
397
        self.log("PrimiPlotter: Starting plot at "
398
                 + time.strftime("%a, %d %b %Y %H:%M:%S"))
399
        colors = self._getcolors()
400
        invisible = self._getinvisible()
401
        coords = self._rotate(self._getpositions())
402
        radii = self._getradii()
403
        if self.a_scale:
404
            self._autoscale(coords,radii)
405
        scale = self.scale * self.relativescale
406
        coords = scale * coords
407
        center = self._getcenter(coords)
408
        offset = array(self.dims + (0.0,))/2.0 - center
409
        coords = coords + offset
410
        self.log("Scale is %f and size is (%d, %d)"
411
                 % (scale, self.dims[0], self.dims[1]))
412
        self.log("Physical size of plot is %f Angstrom times %f Angstrom"
413
                 % (self.dims[0] / scale, self.dims[1] / scale))
414

    
415
        self._verb("Sorting.")
416
        order = argsort(coords[:,2])
417
        coords = coords[order]  ### take(coords, order)
418
        radii = radii[order]    ### take(radii, order)
419
        colors = colors[order]  ### take(colors, order)
420
        invisible = invisible[order]  ### take(invisible, order)
421
        if self.isparallel:
422
            id = arange(len(coords))[order] ### take(arange(len(coords)), order)
423
        else:
424
            id = None
425
            
426
        radii = radii * scale
427
        selector = self._computevisibility(coords, radii, invisible, id)
428
        coords = compress(selector, coords, 0)
429
        radii = compress(selector, radii)
430
        colors = compress(selector, colors, 0)
431
        self._makeoutput(scale, coords, radii, colors)
432
        self.log("PrimiPlotter: Finished plotting at "
433
                 + time.strftime("%a, %d %b %Y %H:%M:%S"))
434
        self.log("\n\n")
435
        if self.timing:
436
            self._stoptimer()
437

    
438
    def _computevisibility(self, coords, rad, invisible, id, zoom = 1):
439
        xy = coords[:,:2]
440
        typradius = sum(rad) / len(rad)
441
        if typradius < 4.0:
442
            self.log("Refining visibility check.")
443
            if zoom >= 16:
444
                raise RuntimeError, "Cannot check visibility - too deep recursion."
445
            return self._computevisibility(xy*2, rad*2, invisible, id, zoom*2)
446
        else:
447
            self.log("Visibility(r_typ = %.1f pixels)" % (typradius,))
448
        dims = array(self.dims) * zoom
449
        maxr = int(ceil(max(rad))) + 2
450
        canvas = zeros((dims[0] + 4*maxr, dims[1] + 4*maxr), int8)
451
        # Atoms are only invisible if they are within the canvas, or closer
452
        # to its edge than their radius
453
        visible = (greater(xy[:,0], -rad) * less(xy[:,0], dims[0]+rad)
454
                   * greater(xy[:,1], -rad) * less(xy[:,1], dims[1]+rad)
455
                   * logical_not(invisible))
456
        # Atoms are visible if not hidden behind other atoms
457
        xy = floor(xy + 2*maxr + 0.5).astype(int)
458
        masks = {}
459
        for i in xrange(len(rad)-1, -1, -1):
460
            if (i % 100000) == 0 and i:
461
                self._verb(str(i))
462
            if not visible[i]:
463
                continue
464
            x, y = xy[i]
465
            r = rad[i]
466
            try:
467
                mask, invmask, rn = masks[r]
468
            except KeyError:
469
                rn = int(ceil(r))
470
                nmask = 2*rn+1
471
                mask = (arange(nmask) - rn)**2
472
                mask = less(mask[:,newaxis]+mask[newaxis,:], r*r).astype(int8)
473
                invmask = equal(mask, 0).astype(int8)
474
                masks[r] = (mask, invmask, rn)
475
            window = logical_or(canvas[x-rn:x+rn+1, y-rn:y+rn+1], invmask)
476
            hidden = alltrue(window.flat)
477
            if hidden:
478
                visible[i] = 0
479
            else:
480
                canvas[x-rn:x+rn+1, y-rn:y+rn+1] = logical_or(canvas[x-rn:x+rn+1, y-rn:y+rn+1], mask)
481
        self.log("%d visible, %d hidden out of %d" %
482
                   (sum(visible), len(visible) - sum(visible), len(visible)))
483
        return visible
484
        
485
    def _rotate(self, positions):
486
        self.log("Rotation angles: %f %f %f" % tuple(self.angles))
487
        mat = dot(dot(_rot(self.angles[2], 2),
488
                      _rot(self.angles[1], 1)),
489
                  _rot(self.angles[0]+pi, 0))
490
        return dot(positions, mat)
491

    
492
    def _getcenter(self, coords):
493
        return array((max(coords[:,0]) + min(coords[:,0]),
494
                      max(coords[:,1]) + min(coords[:,1]), 0.0)) / 2.0
495

    
496
    def _autoscale(self, coords, radii):
497
        x = coords[:,0]
498
        y = coords[:,1]
499
        maxradius = max(radii)
500
        deltax = max(x) - min(x) + 2*maxradius
501
        deltay = max(y) - min(y) + 2*maxradius
502
        scalex = self.dims[0] / deltax
503
        scaley = self.dims[1] / deltay
504
        self.scale = 0.95 * min(scalex, scaley)
505
        self.log("Autoscale: %f" % self.scale)
506

    
507
    def _makeoutput(self, scale, coords, radii, colors):
508
        for device in self.outputdevice:
509
            device.inform_about_scale(scale)
510
            device.plot(self.n, coords, radii, colors)
511
        self.n = self.n + 1
512

    
513

    
514
class ParallelPrimiPlotter(PrimiPlotter):
515
    """A version of PrimiPlotter for parallel ASAP simulations.
516

517
    Used like PrimiPlotter, but only the output devices on the master
518
    node are used.  Most of the processing is distributed on the
519
    nodes, but the actual output is only done on the master.  See the
520
    PrimiPlotter docstring for details.
521
    """
522
    def __init__(self, *args, **kwargs):
523
        apply(PrimiPlotter.__init__, (self,)+args, kwargs)
524
        self.isparallel = 1
525
        import Scientific.MPI
526
        self.MPI = Scientific.MPI
527
        self.mpi = Scientific.MPI.world
528
        if self.mpi is None:
529
            raise RuntimeError, "MPI is not available."
530
        self.master = self.mpi.rank == 0
531
        self.mpitag = 42   # Reduce chance of collision with other modules.
532
        
533
    def set_output(self, device):
534
        if self.master:
535
            PrimiPlotter.set_output(self, device)
536

    
537
    def set_log(self, log):
538
        if self.master:
539
            PrimiPlotter.set_log(self, log)
540

    
541
    def _getpositions(self):
542
        realpos = self.atoms.get_positions()
543
        ghostpos = self.atoms.GetGhostCartesianPositions()
544
        self.numberofrealatoms = len(realpos)
545
        self.numberofghostatoms = len(ghostpos)
546
        return concatenate((realpos, ghostpos))
547

    
548
    def _getatomicnumbers(self):
549
        realz = self.atoms.get_atomic_numbers()
550
        ghostz = self.atoms.GetGhostAtomicNumbers()
551
        return concatenate((realz, ghostz))
552

    
553
    def _getradius(self):
554
        r = PrimiPlotter._getradius(self)
555
        if len(r) == self.numberofrealatoms + self.numberofghostatoms:
556
            # Must have calculated radii from atomic numbers
557
            return r
558
        else:
559
            assert len(r) == self.numberofrealatoms
560
            # Heuristic: use minimum r for the ghosts
561
            ghostr = min(r) * ones(self.numberofghostatoms, float)
562
            return concatenate((r, ghostr))
563

    
564
    def _getcenter(self, coords):
565
        # max(x) and min(x) only works for rank-1 arrays in Numeric version 17.
566
        maximal = maximum.reduce(coords[:,0:2])
567
        minimal = minimum.reduce(coords[:,0:2])
568
        recvmax = zeros(2, maximal.typecode())
569
        recvmin = zeros(2, minimal.typecode())
570
        self.mpi.allreduce(maximal, recvmax, self.MPI.max)
571
        self.mpi.allreduce(minimal, recvmin, self.MPI.min)
572
        maxx, maxy = recvmax
573
        minx, miny = recvmin
574
        return array([maxx + minx, maxy + miny, 0.0]) / 2.0
575

    
576
    def _computevisibility(self, xy, rad, invisible, id, zoom = 1):
577
        # Find visible atoms, allowing ghost atoms to hide real atoms.
578
        v = PrimiPlotter._computevisibility(self, xy, rad, invisible, id, zoom)
579
        # Then remove ghost atoms
580
        return v * less(id, self.numberofrealatoms)
581

    
582
    def _autoscale(self, coords, radii):
583
        self._verb("Autoscale")
584
        n = len(self.atoms)
585
        x = coords[:n,0]
586
        y = coords[:n,1]
587
        assert len(x) == len(self.atoms)
588
        maximal = array([max(x), max(y), max(radii[:n])])
589
        minimal = array([min(x), min(y)])
590
        recvmax = zeros(3, maximal.typecode())
591
        recvmin = zeros(2, minimal.typecode())
592
        self.mpi.allreduce(maximal, recvmax, self.MPI.max)
593
        self.mpi.allreduce(minimal, recvmin, self.MPI.min)
594
        maxx, maxy, maxradius = recvmax
595
        minx, miny = recvmin
596
        deltax = maxx - minx + 2*maxradius
597
        deltay = maxy - miny + 2*maxradius
598
        scalex = self.dims[0] / deltax
599
        scaley = self.dims[1] / deltay
600
        self.scale = 0.95 * min(scalex, scaley)
601
        self.log("Autoscale: %f" % self.scale)
602

    
603
    def _getcolors(self):
604
        col = PrimiPlotter._getcolors(self)
605
        nghost = len(self.atoms.GetGhostCartesianPositions())
606
        newcolshape = (nghost + col.shape[0],) + col.shape[1:]
607
        newcol = zeros(newcolshape, col.typecode())
608
        newcol[:len(col)] = col
609
        return newcol
610
    
611
    def _makeoutput(self, scale, coords, radii, colors):
612
        if len(colors.shape) == 1:
613
            # Greyscales
614
            ncol = 1
615
        else:
616
            ncol = colors.shape[1]  # 1 or 3.
617
            assert ncol == 3  # RGB values
618
        # If one processor says RGB, all must convert
619
        ncolthis = array([ncol])
620
        ncolmax = zeros((1,), ncolthis.typecode())
621
        self.mpi.allreduce(ncolthis, ncolmax, self.MPI.max)
622
        ncolmax = ncolmax[0]
623
        if ncolmax > ncol:
624
            assert ncol == 1
625
            colors = colors[:,newaxis] + zeros(ncolmax)[newaxis,:]
626
            ncol = ncolmax
627
            assert colors.shape == (len(coords), ncol)
628
        # Now send data from slaves to master
629
        data = zeros((len(coords)+1, 4+ncol), float)
630
        data[:-1,:3] = coords
631
        data[:-1,3] = radii
632
        data[-1,-1] = 4+ncol  # Used to communicate shape
633
        if ncol == 1:
634
            data[:-1,4] = colors
635
        else:
636
            data[:-1,4:] = colors
637
        if not self.master:
638
            self.mpi.send(data, 0, self.mpitag)
639
        else:
640
            total = [data[:-1]]  # Last row is the dimensions.
641
            n = len(coords)
642
            colsmin = colsmax = 4+ncol
643
            for proc in range(1, self.mpi.size):
644
                self._verb("Receiving from processor "+str(proc))
645
                fdat = self.mpi.receive(float, proc, self.mpitag)[0]
646
                fdat.shape = (-1, fdat[-1])
647
                fdat = fdat[:-1]  # Last row is the dimensions.
648
                total.append(fdat)
649
                n = n + len(fdat)
650
                if fdat.shape[1] < colsmin:
651
                    colsmin = fdat.shape[1]
652
                if fdat.shape[1] > colsmax:
653
                    colsmax = fdat.shape[1]
654
            self._verb("Merging data")
655
            # Some processors may have only greyscales whereas others
656
            # may have RGB.  That will cause difficulties.
657
            trouble = colsmax != colsmin
658
            data = zeros((n, colsmax), float)
659
            if trouble:
660
                assert data.shape[1] == 7
661
            else:
662
                assert data.shape[1] == 7 or data.shape[1] == 5
663
            i = 0
664
            for d in total:
665
                if not trouble or d.shape[1] == 7:
666
                    data[i:i+len(d)] = d
667
                else:
668
                    assert d.shape[1] == 5
669
                    data[i:i+len(d), :5] = d
670
                    data[i:i+len(d), 5] = d[4]
671
                    data[i:i+len(d), 6] = d[4]
672
                i = i + len(d)
673
            assert i == len(data)
674
            # Now all data is on the master
675
            self._verb("Sorting merged data")
676
            order = argsort(data[:,2])
677
            data = data[order]   ### take(data, order)
678
            coords = data[:,:3]
679
            radii = data[:,3]
680
            if data.shape[1] == 5:
681
                colors = data[:,4]
682
            else:
683
                colors = data[:,4:]
684
            PrimiPlotter._makeoutput(self, scale, coords, radii, colors)
685
    
686
class _PostScriptDevice:
687
    """PostScript based output device."""
688
    offset = (0,0)   # Will be changed by some classes
689
    def __init__(self):
690
        self.scale = 1
691
        self.linewidth = 1
692
        self.outline = 1
693
        
694
    def set_dimensions(self, dims):
695
        self.dims = dims
696

    
697
    def set_owner(self, owner):
698
        self.owner = owner
699
        
700
    def inform_about_scale(self, scale):
701
        self.linewidth = 0.1 * scale
702

    
703
    def set_outline(self, value):
704
        self.outline = value
705
        return self   # Can chain these calls in set_output()
706
        
707
    def plot(self, *args, **kargs):
708
        self.Doplot(self.PSplot, *args, **kargs)
709
        
710
    def plotArray(self, *args, **kargs):
711
        self.Doplot(self.PSplotArray, *args, **kargs)
712
        
713
    def PSplot(self, file, n, coords, r, colors, noshowpage=0):
714
        xy = coords[:,:2]
715
        assert(len(xy) == len(r) and len(xy) == len(colors))
716
        if len(colors.shape) == 1:
717
            gray = 1
718
        else:
719
            gray = 0
720
            assert(colors.shape[1] == 3)
721
        file.write("%!PS-Adobe-2.0\n")
722
        file.write("%%Creator: Primiplot\n")
723
        file.write("%%Pages: 1\n")        
724
        file.write("%%%%BoundingBox: %d %d %d %d\n" %
725
                   (self.offset + (self.offset[0] + self.dims[0],
726
                                   self.offset[1] + self.dims[1])))
727
        file.write("%%EndComments\n")
728
        file.write("\n")
729
        file.write("% Enforce BoundingBox\n")
730
        file.write("%d %d moveto %d 0 rlineto 0 %d rlineto -%d 0 rlineto\n" %
731
                   ((self.offset + self.dims + (self.dims[0],))))
732
        file.write("closepath clip newpath\n\n")
733
        file.write("%f %f scale\n" % (2*(1.0/self.scale,)))
734
        file.write("%d %d translate\n" % (self.scale * self.offset[0],
735
                                          self.scale * self.offset[1]))
736
        file.write("\n")
737
        if gray:
738
            if self.outline:
739
                file.write("/circ { 0 360 arc gsave setgray fill grestore stroke } def\n")
740
            else:
741
                file.write("/circ { 0 360 arc setgray fill } def\n")
742
        else:
743
            if self.outline:
744
                file.write("/circ { 0 360 arc gsave setrgbcolor fill grestore stroke } def\n")
745
            else:
746
                file.write("/circ { 0 360 arc setrgbcolor fill } def\n")
747
        file.write("%f setlinewidth 0.0 setgray\n" %
748
                   (self.linewidth * self.scale,))
749
        
750
        if gray:
751
            data = zeros((len(xy), 4), float)
752
            data[:,0] = colors
753
            data[:,1:3] = (self.scale * xy)
754
            data[:,3] = (self.scale * r)
755
            for point in data:
756
                file.write("%.3f %.2f %.2f %.2f circ\n" % tuple(point))
757
        else:
758
            data = zeros((len(xy), 6), float)
759
            data[:,0:3] = colors
760
            data[:,3:5] = (self.scale * xy)
761
            data[:,5] = (self.scale * r)
762
            for point in data:
763
                file.write("%.3f %.3f %.3f %.2f %.2f %.2f circ\n" % tuple(point))
764
        if not noshowpage:
765
            file.write("showpage\n")
766
            
767
    def PSplotArray(self, file, n, data, noshowpage=0):
768
        assert(len(data.shape) == 3)
769
        assert(data.shape[0] == self.dims[1] and data.shape[1] == self.dims[0])
770
        data = clip((256*data).astype(int), 0, 255)
771
        file.write("%!PS-Adobe-2.0\n")
772
        file.write("%%Creator: Fieldplotter\n")
773
        file.write("%%Pages: 1\n")        
774
        file.write("%%%%BoundingBox: %d %d %d %d\n" %
775
                   (self.offset + (self.offset[0] + self.dims[0],
776
                                   self.offset[1] + self.dims[1])))
777
        file.write("%%EndComments\n")
778
        file.write("\n")
779
        file.write("%d %d translate\n" % self.offset)
780
        file.write("%f %f scale\n" % self.dims)
781
        file.write("\n")
782
        file.write("% String holding a single line\n")
783
        file.write("/pictline %d string def\n" %(data.shape[1]*data.shape[2],))
784
        file.write("\n")
785
        file.write("%d %d 8\n" % self.dims)
786
        file.write("[%d 0 0 %d 0 0]\n" % self.dims)
787
        file.write("{currentfile pictline readhexstring pop}\n")
788
        file.write("false %d colorimage\n" % (data.shape[2],))
789
        file.write("\n")
790
        s = ""
791
        for d in data.flat:
792
            s += ("%02X" % d)
793
            if len(s) >= 72:
794
                file.write(s+"\n")
795
                s = ""
796
        file.write(s+"\n")
797
        file.write("\n")
798
        if not noshowpage:
799
            file.write("showpage\n")
800
            
801
class _PostScriptToFile(_PostScriptDevice):
802
    """Output device for PS files."""
803
    compr_suffix = None
804
    def __init__(self, prefix, compress = 0):
805
        self.compress = compress
806
        if "'" in prefix:
807
            raise ValueError, "Filename may not contain a quote ('): "+prefix
808
        if "%" in prefix:
809
            # Assume the user knows what (s)he is doing
810
            self.filenames = prefix
811
        else:
812
            self.filenames = prefix + "%04d" + self.suffix
813
            if compress:
814
                if self.compr_suffix is None:
815
                    raise RuntimeError, "Compression not supported."
816
                self.filenames = self.filenames + self.compr_suffix
817
        _PostScriptDevice.__init__(self)
818

    
819
class PostScriptFile(_PostScriptToFile):
820
    suffix = ".ps"
821
    compr_suffix = ".gz"
822
    offset = (50,50)
823
    # Inherits __init__
824

    
825
    def Doplot(self, plotmethod, n, *args, **kargs):
826
        filename = self.filenames % (n,)
827
        self.owner.log("Output to PostScript file "+filename)
828
        if self.compress:
829
            file = os.popen("gzip > '"+filename+"'", "w")
830
        else:
831
            file = open(filename, "w")
832
        apply(plotmethod, (file, n)+args, kargs)
833
        file.close()
834

    
835
class _PS_via_PnmFile(_PostScriptToFile):
836
    gscmd = "gs -q -sDEVICE=pnmraw -sOutputFile=- -dDEVICEWIDTH=%d -dDEVICEHEIGHT=%d - "
837
    # Inherits __init__
838

    
839
    def Doplot(self, plotmethod, n, *args, **kargs):
840
        filename = self.filenames % (n,)
841
        self.owner.log("Output to bitmapped file " + filename)
842
        cmd = self.gscmd + self.converter
843
        if self.compress:
844
            cmd = cmd + "| gzip "
845
            
846
        cmd = (cmd+" > '%s'") % (self.dims[0], self.dims[1], filename)
847
        file = os.popen(cmd, "w")
848
        apply(plotmethod, (file, n)+args, kargs)
849
        file.close()
850

    
851
class PnmFile(_PS_via_PnmFile):
852
    suffix = ".pnm"
853
    compr_suffix = ".gz"
854
    converter = ""
855

    
856
class GifFile(_PS_via_PnmFile):
857
    suffix = ".gif"
858
    converter = "| ppmquant -floyd 256 2>/dev/null | ppmtogif 2>/dev/null"
859

    
860
class JpegFile(_PS_via_PnmFile):
861
    suffix = ".jpeg"
862
    converter = "| ppmtojpeg --smooth=5"
863
    
864
class X11Window(_PostScriptDevice):
865
    """Shows the plot in an X11 window."""
866
    #Inherits __init__
867
    gscmd = "gs -q -sDEVICE=x11 -dDEVICEWIDTH=%d -dDEVICEHEIGHT=%d -r72x72 -"
868
    def Doplot(self, plotmethod, n, *args, **kargs):
869
        self.owner.log("Output to X11 window")
870
        try:
871
            file = self.pipe
872
            self.pipe.write("showpage\n")
873
        except AttributeError:
874
            filename = self.gscmd % tuple(self.dims)
875
            file = os.popen(filename, "w")
876
            self.pipe = file
877
        kargs["noshowpage"] = 1
878
        apply(plotmethod, (file, n)+args, kargs)
879
        file.write("flushpage\n")
880
        file.flush()
881

    
882
# Helper functions
883
def _rot(v, axis):
884
    ax1, ax2 = ((1, 2), (0, 2), (0, 1))[axis]
885
    c, s = cos(v), sin(v)
886
    m = zeros((3,3), float)
887
    m[axis,axis] = 1.0
888
    m[ax1,ax1] = c
889
    m[ax2,ax2] = c
890
    m[ax1,ax2] = s
891
    m[ax2,ax1] = -s
892
    return m
893

    
894
def _colorsfromdict(dict, cls):
895
    """Extract colors from dictionary using cls as key."""
896
    assert(type(dict) == type({}))
897
    # Allow local modifications, to replace strings with rgb values.
898
    dict = dict.copy()  
899
    isgray, isrgb = 0, 0
900
    for k in dict.keys():
901
        v = dict[k]
902
        if type(v) == type("string"):
903
            v = color_table[v]
904
            dict[k] = v
905
        try:
906
            if len(v) == 3:
907
                isrgb = 1 # Assume it is an RGB value
908
                if not hasattr(v, "shape"):
909
                    dict[k] = array(v)   # Convert to array
910
            else:
911
                raise RuntimeError, "Unrecognized color object "+repr(v)
912
        except TypeError:
913
            isgray = 1 # Assume it is a number
914
    if isgray and isrgb:
915
        # Convert all to RGB
916
        for k in dict.keys():
917
            v = dict[k]
918
            if not hasattr(v, "shape"):
919
                dict[k] = v * ones(3, float)
920
    # Now the dictionary is ready
921
    if isrgb:
922
        colors = zeros((len(cls),3), float)
923
    else:
924
        colors = zeros((len(cls),), float)
925
    for i in xrange(len(cls)):
926
        colors[i] = dict[cls[i]]
927
    return colors
928