Statistiques
| Révision :

root / ase / io / trajectory.py @ 1

Historique | Voir | Annoter | Télécharger (11,7 ko)

1
import os
2
import cPickle as pickle
3

    
4
from ase.calculators.singlepoint import SinglePointCalculator
5
from ase.atoms import Atoms
6
from ase.parallel import rank, barrier
7
from ase.utils import devnull
8

    
9

    
10
class PickleTrajectory:
11
    "Reads/writes Atoms objects into a .traj file."
12
    # Per default, write these quantities
13
    write_energy = True
14
    write_forces = True
15
    write_stress = True
16
    write_magmoms = True
17
    write_momenta = True
18
    
19
    def __init__(self, filename, mode='r', atoms=None, master=None,
20
                 backup=True):
21
        """A PickleTrajectory can be created in read, write or append mode.
22

23
        Parameters:
24

25
        filename:
26
            The name of the parameter file.  Should end in .traj.
27

28
        mode='r':
29
            The mode.
30

31
            'r' is read mode, the file should already exist, and
32
            no atoms argument should be specified.
33

34
            'w' is write mode.  If the file already exists, is it
35
            renamed by appending .bak to the file name.  The atoms
36
            argument specifies the Atoms object to be written to the
37
            file, if not given it must instead be given as an argument
38
            to the write() method.
39

40
            'a' is append mode.  It acts a write mode, except that
41
            data is appended to a preexisting file.
42

43
        atoms=None:
44
            The Atoms object to be written in write or append mode.
45

46
        master=None:
47
            Controls which process does the actual writing. The
48
            default is that process number 0 does this.  If this
49
            argument is given, processes where it is True will write.
50

51
        backup=True:
52
            Use backup=False to disable renaming of an existing file.
53
        """
54
        self.offsets = []
55
        if master is None:
56
            master = (rank == 0)
57
        self.master = master
58
        self.backup = backup
59
        self.set_atoms(atoms)
60
        self.open(filename, mode)
61

    
62
    def open(self, filename, mode):
63
        """Opens the file.
64

65
        For internal use only.
66
        """
67
        self.fd = filename
68
        if mode == 'r':
69
            if isinstance(filename, str):
70
                self.fd = open(filename, 'rb')
71
            self.read_header()
72
        elif mode == 'a':
73
            exists = True
74
            if isinstance(filename, str):
75
                exists = os.path.isfile(filename)
76
                if exists:
77
                    self.fd = open(filename, 'rb')
78
                    self.read_header()
79
                    self.fd.close()
80
                barrier()
81
                if self.master:
82
                    self.fd = open(filename, 'ab+')
83
                else:
84
                    self.fd = devnull
85
        elif mode == 'w':
86
            if self.master:
87
                if isinstance(filename, str):
88
                    if self.backup and os.path.isfile(filename):
89
                        os.rename(filename, filename + '.bak')
90
                    self.fd = open(filename, 'wb')
91
            else:
92
                self.fd = devnull
93
        else:
94
            raise ValueError('mode must be "r", "w" or "a".')
95

    
96
    def set_atoms(self, atoms=None):
97
        """Associate an Atoms object with the trajectory.
98

99
        Mostly for internal use.
100
        """
101
        if atoms is not None and not hasattr(atoms, 'get_positions'):
102
            raise TypeError('"atoms" argument is not an Atoms object.')
103
        self.atoms = atoms
104

    
105
    def read_header(self):
106
        self.fd.seek(0)
107
        try:
108
            if self.fd.read(len('PickleTrajectory')) != 'PickleTrajectory':
109
                raise IOError('This is not a trajectory file!')
110
            d = pickle.load(self.fd)
111
        except EOFError:
112
            raise EOFError('Bad trajectory file.')
113
        self.pbc = d['pbc']
114
        self.numbers = d['numbers']
115
        self.tags = d.get('tags')
116
        self.masses = d.get('masses')
117
        self.constraints = d['constraints']
118
        self.offsets.append(self.fd.tell())
