Statistiques
| Révision :

root / ase / calculators / jacapo / validate.py @ 19

Historique | Voir | Annoter | Télécharger (8,76 ko)

1
import os
2
import numpy as np
3
'''
4
input validation module
5

6
provides functions to validate all input variables to the Jacapo calculator.
7
'''
8

    
9
###########################################3
10
### Validation functions
11
##########################################
12
def valid_int(x):
13
    return isinstance(x, int)
14

    
15
def valid_float(x):
16
    return isinstance(x, float)
17

    
18
def valid_int_or_float(x):
19
    return (isinstance(x, int) or isinstance(x, float))
20

    
21
def valid_boolean(x):
22
    return isinstance(x, bool)
23

    
24
def valid_str(x):
25
    return isinstance(x, str)
26

    
27
def valid_atoms(x):
28
    import ase
29
    return isinstance(x, ase.Atoms)
30

    
31
def valid_pw(x):
32
    return (valid_int_or_float(x) and x>0 and x<2000)
33

    
34
def valid_dw(x):
35
    return (valid_int_or_float(x) and x>0 and x<2000)
36

    
37
def valid_xc(x):
38
    return (x in ['PW91', 'PBE', 'revPBE', 'RPBE', 'VWN'])
39

    
40
def valid_nbands(x):
41
    return valid_int(x)
42

    
43
def valid_ft(x):
44
    return(valid_float, x)
45

    
46
def valid_spinpol(x):
47
    return valid_boolean(x)
48

    
49
def valid_fixmagmom(x):
50
    return valid_float(x)
51

    
52
def valid_symmetry(x):
53
    return valid_boolean(x)
54

    
55
def valid_calculate_stress(x):
56
    return valid_boolean(x)
57

    
58
def valid_kpts(x):
59
    if isinstance(x, str):
60
        return x in ['cc-6-1x1',
61
                     'cc-12-2x3',
62
                     'cc-18-sq3xsq3',
63
                     'cc-18-1x1',
64
                     'cc-54-sq3xsq3',
65
                     'cc-54-1x1',
66
                     'cc-162-1x1']
67
    x = np.array(x)
68
    #empty arg is no good
69
    if x.shape == ():
70
        return False
71
    #monkhorst-pack
72
    elif x.shape == (3,) and ((x.dtype == 'int32') or (x.dtype == 'int64')):
73
        return True
74
    #user-defined list
75
    elif x.shape[1] == 3 and (str(x.dtype))[0:7] == 'float64':
76
        return True
77
    else:
78
        return False
79

    
80
def valid_dipole(x):
81
    if valid_boolean(x):
82
        return True
83
    #dictionary passed in. we need to check the keys
84
    valid_keys = {'status':valid_boolean,
85
                  'mixpar':valid_float,
86
                  'initval':valid_float,
87
                  'adddipfield':valid_float,
88
                  'position':valid_float}
89
    for key in x:
90
        if key not in valid_keys:
91
            return False
92
        else:
93
            if x[key] is not None:
94
                if not valid_keys[key](x[key]):
95
                    return False
96
    return True
97

    
98
def valid_nc(x):
99
    #todo check for read/write access?
100
    return valid_str(x)
101

    
102
def valid_status(x):
103
    return valid_str(x)
104

    
105
def valid_pseudopotentials(x):
106
    #todo check that keys are symbols or numbers
107
    #todo check that psp files exist
108

    
109
    DACAPOPATH = os.environ.get('DACAPOPATH', None)
110
    if DACAPOPATH is None:
111
        raise Exception, 'No $DACAPOPATH found. please set it in .cshrc or .bashrc'
112

    
113
    from ase.data import chemical_symbols
114
    for key in x:
115
        if valid_str(key):
116
            if key not in chemical_symbols:
117
                return False
118
        elif not (valid_int(key) and key > 0 and key < 112):
119
            return False
120

    
121
        #now check for existence of psp files
122
        psp = x[key]
123
        if not (os.path.exists(psp)
124
                or os.path.exists(os.path.join(DACAPOPATH, psp))):
125
            return False
126
    return True
127

    
128
def valid_extracharge(x):
129
    return valid_float(x)
130

    
131
def valid_extpot(x):
132
    grids = get_fftgrid()
133
    if (x.shape == np.array(grids['soft'])).all():
134
        return True
135
    else:
136
        return False
137

    
138
def valid_ascii_debug(x):
139
    return (x in ['Off', 'MediumLevel', 'HighLevel'])
140

    
141
def valid_ncoutput(x):
142
    if x is None:
143
        return
144
    valid_keys = ['wf', 'cd', 'efp', 'esp']
145

    
146
    for key in x:
147
        if key not in valid_keys:
148
            return False
149
        else:
150
            if x[key] not in ['Yes', 'No']:
151
                return False
152
    return True
153

    
154
def valid_ados(x):
155
    if x is None:
156
        return
157
    valid_keys = ['energywindow',
158
                  'energywidth',
159
                  'npoints',
160
                  'cutoff']
161
    for key in x:
