Statistiques
| Révision :

root / ase / optimize / test / __init__.py @ 1

Historique | Voir | Annoter | Télécharger (4,33 ko)

1
"""Define a helper function for running tests
2

3
The skeleton for making a new setup is as follows:
4

5
from ase.optimize.test import run_test
6

7
def get_atoms():
8
    return Atoms('H')
9

10
def get_calculator():
11
    return EMT()
12

13
run_test(get_atoms, get_calculator, 'Hydrogen')
14
"""
15
import matplotlib
16
matplotlib.rcParams['backend']="Agg"
17

    
18
from ase.optimize.bfgs import BFGS
19
from ase.optimize.lbfgs import LBFGS, LBFGSLineSearch
20
from ase.optimize.fire import FIRE
21
from ase.optimize.mdmin import MDMin
22
from ase.optimize.sciopt import SciPyFminCG
23
from ase.optimize.sciopt import SciPyFminBFGS
24
from ase.optimize.bfgslinesearch import BFGSLineSearch
25

    
26
from ase.parallel import rank, paropen
27

    
28
import matplotlib.pyplot as pl
29
import numpy as np
30

    
31
import traceback
32

    
33
optimizers = [
34
    'BFGS',
35
    'LBFGS',
36
    'LBFGSLineSearch',
37
    'FIRE',
38
    'MDMin',
39
    'SciPyFminCG',
40
    'SciPyFminBFGS',
41
    'BFGSLineSearch'
42
]
43

    
44
def get_optimizer(optimizer):
45
    if optimizer == 'BFGS': return BFGS
46
    elif optimizer == 'LBFGS': return LBFGS
47
    elif optimizer == 'LBFGSLineSearch': return LBFGSLineSearch
48
    elif optimizer == 'FIRE': return FIRE
49
    elif optimizer == 'MDMin': return MDMin
50
    elif optimizer == 'SciPyFminCG': return SciPyFminCG
51
    elif optimizer == 'SciPyFminBFGS': return SciPyFminBFGS
52
    elif optimizer == 'BFGSLineSearch': return BFGSLineSearch
53

    
54
def run_test(get_atoms, get_calculator, name,
55
             fmax=0.05, steps=100, plot=True):
56

    
57
    plotter = Plotter(name, fmax)
58
    csvwriter = CSVWriter(name)
59
    for optimizer in optimizers:
60
        note = ''
61
        logname = name + '-' + optimizer
62

    
63
        atoms = get_atoms()
64
        atoms.set_calculator(get_calculator())
65
        opt = get_optimizer(optimizer)
66
        relax = opt(atoms, logfile=None)
67
                    #logfile = logname + '.log',
68
                    #trajectory = logname + '.traj')
69

    
70
        obs = DataObserver(atoms)
71
        relax.attach(obs)
72
        try:
73
            relax.run(fmax = fmax, steps = steps)
74
            E = atoms.get_potential_energy()
75

    
76
            if relax.get_number_of_steps() == steps:
77
                note = 'Not converged in %i steps' % steps
78
        except Exception:
79
            traceback.print_exc()
80
            note = 'An exception occurred'
81
            E = np.nan
82

    
83
        nsteps = relax.get_number_of_steps()
84
        if hasattr(relax, 'force_calls'):
85
            fc = relax.force_calls
86
            if rank == 0:
87
                print '%-15s %-15s %3i %8.3f (%3i) %s' % (name, optimizer, nsteps, E, fc, note)
88
        else:
89
            fc = nsteps
90
            if rank == 0:
91
                print '%-15s %-15s %3i %8.3f       %s' % (name, optimizer, nsteps, E, note)
92

    
93
        plotter.plot(optimizer, obs.get_E(), obs.get_fmax())
94
        csvwriter.write(optimizer, nsteps, E, fc, note)
95

    
96
    plotter.save()
97
    csvwriter.finalize()
98

    
99
class Plotter:
100
    def __init__(self, name, fmax):
101
        self.name = name
102
        self.fmax = fmax
103
        if rank == 0: 
104
            self.fig = pl.figure(figsize=[12.0, 9.0])
105
            self.axes0 = self.fig.add_subplot(2, 1, 1)
106
            self.axes1 = self.fig.add_subplot(2, 1, 2)
107

    
108
    def plot(self, optimizer, E, fmax):
109
        if rank == 0:
110
            self.axes0.plot(E, label = optimizer)
111
            self.axes1.plot(fmax)
112

    
113
    def save(self, format='png'):
114
        if rank == 0:
115
            self.axes0.legend()
116
            self.axes0.set_title(self.name)
117
            self.axes0.set_ylabel('E [eV]')
118
            #self.axes0.set_yscale('log')
119

    
120
            self.axes1.set_xlabel('steps')
121
            self.axes1.set_ylabel('fmax [eV/A]')
122
            self.axes1.set_yscale('log')
123
            self.axes1.axhline(self.fmax, color='k', linestyle='--')
124
            self.fig.savefig(self.name + '.' + format)
125

    
126
class CSVWriter:
127
    def __init__(self, name):
128
        self.f = paropen(name + '.csv', 'w')
129

    
130
    def write(self, optimizer, nsteps, E, fc, note=''):
131
        self.f.write(
132
            '%s,%i,%i,%f,%s\n' % (optimizer, nsteps, fc, E, note)
133
        )
134

    
135
    def finalize(self):
136
        self.f.close()
137

    
138
class DataObserver:
139
    def __init__(self, atoms):
140
        self.atoms = atoms
141
        self.E = []
142
        self.fmax = []
143

    
144
    def __call__(self):
145
        self.E.append(self.atoms.get_potential_energy())
146
        self.fmax.append(np.sqrt((self.atoms.get_forces()**2).sum(axis=1)).max())
147

    
148
    def get_E(self):
149
        return np.array(self.E)
150

    
151
    def get_fmax(self):
152
        return np.array(self.fmax)