Statistiques
| Révision :

root / ase / utils / molecule_test.py @ 7

Historique | Voir | Annoter | Télécharger (9,49 ko)

1
#!/usr/bin/env python
2

    
3
"""This module defines extensible classes for running tests on molecules.
4

5
Use this to compare different calculators, XC functionals and so on by
6
calculating e.g. atomization energies and bond lengths across the
7
g2 database of small molecules.
8
"""
9

    
10
import os
11
import sys
12
import traceback
13

    
14
import numpy as np
15

    
16
from ase import PickleTrajectory, read
17
from ase.calculators.emt import EMT
18
from ase.data.molecules import molecule, atoms as g2_atoms, g1
19

    
20

    
21
class BatchTest:
22
    """Contains logic for looping over tests and file management."""
23
    def __init__(self, test):
24
        self.test = test
25
        self.txt = sys.stdout # ?
26

    
27
    def run_single_test(self, formula):
28
        print >> self.txt, self.test.name, formula, '...',
29
        self.txt.flush()
30
        filename = self.test.get_filename(formula)
31
        if os.path.exists(filename):
32
            print >> self.txt, 'Skipped.'
33
            return
34
        try:
35
            open(filename, 'w').close() # Empty file
36
            system, calc = self.test.setup(formula)
37
            self.test.run(formula, system, filename)
38
            print >> self.txt, 'OK!'
39
            self.txt.flush()
40
        except self.test.exceptions:
41
            print >> self.txt, 'Failed!'
42
            traceback.print_exc(file=self.txt)
43
            print >> self.txt
44
            self.txt.flush()
45

    
46
    def run(self, formulas):
47
        """Run a batch of tests.
48

49
        This will invoke the test method on each formula, printing
50
        status to stdout.
51

52
        Those formulas that already have test result files will
53
        be skipped."""
54
        
55
        # Create directories if necessary
56
        if self.test.dir and not os.path.isdir(self.test.dir):
57
            os.mkdir(self.test.dir) # Won't work on 'dir1/dir2', but oh well
58
        
59
        for formula in formulas:
60
            self.run_single_test(formula)
61
    
62
    def collect(self, formulas, verbose=False):
63
        """Yield results of previous calculations."""
64
        for formula in formulas:
65
            try:
66
                filename = self.test.get_filename(formula)
67
                results = self.test.retrieve(formula, filename)
68
                if verbose:
69
                    print >> self.txt, 'Loaded:', formula, filename
70
                yield formula, results
71
            except (IOError, RuntimeError, TypeError):
72
                # XXX which errors should we actually catch?
73
                if verbose:
74
                    print >> self.txt, 'Error:', formula, '[%s]' % filename
75
                    traceback.print_exc(file=self.txt)
76

    
77

    
78
class MoleculeTest:
79
    """Generic class for runnings various tests on the g2 dataset.
80

81
    Usage: instantiate MoleculeTest with desired test settings and
82
    invoke its run() method on the desired formulas.
83

84
    This class will use the ASE EMT calculator by default.  You can
85
    create a subclass using an arbitrary calculator by overriding the
86
    setup_calculator method.  Most methods can be overridden to
87
    provide highly customized behaviour.  """
88
    
89
    def __init__(self, name, vacuum=6.0, exceptions=None):
90
        """Create a molecule test.
91

92
        The name parameter will be part of all output files generated
93
        by this molecule test.  If name contains a '/' character, the
94
        preceding part will be interpreted as a directory in which to
95
        put files.
96

97
        The vacuum parameter is used to set the cell size.
98

99
        A tuple of exception types can be provided which will be
100
        caught during a batch of calculations.  Types not specified
101
        will be considered fatal."""
102

    
103
        dir, path = os.path.split(name)
104
        self.dir = dir
105
        self.name = name
106
        self.vacuum = vacuum
107
        if exceptions is None:
108
            exceptions = ()
109
        self.exceptions = exceptions
110

    
111
    def setup_calculator(self, system, formula):
112
        """Create a new calculator.
113

114
        Default is an EMT calculator.  Most implementations will want to
115
        override this method."""
116
        raise NotImplementedError
117

    
118
    def setup_system(self, formula):
119
        """Create an Atoms object from the given formula.
120

121
        By default this will be loaded from the g2 database, setting
122
        the cell size by means of the molecule test's vacuum parameter."""
123
        system = molecule(formula)
124
        system.center(vacuum=self.vacuum)
125
        return system
126

    
127
    def setup(self, formula):
128
        """Build calculator and atoms objects.
129

130
        This will invoke the setup_calculator and setup_system methods."""
131
        system = self.setup_system(formula)
132
        calc = self.setup_calculator(system, formula)
133
        system.set_calculator(calc)
134
        return system, calc
135
        
136
    def get_filename(self, formula, extension='traj'):
137
        """Returns the filename for a test result file.
138

139
        Default format is <name>.<formula>.traj
140

141
        The test may write other files, but this filename is used as a
142
        flag denoting whether the calculation has been done
143
        already."""
144
        return '.'.join([self.name, formula, extension])
145

    
146
    def run(self, formula, system, filename):