162
        if key not in valid_keys:
163
            print '%s not in %s' % (key, str(valid_keys))
164
            return False
165
        if key == 'energywindow':
166
            if not len(x['energywindow']) == 2:
167
                print '%s is bad' % key
168
                return False
169
        if key == 'energywidth':
170
            if not valid_float(x['energywidth']):
171
                print key, ' is bad'
172
                return False
173
        elif key == 'npoints':
174
            if not valid_int(x['npoints']):
175
                print key, ' is bad'
176
                return False
177
        elif key == 'cutoff':
178
            if not valid_float(x['cutoff']):
179
                print key, ' is bad'
180
                return False
181
    return True
182

    
183

    
184
def valid_decoupling(x):
185
    if x is None:
186
        return
187
    valid_keys = ['ngaussians', 'ecutoff', 'gausswidth']
188
    for key in x:
189
        if key not in valid_keys:
190
            return False
191
        elif key == 'ngaussians':
192
            if not valid_int(x[key]):
193
                print key
194
                return False
195
        elif key == 'ecutoff':
196
            if not valid_int_or_float(x[key]):
197
                return False
198
        elif key == 'gausswidth':
199
            if not valid_float(x[key]):
200
                print key, x[key]
201
                return False
202
    return True
203

    
204
def valid_external_dipole(x):
205
    if x is None:
206
        return
207
    if valid_float(x):
208
        return True
209

    
210
    valid_keys = ['value', 'position']
211

    
212
    for key in x:
213
        if key not in valid_keys:
214
            return False
215
        if key == 'value':
216
            if not valid_float(x['value']):
217
                return False
218
        elif key == 'position':
219
            if not valid_float(x['position']):
220
                return False
221
    return True
222

    
223
def valid_stay_alive(x):
224
    return valid_boolean(x)
225

    
226
def valid_fftgrid(x):
227
    valid_keys = ['soft', 'hard']
228
    for key in x:
229
        if key not in valid_keys:
230
            return False
231
        if x[key] is None:
232
            continue
233

    
234
        grid = np.array(x[key])
235
        if (grid.shape != (3,) and grid.dtype != 'int32'):
236
            return False
237
    return True
238

    
239
def valid_convergence(x):
240
    valid_keys = ['energy',
241
                  'density',
242
                  'occupation',
243
                  'maxsteps',
244
                  'maxtime']
245
    for key in x:
246
        if key not in valid_keys:
247
            return False
248
        if x[key] is None:
249
            continue
250
        if key == 'energy':
251
            if not valid_float(x[key]):
252
                return False
253
        elif key == 'density':
254
            if not valid_float(x[key]):
255
                return False
256
        elif key == 'occupation':
257
            if not valid_float(x[key]):
258
                return False
259
        elif key == 'maxsteps':
260
            if not valid_int(x[key]):
261
                return False
262
        elif key == 'maxtime':
263
            if not valid_int(x[key]):
264
                return False
265
    return True
266

    
267
def valid_charge_mixing(x):
268
    valid_keys = ['method',
269
                  'mixinghistory',
270
                  'mixingcoeff',
271
                  'precondition',
272
                  'updatecharge']
273

    
274
    for key in x:
275
        if key not in valid_keys:
276
            return False
277
        elif key == 'method':
278
            if x[key] not in ['Pulay']:
279
                return False
280
        elif key == 'mixinghistory':
281
            if not valid_int(x[key]):
282
                return False
283
        elif key == 'mixingcoeff':
284
            if not valid_float(x[key]):
285
                return False
286
        elif key == 'precondition':
287
            if x[key] not in ['Yes', 'No']:
288
                return False
289
        elif key == 'updatecharge':
290
            if x[key] not in ['Yes', 'No']:
291
                return False
292
    return True
293

    
294
def valid_electronic_minimization(x):
295
    valid_keys = ['method', 'diagsperband']
296
    for key in x:
297
        if key not in valid_keys:
298
            return False
299
        elif key == 'method':
300
            if x[key] not in ['resmin',
301
                              'eigsolve',
302
                              'rmm-diis']:
303
                return False
304
        elif key == 'diagsperband':
305
            if not valid_int(x[key]):
306
                return False
307
    return True
308

    
309
def valid_occupationstatistics(x):
310
    return (x in ['FermiDirac', 'MethfesselPaxton'])
311

    
312

    
313
def valid_mdos(x):
314
    return True
315

    
316
def valid_psp(x):
317
    valid_keys = ['sym','psp']
318
    if x is None:
319
        return True
320
    for key in x:
321
        if key not in valid_keys:
322
            return False
323
        if not valid_str(x[key]):
324
            return False
325
        if key == 'sym':
326
            from ase.data import chemical_symbols
327
            if key not in chemical_symbols:
328
                return False
329
        if key == 'psp':
330
            
331
            if os.path.exists(x['psp']):
332
                return True
333

    
334
            if os.path.exists(os.path.join(os.environ['DACAPOPATH'],
335
                                           x['psp'])):
336
                return True
337
            #psp not found
338
            return False