119

    
120
    def write(self, atoms=None):
121
        """Write the atoms to the file.
122

123
        If the atoms argument is not given, the atoms object specified
124
        when creating the trajectory object is used.
125
        """
126
        if atoms is None:
127
            atoms = self.atoms
128

    
129
        if hasattr(atoms, 'interpolate'):
130
            # seems to be a NEB
131
            neb = atoms
132
            try:
133
                neb.get_energies_and_forces(all=True)
134
            except AttributeError:
135
                pass
136
            for image in neb.images:
137
                self.write(image)
138
            return
139

    
140
        if len(self.offsets) == 0:
141
            self.write_header(atoms)
142

    
143
        if atoms.has('momenta'):
144
            momenta = atoms.get_momenta()
145
        else:
146
            momenta = None
147

    
148
        d = {'positions': atoms.get_positions(),
149
             'cell': atoms.get_cell(),
150
             'momenta': momenta}
151

    
152
        if atoms.get_calculator() is not None:
153
            if self.write_energy:
154
                d['energy'] = atoms.get_potential_energy()
155
            if self.write_forces:
156
                assert self.write_energy
157
                try:
158
                    d['forces'] = atoms.get_forces(apply_constraint=False)
159
                except NotImplementedError:
160
                    pass
161
            if self.write_stress:
162
                assert self.write_energy
163
                try:
164
                    d['stress'] = atoms.get_stress()
165
                except NotImplementedError:
166
                    pass
167

    
168
            if self.write_magmoms:
169
                try:
170
                    if atoms.calc.get_spin_polarized():
171
                        d['magmoms'] = atoms.get_magnetic_moments()
172
                except (NotImplementedError, AttributeError):
173
                    pass
174

    
175
        if 'magmoms' not in d and atoms.has('magmoms'):
176
            d['magmoms'] = atoms.get_initial_magnetic_moments()
177
            
178
        if self.master:
179
            pickle.dump(d, self.fd, protocol=-1)
180
        self.fd.flush()
181
        self.offsets.append(self.fd.tell())
182

    
183
    def write_header(self, atoms):
184
        self.fd.write('PickleTrajectory')
185
        if atoms.has('tags'):
186
            tags = atoms.get_tags()
187
        else:
188
            tags = None
189
        if atoms.has('masses'):
190
            masses = atoms.get_masses()
191
        else:
192
            masses = None
193
        d = {'pbc': atoms.get_pbc(),
194
             'numbers': atoms.get_atomic_numbers(),
195
             'tags': tags,
196
             'masses': masses,
197
             'constraints': atoms.constraints}
198
        pickle.dump(d, self.fd, protocol=-1)
199
        self.header_written = True
200
        self.offsets.append(self.fd.tell())
201
        
202
    def close(self):
203
        """Close the trajectory file."""
204
        self.fd.close()
205

    
206
    def __getitem__(self, i=-1):
207
        N = len(self.offsets)
208
        if 0 <= i < N:
209
            self.fd.seek(self.offsets[i])
210
            try:
211
                d = pickle.load(self.fd)
212
            except EOFError:
213
                raise IndexError
214
            if i == N - 1:
215
                self.offsets.append(self.fd.tell())
216
            try:
217
                magmoms = d['magmoms']
218
            except KeyError:
219
                magmoms = None    
220
            atoms = Atoms(positions=d['positions'],
221
                          numbers=self.numbers,
222
                          cell=d['cell'],
223
                          momenta=d['momenta'],
224
                          magmoms=magmoms,
225
                          tags=self.tags,
226
                          masses=self.masses,
227
                          pbc=self.pbc,
228
                          constraint=[c.copy() for c in self.constraints])
229
            if 'energy' in d:
230
                calc = SinglePointCalculator(
231
                    d.get('energy', None), d.get('forces', None),
232
                    d.get('stress', None), magmoms, atoms)
233
                atoms.set_calculator(calc)
234
            return atoms
235

    
236
        if i >= N:
237
            for j in range(N - 1, i + 1):
238
                atoms = self[j]