147
        raise NotImplementedError
148

    
149
    def retrieve(self, formula, filename):
150
        """Retrieve results of previous calculation from file.
151

152
        Default implementation returns the total energy.
153

154
        This method should be overridden whenever the test method is
155
        overridden to calculate something else than the total energy."""
156
        raise NotImplementedError
157

    
158

    
159
class EnergyTest:
160
    def run(self, formula, system, filename):
161
        """Calculate energy of specified system and save to file."""
162
        system.get_potential_energy()
163
         # Won't create .bak file:
164
        traj = PickleTrajectory(open(filename, 'w'), 'w')
165
        traj.write(system)
166
        traj.close()
167

    
168
    def retrieve(self, formula, filename):
169
        system = read(filename)
170
        energy = system.get_potential_energy()
171
        return energy
172
    
173
    def calculate_atomization_energies(self, molecular_energies,
174
                                       atomic_energies):
175
        atomic_energy_dict = dict(atomic_energies)
176
        for formula, molecular_energy in molecular_energies:
177
            try:
178
                system = molecule(formula)
179
                atomic = [atomic_energy_dict[s]
180
                          for s in system.get_chemical_symbols()]            
181
                atomization_energy = molecular_energy - sum(atomic)
182
                yield formula, atomization_energy
183
            except KeyError:
184
                pass
185

    
186

    
187
class BondLengthTest:
188
    def run(self, formula, system, filename):
189
        """Calculate bond length of a dimer.
190

191
        This will calculate total energies for varying atomic
192
        separations close to the g2 bond length, allowing
193
        determination of bond length by fitting.
194
        """
195
        if len(system) != 2:
196
            raise ValueError('Not a dimer')
197
        traj = PickleTrajectory(open(filename, 'w'), 'w')
198
        pos = system.positions
199
        d = np.linalg.norm(pos[1] - pos[0])
200
        for x in range(-2, 3):
201
            system.set_distance(0, 1, d * (1.0 + x * 0.02))
202
            traj.write(system)
203
        traj.close()
204
    
205
    def retrieve(self, formula, filename):
206
        traj = PickleTrajectory(filename, 'r')
207
        distances = np.array([np.linalg.norm(a.positions[1] - a.positions[0])
208
                              for a in traj])
209
        energies = np.array([a.get_potential_energy() for a in traj])
210
        polynomial = np.polyfit(distances, energies, 2) # or maybe 3rd order?
211
        # With 3rd order it is not always obvious which root is right
212
        pderiv = np.polyder(polynomial, 1)
213
        d0 = np.roots(pderiv)
214
        e0 = np.polyval(energies, d0)
215
        return distances, energies, d0, e0, polynomial
216
    
217

    
218
class EMTTest(MoleculeTest):
219
    def setup_calculator(self, system, calculator):
220
        return EMT()
221

    
222

    
223
class EMTEnergyTest(EnergyTest, EMTTest):
224
    pass
225

    
226

    
227
class EMTBondLengthTest(BondLengthTest, EMTTest):
228
    pass
229

    
230

    
231
def main():
232
    supported_elements = 'Ni, C, Pt, Ag, H, Al, O, N, Au, Pd, Cu'.split(', ')
233
    formulas = [formula for formula in g1
234
                if np.all([symbol in supported_elements
235
                           for symbol
236
                           in molecule(formula).get_chemical_symbols()])]
237
    
238
    atoms = [symbol for symbol in g2_atoms if symbol in supported_elements]
239
    dimers = [formula for formula in formulas if len(molecule(formula)) == 2]
240

    
241

    
242
    name1 = 'testfiles/energy'
243
    name2 = 'testfiles/bond'
244
    test1 = BatchTest(EMTEnergyTest(name1, vacuum=3.0))
245
    test2 = BatchTest(EMTBondLengthTest(name2, vacuum=3.0))
246

    
247
    print 'Energy test'
248
    print '-----------'
249
    test1.run(formulas + atoms)
250

    
251
    print
252
    print 'Bond length test'
253
    print '----------------'
254
    test2.run(dimers)
255

    
256
    print
257
    print 'Atomization energies'
258
    print '--------------------'
259
    atomic_energies = dict(test1.collect(atoms))
260
    molecular_energies = dict(test1.collect(formulas))
261
    atomization_energies = {}
262
    for formula, energy in molecular_energies.iteritems():
263
        system = molecule(formula)
264
        atomic = [atomic_energies[s] for s in system.get_chemical_symbols()]
265
        atomization_energy = energy - sum(atomic)
266
        atomization_energies[formula] = atomization_energy
267
        print formula.rjust(10), '%.02f' % atomization_energy
268

    
269
    print
270
    print 'Bond lengths'
271
    print '------------'
272
    for formula, (d_i, e_i, d0, e0, poly) in test2.collect(dimers):
273
        system = molecule(formula)
274
        bref = np.linalg.norm(system.positions[1] - system.positions[0])
275
        print formula.rjust(10), '%6.3f' % d0, '  g2ref =', '%2.3f' % bref
276

    
277
        
278
if __name__ == '__main__':
279
    main()