Statistiques
| Révision :

root / ase / calculators / dacapo.py @ 19

Historique | Voir | Annoter | Télécharger (6,67 ko)

1
import numpy as np
2

    
3
from ase.old import OldASEListOfAtomsWrapper
4

    
5
try:
6
    import Numeric as num
7
except ImportError:
8
    pass
9

    
10
def np2num(a, typecode=None):
11
    if num.__version__ > '23.8':
12
        return num.array(a, typecode)
13
    if typecode is None:
14
        typecode = num.Float
15
    b = num.fromstring(a.tostring(), typecode)
16
    b.shape = a.shape
17
    return b
18

    
19
def restart(filename, **kwargs):
20
    calc = Dacapo(filename, **kwargs)
21
    atoms = calc.get_atoms()
22
    return atoms, calc
23

    
24
class Dacapo:
25
    def __init__(self, filename=None, stay_alive=False, stress=False,
26
                 **kwargs):
27

    
28
        self.kwargs = kwargs
29
        self.stay_alive = stay_alive
30
        self.stress = stress
31
        
32
        if filename is not None:
33
            from Dacapo import Dacapo
34
            self.loa = Dacapo.ReadAtoms(filename, **kwargs)
35
            self.calc = self.loa.GetCalculator()
36
        else:
37
            self.loa = None
38
            self.calc = None
39

    
40
        self.pps = []
41
        
42
    def set_pp(self, Z, path):
43
        self.pps.append((Z, path))
44

    
45
    def set_txt(self, txt):
46
        if self.calc is None:
47
            self.kwargs['txtout'] = txt
48
        else:
49
            self.calc.SetTxtFile(txt)
50

    
51
    def set_nc(self, nc):
52
        if self.calc is None:
53
            self.kwargs['out'] = nc
54
        else:
55
            self.calc.SetNetCDFFile(nc)
56

    
57
    def update(self, atoms):
58
        from Dacapo import Dacapo
59
        if self.calc is None:
60
            if 'nbands' not in self.kwargs:
61
                n = sum([valence[atom.symbol] for atom in atoms])
62
                self.kwargs['nbands'] = int(n * 0.65) + 4
63

    
64
            magmoms = atoms.get_initial_magnetic_moments()
65
            if magmoms.any():
66
                self.kwargs['spinpol'] = True
67

    
68
            self.calc = Dacapo(**self.kwargs)
69

    
70
            if self.stay_alive:
71
                self.calc.StayAliveOn()
72
            else:
73
                self.calc.StayAliveOff()
74

    
75
            if self.stress:
76
                self.calc.CalculateStress()
77

    
78
            for Z, path in self.pps:
79
                self.calc.SetPseudoPotential(Z, path)
80

    
81
        if self.loa is None:
82
            from ASE import Atom, ListOfAtoms
83
            numbers = atoms.get_atomic_numbers()
84
            positions = atoms.get_positions()
85
            magmoms = atoms.get_initial_magnetic_moments()
86
            self.loa = ListOfAtoms([Atom(Z=numbers[a],
87
                                         position=positions[a],
88
                                         magmom=magmoms[a])
89
                                    for a in range(len(atoms))],
90
                                   cell=np2num(atoms.get_cell()),
91
                                   periodic=tuple(atoms.get_pbc()))
92
            self.loa.SetCalculator(self.calc)
93
        else:
94
            self.loa.SetCartesianPositions(np2num(atoms.get_positions()))
95
            self.loa.SetUnitCell(np2num(atoms.get_cell()), fix=True)
96
            
97
    def get_atoms(self):
98
        atoms = OldASEListOfAtomsWrapper(self.loa).copy()
99
        atoms.set_calculator(self)
100
        return atoms
101
    
102
    def get_potential_energy(self, atoms):
103
        self.update(atoms)
104
        return self.calc.GetPotentialEnergy()
105

    
106
    def get_forces(self, atoms):
107
        self.update(atoms)
108
        return np.array(self.calc.GetCartesianForces())
109

    
110
    def get_stress(self, atoms):
111
        self.update(atoms)
112
        stress = np.array(self.calc.GetStress())
113
        if stress.ndim == 2:
114
            return stress.ravel()[[0, 4, 8, 5, 2, 1]]
115
        else:
116
            return stress
117

    
118
    def calculation_required(self, atoms, quantities):
119
        if self.calc is None:
120
            return True
