Statistiques
| Révision :

root / ase / io / eps.py @ 15

Historique | Voir | Annoter | Télécharger (6,04 ko)

1
import time
2
from math import sqrt
3

    
4
import numpy as np
5

    
6
from ase.utils import rotate
7
from ase.data import covalent_radii
8
from ase.data.colors import jmol_colors
9

    
10

    
11
class EPS:
12
    def __init__(self, atoms,
13
                 rotation='', show_unit_cell=False, radii=None,
14
                 bbox=None, colors=None, scale=20):
15
        self.numbers = atoms.get_atomic_numbers()
16
        self.colors = colors
17
        if colors is None:
18
            self.colors = jmol_colors[self.numbers]
19

    
20
        if radii is None:
21
            radii = covalent_radii[self.numbers]
22
        elif type(radii) is float:
23
            radii = covalent_radii[self.numbers] * radii
24
            
25
        natoms = len(atoms)
26

    
27
        if isinstance(rotation, str):
28
            rotation = rotate(rotation)
29

    
30
        A = atoms.get_cell()
31
        if show_unit_cell > 0:
32
            L, T, D = self.cell_to_lines(A)
33
            C = np.empty((2, 2, 2, 3))
34
            for c1 in range(2):
35
                for c2 in range(2):
36
                    for c3 in range(2):
37
                        C[c1, c2, c3] = np.dot([c1, c2, c3], A)
38
            C.shape = (8, 3)
39
            C = np.dot(C, rotation) # Unit cell vertices
40
        else:
41
            L = np.empty((0, 3))
42
            T = None
43
            D = None
44
            C = None
45

    
46
        nlines = len(L)
47

    
48
        X = np.empty((natoms + nlines, 3))
49
        R = atoms.get_positions()
50
        X[:natoms] = R
51
        X[natoms:] = L
52

    
53
        r2 = radii**2
54
        for n in range(nlines):
55
            d = D[T[n]]
56
            if ((((R - L[n] - d)**2).sum(1) < r2) &
57
                (((R - L[n] + d)**2).sum(1) < r2)).any():
58
                T[n] = -1
59

    
60
        X = np.dot(X, rotation)
61
        R = X[:natoms]
62

    
63
        if bbox is None:
64
            X1 = (R - radii[:, None]).min(0) 
65
            X2 = (R + radii[:, None]).max(0) 
66
            if show_unit_cell == 2:
67
                X1 = np.minimum(X1, C.min(0))
68
                X2 = np.maximum(X2, C.max(0))
69
            M = (X1 + X2) / 2
70
            S = 1.05 * (X2 - X1)
71
            w = scale * S[0]
72
            if w > 500:
73
                w = 500
74
                scale = w / S[0]
75
            h = scale * S[1]
76
            offset = np.array([scale * M[0] - w / 2, scale * M[1] - h / 2, 0])
77
        else:
78
            w = (bbox[2] - bbox[0]) * scale
79
            h = (bbox[3] - bbox[1]) * scale
80
            offset = np.array([bbox[0], bbox[1], 0]) * scale
81

    
82
        self.w = w
83
        self.h = h
84
        
85
        X *= scale
86
        X -= offset
87

    
88
        if nlines > 0:
89
            D = np.dot(D, rotation)[:, :2] * scale
90
        
91
        if C is not None:
92
            C *= scale
93
            C -= offset
94

    
95
        A = np.dot(A, rotation)
96
        A *= scale
97

    
98
        self.A = A
99
        self.X = X
100
        self.D = D
101
        self.T = T
102
        self.C = C
103
        self.natoms = natoms
104
        self.d = 2 * scale * radii
105

    
106
    def cell_to_lines(self, A):
107
        nlines = 0
108
        nn = []
109
        for c in range(3):
110
            d = sqrt((A[c]**2).sum())
111
            n = max(2, int(d / 0.3))
112
            nn.append(n)
113
            nlines += 4 * n
114

    
115
        X = np.empty((nlines, 3))
116
        T = np.empty(nlines, int)
117
        D = np.zeros((3, 3))
118

    
119
        n1 = 0
120
        for c in range(3):
121
            n = nn[c]
122
            dd = A[c] / (4 * n - 2)
123
            D[c] = dd
124
            P = np.arange(1, 4 * n + 1, 4)[:, None] * dd
125
            T[n1:] = c
126
            for i, j in [(0, 0), (0, 1), (1, 0), (1, 1)]:
127
                n2 = n1 + n
128
                X[n1:n2] = P + i * A[(c + 1) % 3] + j * A[(c + 2) % 3]
129
                n1 = n2
130

    
131
        return X, T, D
132

    
133
    def write(self, filename):
134
        self.filename = filename
135
        self.write_header()
136
        self.write_body()
137
        self.write_trailer()
138

    
139
    def write_header(self):
140
        import matplotlib
141
        if matplotlib.__version__ <= '0.8':
142
            raise RuntimeError('Your version of matplotlib (%s) is too old' %
143
                               matplotlib.__version__)
144

    
145
        from matplotlib.backends.backend_ps import RendererPS, \
146
             GraphicsContextPS, psDefs
147

    
148
        self.fd = open(self.filename, 'w')
149
        self.fd.write('%!PS-Adobe-3.0 EPSF-3.0\n')
150
        self.fd.write('%%Creator: G2\n')
151
        self.fd.write('%%CreationDate: %s\n' % time.ctime(time.time()))
152
        self.fd.write('%%Orientation: portrait\n')
153
        bbox = (0, 0, self.w, self.h)
154
        self.fd.write('%%%%BoundingBox: %d %d %d %d\n' % bbox)
155
        self.fd.write('%%EndComments\n')
156

    
157
        Ndict = len(psDefs)
158
        self.fd.write('%%BeginProlog\n')
159
        self.fd.write('/mpldict %d dict def\n' % Ndict)
160
        self.fd.write('mpldict begin\n')
161
        for d in psDefs:
162
            d = d.strip()
163
            for l in d.split('\n'):
164
                self.fd.write(l.strip() + '\n')
165
        self.fd.write('%%EndProlog\n')
166

    
167
        self.fd.write('mpldict begin\n')
168
        self.fd.write('%d %d 0 0 clipbox\n' % (self.w, self.h))
169

    
170
        self.renderer = RendererPS(self.w, self.h, self.fd)
171
        
172
    def write_body(self):
173
        try:
174
            from matplotlib.path import Path
175
        except ImportError:
176
            Path = None
177
            from matplotlib.patches import Circle, Polygon
178
        else:
179
            from matplotlib.patches import Circle, PathPatch
180

    
181
        indices = self.X[:, 2].argsort()
182
        for a in indices:
183
            xy = self.X[a, :2]
184
            if a < self.natoms:
185
                circle = Circle(xy, self.d[a] / 2, facecolor=self.colors[a])
186
                circle.draw(self.renderer)
187
            else:
188
                a -= self.natoms
189
                c = self.T[a]
190
                if c != -1:
191
                    hxy = self.D[c]
192
                    if Path is None:
193
                        line = Polygon((xy + hxy, xy - hxy))
194
                    else:
195
                        line = PathPatch(Path((xy + hxy, xy - hxy)))
196
                    line.draw(self.renderer)
197

    
198
    def write_trailer(self):
199
        self.fd.write('end\n')
200
        self.fd.write('showpage\n')
201
        self.fd.close()
202

    
203

    
204
def write_eps(filename, atoms, **parameters):
205
    if isinstance(atoms, list):
206
        assert len(atoms) == 1
207
        atoms = atoms[0]
208
    EPS(atoms, **parameters).write(filename)