239
            return atoms
240

    
241
        i = len(self) + i
242
        if i < 0:
243
            raise IndexError('Trajectory index out of range.')
244
        return self[i]
245

    
246
    def __len__(self):
247
        N = len(self.offsets) - 1
248
        while True:
249
            self.fd.seek(self.offsets[N])
250
            try:
251
                pickle.load(self.fd)
252
            except EOFError:
253
                return N
254
            self.offsets.append(self.fd.tell())
255
            N += 1
256

    
257
    def __iter__(self):
258
        del self.offsets[1:]
259
        return self
260

    
261
    def next(self):
262
        try:
263
            return self[len(self.offsets) - 1]
264
        except IndexError:
265
            raise StopIteration
266

    
267
    def guess_offsets(self):
268
        size = os.path.getsize(self.fd.name)
269

    
270
        while True:
271
            self.fd.seek(self.offsets[-1])
272
            try:
273
                pickle.load(self.fd)
274
            except:
275
                raise EOFError('Damaged trajectory file.')
276
            else:
277
                self.offsets.append(self.fd.tell())
278

    
279
            if self.offsets[-1] >= size:
280
                break
281

    
282
            if len(self.offsets) > 2:
283
                step1 = self.offsets[-1] - self.offsets[-2]
284
                step2 = self.offsets[-2] - self.offsets[-3]
285

    
286
                if step1 == step2:
287
                    m = int((size - self.offsets[-1]) / step1) - 1
288

    
289
                    while m > 1:
290
                        self.fd.seek(self.offsets[-1] + m * step1)
291
                        try:
292
                            pickle.load(self.fd)
293
                        except:
294
                            m = m / 2
295
                        else:
296
                            for i in range(m):
297
                                self.offsets.append(self.offsets[-1] + step1)
298
                            m = 0
299

    
300
        return
301

    
302
def read_trajectory(filename, index=-1):
303
    traj = PickleTrajectory(filename, mode='r')
304

    
305
    if isinstance(index, int):
306
        return traj[index]
307
    else:
308
        # Here, we try to read only the configurations we need to read
309
        # and len(traj) should only be called if we need to as it will
310
        # read all configurations!
311

    
312
        # XXX there must be a simpler way?
313
        step = index.step or 1
314
        if step > 0:
315
            start = index.start or 0
316
            if start < 0:
317
                start += len(traj)
318
            stop = index.stop or len(traj)
319
            if stop < 0:
320
                stop += len(traj)
321
        else:
322
            if index.start is None:
323
                start = len(traj) - 1
324
            else:
325
                start = index.start
326
                if start < 0:
327
                    start += len(traj)
328
            if index.stop is None:
329
                stop = -1
330
            else:
331
                stop = index.stop
332
                if stop < 0:
333
                    stop += len(traj)
334
                    
335
        return [traj[i] for i in range(start, stop, step)]
336

    
337
def write_trajectory(filename, images):
338
    """Write image(s) to trajectory.
339

340
    Write also energy, forces, and stress if they are already
341
    calculated."""
342

    
343
    traj = PickleTrajectory(filename, mode='w')
344

    
345
    if not isinstance(images, (list, tuple)):
346
        images = [images]
347
        
348
    for atoms in images:
349
        # Avoid potentially expensive calculations:
350
        calc = atoms.get_calculator()
351
        if calc is not None:
352
            if  hasattr(calc, 'calculation_required'):
353
                if calc.calculation_required(atoms, ['energy']):
354
                    traj.write_energy = False
355
                if calc.calculation_required(atoms, ['forces']):
356
                    traj.write_forces = False
357
                if calc.calculation_required(atoms, ['stress']):
358
                    traj.write_stress = False
359
                if calc.calculation_required(atoms, ['magmoms']):
360
                    traj.write_magmoms = False
361
        else:
362
            traj.write_energy = False
363
            traj.write_forces = False
364
            traj.write_stress = False
365
            traj.write_magmoms = False
366
            
367
        traj.write(atoms)
368
    traj.close()