121

    
122
        if atoms != self.get_atoms():
123
            return True
124

    
125
        return False
126
        
127
    def get_number_of_bands(self):
128
        return self.calc.GetNumberOfBands()
129

    
130
    def get_k_point_weights(self):
131
        return np.array(self.calc.GetIBZKPointWeights())
132

    
133
    def get_number_of_spins(self):
134
        return 1 + int(self.calc.GetSpinPolarized())
135

    
136
    def get_eigenvalues(self, kpt=0, spin=0):
137
        return np.array(self.calc.GetEigenvalues(kpt, spin))
138

    
139
    def get_fermi_level(self):
140
        return self.calc.GetFermiLevel()
141

    
142
    def get_number_of_grid_points(self):
143
        return np.array(self.get_pseudo_wave_function(0, 0, 0).shape)
144

    
145
    def get_pseudo_density(self, spin=0):
146
        return np.array(self.calc.GetDensityArray(s))
147
    
148
    def get_pseudo_wave_function(self, band=0, kpt=0, spin=0, pad=True):
149
        kpt_c = self.get_bz_k_points()[kpt]
150
        state = self.calc.GetElectronicStates().GetState(band=band, spin=spin,
151
                                                         kptindex=kpt)
152

    
153
        # Get wf, without bloch phase (Phase = True doesn't do anything!)
154
        wave = state.GetWavefunctionOnGrid(phase=False)
155

    
156
        # Add bloch phase if this is not the Gamma point
157
        if np.all(kpt_c == 0):
158
            return wave
159
        coord = state.GetCoordinates()
160
        phase = coord[0] * kpt_c[0] + coord[1] * kpt_c[1] + coord[2] * kpt_c[2]
161
        return np.array(wave) * np.exp(-2.j * np.pi * phase) # sign! XXX
162

    
163
        #return np.array(self.calc.GetWaveFunctionArray(n, k, s)) # No phase!
164

    
165
    def get_bz_k_points(self):
166
        return np.array(self.calc.GetBZKPoints())
167

    
168
    def get_ibz_k_points(self):
169
        return np.array(self.calc.GetIBZKPoints())
170

    
171
    def get_wannier_localization_matrix(self, nbands, dirG, kpoint,
172
                                        nextkpoint, G_I, spin):
173
        return np.array(self.calc.GetWannierLocalizationMatrix(
174
            G_I=G_I.tolist(), nbands=nbands, dirG=dirG.tolist(),
175
            kpoint=kpoint, nextkpoint=nextkpoint, spin=spin))
176
    
177
    def initial_wannier(self, initialwannier, kpointgrid, fixedstates,
178
                        edf, spin):
179
        # Use initial guess to determine U and C
180
        init = self.calc.InitialWannier(initialwannier, self.atoms,
181
                                        np2num(kpointgrid, num.Int))
182

    
183
        states = self.calc.GetElectronicStates()
184
        waves = [[state.GetWaveFunction()
185
                  for state in states.GetStatesKPoint(k, spin)]
186
                 for k in self.calc.GetIBZKPoints()] 
187

    
188
        init.SetupMMatrix(waves, self.calc.GetBZKPoints())
189
        c, U = init.GetListOfCoefficientsAndRotationMatrices(
190
            (self.calc.GetNumberOfBands(), fixedstates, edf))
191
        U = np.array(U)
192
        for k in range(len(c)):
193
            c[k] = np.array(c[k])
194
        return c, U
195

    
196
valence = {
197
'H':   1,
198
'B':   3,
199
'C':   4,
200
'N':   5,
201
'O':   6,
202
'Li':  1,
203
'Na':  1,
204
'K':   9,
205
'Mg':  8,
206
'Ca': 10,
207
'Sr': 10,
208
'Al':  3,
209
'Ga': 13,
210
'Sc': 11,
211
'Ti': 12,
212
'V':  13,
213
'Cr': 14,
214
'Mn':  7,
215
'Fe':  8,
216
'Co':  9,
217
'Ni': 10,
218
'Cu': 11,
219
'Zn': 12,
220
'Y':  11,
221
'Zr': 12,
222
'Nb': 13,
223
'Mo':  6,
224
'Ru':  8,
225
'Rh':  9,
226
'Pd': 10,
227
'Ag': 11,
228
'Cd': 12,
229
'Ir': 9,
230
'Pt': 10,
231
'Au': 11,
232
}