Statistiques
| Branche: | Tag: | Révision :

dockonsurf / modules / dos_input.py @ a44ad3c2

Historique | Voir | Annoter | Télécharger (38,61 ko)

1
"""Functions to deal with DockOnSurf input files.
2

3
Functions
4
try_command:Tries to run a command and logs its exceptions (expected and not).
5
str2lst: Converts a string of integers, and groups of them, to a list of lists.
6
check_expect_val: Checks whether the value of an option has an adequate value.
7
read_input: Sets up the calculation by reading the parameters from input file.
8
get_run_type: Gets 'run_type' value and checks that its value is acceptable.
9
get_code: Gets 'code' value and checks that its value is acceptable.
10
get_batch_q_sys: Gets 'batch_q_sys' value and checks that its value is
11
acceptable.
12
get_relaunch_err: Gets 'relaunch_err' value and checks that its value is
13
acceptable.
14
get_max_jobs: Gets 'max_jobs' value and checks that its value is acceptable.
15
get_special_atoms: Gets 'special_atoms' value and checks that its value is
16
acceptable.
17
get_isol_inp_file: Gets 'isol_inp_file' value and checks that its value is
18
acceptable.
19
get_cluster_magns: Gets 'cluster_magns' value and checks that its value is
20
acceptable.
21
get_num_conformers: Gets 'num_conformers' value and checks that its value is
22
acceptable.
23
get_num_prom_cand: Gets 'num_prom_cand' value and checks that its value is
24
acceptable.
25
get_iso_rmsd: Gets 'iso_rmsd' value and checks that its value is acceptable.
26
get_pre_opt: Gets 'pre_opt' value and checks that its value is acceptable.
27
get_screen_inp_file: Gets 'screen_inp_file' value and checks that its value is
28
acceptable.
29
get_sites: Gets 'sites' value and checks that its value is acceptable.
30
get_molec_ctrs: Gets 'molec_ctrs' value and checks that its value is
31
acceptable.
32
get_try_disso: Gets 'try_disso' value and checks that its value is acceptable.
33
get_pts_per_angle: Gets 'pts_per_angle' value and checks that its value is
34
acceptable.
35
get_coll_thrsld: Gets 'coll_thrsld' value and checks that its value is
36
acceptable.
37
get_screen_rmsd: Gets 'screen_rmsd' value and checks that its value is
38
acceptable.
39
get_coll_bottom_z: Gets 'coll_bottom_z' value and checks that its value is
40
acceptable.
41
get_refine_inp_file: Gets 'refine_inp_file' value and checks that its value is
42
acceptable.
43
get_energy_cutoff: Gets 'energy_cutoff' value and checks that its value is
44
acceptable.
45
"""
46
import os.path
47
import logging
48
from configparser import ConfigParser, NoSectionError, NoOptionError, \
49
    MissingSectionHeaderError, DuplicateOptionError
50
import numpy as np
51
from modules.utilities import try_command
52

    
53
logger = logging.getLogger('DockOnSurf')
54

    
55
dos_inp = ConfigParser(inline_comment_prefixes='#',
56
                       empty_lines_in_values=False)
57

    
58
new_answers = {'n': False, 'none': False, 'nay': False,
59
               'y': True, '': True, 'aye': True, 'sure': True}
60
for answer, val in new_answers.items():
61
    dos_inp.BOOLEAN_STATES[answer] = val  # TODO Check value 0
62
turn_false_answers = [answer for answer in dos_inp.BOOLEAN_STATES
63
                      if dos_inp.BOOLEAN_STATES[answer] is False]
64
turn_true_answers = [answer for answer in dos_inp.BOOLEAN_STATES
65
                     if dos_inp.BOOLEAN_STATES[answer]]
66

    
67
no_sect_err = "Section '%s' not found on input file"
68
no_opt_err = "Option '%s' not found on section '%s'"
69
num_error = "'%s' value must be a %s"
70

    
71

    
72
# Auxilary functions
73

    
74
def str2lst(cmplx_str, func=int):  # TODO: enable deeper level of nested lists
75
    # TODO Treat all-enclosing parenthesis as a list instead of list of lists.
76
    """Converts a string of integers, and groups of them, to a list.
77

78
    Keyword arguments:
79
    @param cmplx_str: str, string of integers (or floats) and groups of them
80
    enclosed by parentheses-like characters.
81
    - Group enclosers: '()' '[]' and '{}'.
82
    - Separators: ',' ';' and ' '.
83
    - Nested groups are not allowed: '3 ((6 7) 8) 4'.
84
    @param func: either to use int or float
85

86
    @return list, list of integers (or floats), or list of integers (or floats)
87
    in the case they were grouped. First, the singlets are placed, and then the
88
    groups in input order.
89

90
    eg. '128,(135 138;141] 87 {45, 68}' -> [128, 87, [135, 138, 141], [45, 68]]
91
    """
92

    
93
    # Checks
94
    error_msg = "Function argument should be a str, sequence of " \
95
                "numbers separated by ',' ';' or ' '." \
96
                "\nThey can be grouped in parentheses-like enclosers: '()', " \
97
                "'[]' or {}. Nested groups are not allowed. \n" \
98
                "eg. 128,(135 138;141) 87 {45, 68}"
99
    cmplx_str = try_command(cmplx_str.replace, [(AttributeError, error_msg)],
100
                            ',', ' ')
101

    
102
    cmplx_str = cmplx_str.replace(';', ' ').replace('[', '(').replace(
103
        ']', ')').replace('{', '(').replace('}', ')')
104

    
105
    try_command(list, [(ValueError, error_msg)], map(func, cmplx_str.replace(
106
        ')', '').replace('(', '').split()))
107

    
108
    deepness = 0
109
    for el in cmplx_str.split():
110
        if '(' in el:
111
            deepness += 1
112
        if ')' in el:
113
            deepness += -1
114
        if deepness > 1 or deepness < 0:
115
            logger.error(error_msg)
116
            raise ValueError(error_msg)
117

    
118
    init_list = cmplx_str.split()
119
    start_group = []
120
    end_group = []
121
    for i, element in enumerate(init_list):
122
        if '(' in element:
123
            start_group.append(i)
124
            init_list[i] = element.replace('(', '')
125
        if ')' in element:
126
            end_group.append(i)
127
            init_list[i] = element.replace(')', '')
128

    
129
    init_list = list(map(func, init_list))
130

    
131
    new_list = []
132
    for start_el, end_el in zip(start_group, end_group):
133
        new_list.append(init_list[start_el:end_el + 1])
134

    
135
    for v in new_list:
136
        for el in v:
137
            init_list.remove(el)
138
    return init_list + new_list
139

    
140

    
141
def check_expect_val(value, expect_vals, err_msg=None):
142
    """Checks whether an option lies within its expected values.
143

144
    Keyword arguments:
145
    @param value: The variable to check if its value lies within the expected
146
    ones
147
    @param expect_vals: list, list of values allowed for the present option.
148
    @param err_msg: The error message to be prompted in both log and screen.
149
    @raise ValueError: if the value is not among the expected ones.
150
    @return True if the value is among the expected ones.
151
    """
152
    if err_msg is None:
153
        err_msg = f"'{value}' is not an adequate value.\n" \
154
                  f"Adequate values: {expect_vals}"
155
    if not any([exp_val == value for exp_val in expect_vals]):
156
        logger.error(err_msg)
157
        raise ValueError(err_msg)
158

    
159
    return True
160

    
161

    
162
def check_inp_file(inp_files, code):
163
    if code == 'cp2k':
164
        from pycp2k import CP2K
165
        if not isinstance(inp_files, str):
166
            err_msg = "When using CP2K, only one input file is allowed"
167
            logger.error(err_msg)
168
            ValueError(err_msg)
169
        elif not os.path.isfile(inp_files):
170
            err_msg = f"Input file {inp_files} was not found."
171
            logger.error(err_msg)
172
            raise FileNotFoundError(err_msg)
173
        cp2k = CP2K()
174
        try_command(cp2k.parse,
175
                    [(UnboundLocalError, "Invalid CP2K input file")], inp_files)
176
    elif code == "vasp":
177
        mand_files = ["INCAR", "KPOINTS", "POTCAR"]
178
        # Check that it inp_files is a list of file paths
179
        if not isinstance(inp_files, list) and all(isinstance(inp_file, str)
180
                                                   for inp_file in inp_files):
181
            err_msg = "'inp_files' should be a list of file names/paths"
182
            logger.error(err_msg)
183
            ValueError(err_msg)
184
        # Check that all mandatory files are defined once and just once.
185
        elif [[mand_file in inp_file for inp_file in inp_files].count(True)
186
              for mand_file in mand_files].count(1) != len(mand_files):
187
            err_msg = f"Each of the mandatory files {mand_files} must be " \
188
                      f"defined once and just once."
189
            logger.error(err_msg)
190
            raise FileNotFoundError(err_msg)
191
        # Check that the defined files exist
192
        elif any(not os.path.isfile(inp_file) for inp_file in inp_files):
193
            err_msg = f"At least one of the mandatory files {mand_files} was " \
194
                      "not found."
195
            logger.error(err_msg)
196
            raise FileNotFoundError(err_msg)
197
        # Check that mandatory files are actual vasp files.
198
        else:
199
            from pymatgen.io.vasp.inputs import Incar, Kpoints, Potcar
200
            for inp_file in inp_files:
201
                file_name = inp_file.split("/")[-1]
202
                if not any(mand_file in file_name for mand_file in mand_files):
203
                    continue
204
                file_type = ""
205
                for mand_file in mand_files:
206
                    if mand_file in inp_file:
207
                        file_type = mand_file
208
                err = False
209
                err_msg = f"'{inp_file}' is not a valid {file_name} file."
210
                try:
211
                    eval(file_type.capitalize()).from_file(inp_file)
212
                except ValueError:
213
                    logger.error(err_msg)
214
                    err = ValueError(err_msg)
215
                except IndexError:
216
                    logger.error(err_msg)
217
                    err = IndexError(err_msg)
218
                else:
219
                    if file_name == "INCAR":
220
                        Incar.from_file("INCAR").check_params()
221
                finally:
222
                    if isinstance(err, BaseException):
223
                        raise err
224

    
225

    
226
# Global
227

    
228
def get_run_type():
229
    isolated, screening, refinement = (False, False, False)
230
    run_type_vals = ['isolated', 'screening', 'refinement', 'adsorption',
231
                     'full']
232
    run_types = dos_inp.get('Global', 'run_type').split()
233
    for run_type in run_types:
234
        check_expect_val(run_type.lower(), run_type_vals)
235
        if 'isol' in run_type.lower():
236
            isolated = True
237
        if 'screen' in run_type.lower():
238
            screening = True
239
        if 'refine' in run_type.lower():
240
            refinement = True
241
        if 'adsor' in run_type.lower():
242
            screening, refinement = (True, True)
243
        if 'full' in run_type.lower():
244
            isolated, screening, refinement = (True, True, True)
245

    
246
    return isolated, screening, refinement
247

    
248

    
249
def get_code():
250
    code_vals = ['cp2k', 'vasp']
251
    check_expect_val(dos_inp.get('Global', 'code').lower(), code_vals)
252
    code = dos_inp.get('Global', 'code').lower()
253
    return code
254

    
255

    
256
def get_batch_q_sys():
257
    batch_q_sys_vals = ['sge', 'lsf', 'irene', 'local'] + turn_false_answers
258
    check_expect_val(dos_inp.get('Global', 'batch_q_sys').lower(),
259
                     batch_q_sys_vals)
260
    batch_q_sys = dos_inp.get('Global', 'batch_q_sys').lower()
261
    if batch_q_sys.lower() in turn_false_answers:
262
        return False
263
    else:
264
        return batch_q_sys
265

    
266

    
267
def get_pbc_cell():
268
    err_msg = "'pbc_cell' must be either 3 vectors of size 3 or False."
269
    pbc_cell_str = dos_inp.get('Global', 'pbc_cell', fallback="False")
270
    if pbc_cell_str.lower() in turn_false_answers:
271
        return False
272
    else:
273
        pbc_cell = np.array(try_command(str2lst, [(ValueError, err_msg)],
274
                                        pbc_cell_str, float))
275
        if pbc_cell.shape != (3, 3):
276
            logger.error(err_msg)
277
            raise ValueError(err_msg)
278
        if np.linalg.det(pbc_cell) == 0.0:
279
            err_msg = "The volume of the defined cell is 0"
280
            logger.error(err_msg)
281
            raise ValueError(err_msg)
282
        return pbc_cell
283

    
284

    
285
def get_subm_script():
286
    subm_script = dos_inp.get('Global', 'subm_script')
287
    if not os.path.isfile(subm_script):
288
        logger.error(f'File {subm_script} not found.')
289
        raise FileNotFoundError(f'File {subm_script} not found')
290
    return subm_script
291

    
292

    
293
def get_project_name():
294
    project_name = dos_inp.get('Global', 'project_name', fallback='')
295
    return project_name
296

    
297

    
298
def get_relaunch_err():
299
    relaunch_err_vals = ['geo_not_conv']
300
    relaunch_err = dos_inp.get('Global', 'relaunch_err',
301
                               fallback="False")
302
    if relaunch_err.lower() in turn_false_answers:
303
        return False
304
    else:
305
        check_expect_val(relaunch_err.lower(), relaunch_err_vals)
306
    return relaunch_err
307

    
308

    
309
def get_max_jobs():
310
    import re
311
    err_msg = "'max_jobs' must be a list of, number plus 'p', 'q' or 'r', or " \
312
              "a combination of them without repeating letters.\n" \
313
              "eg: '2r 3p 4pr', '5q' or '3r 3p'"
314
    max_jobs_str = dos_inp.get('Global', 'max_jobs', fallback="inf").lower()
315
    str_vals = ["r", "p", "q", "rp", "rq", "pr", "qr"]
316
    max_jobs = {"r": np.inf, "p": np.inf, "rp": np.inf}
317
    if "inf" == max_jobs_str:
318
        return {"r": np.inf, "p": np.inf, "rp": np.inf}
319
    # Iterate over the number of requirements:
320
    for req in max_jobs_str.split():
321
        # Split numbers from letters into a list
322
        req_parts = re.findall(r'[a-z]+|\d+', req)
323
        if len(req_parts) != 2:
324
            logger.error(err_msg)
325
            raise ValueError(err_msg)
326
        if req_parts[0].isdecimal():
327
            req_parts[1] = req_parts[1].replace('q', 'p').replace('pr', 'rp')
328
            if req_parts[1] in str_vals and max_jobs[req_parts[1]] == np.inf:
329
                max_jobs[req_parts[1]] = int(req_parts[0])
330
        elif req_parts[1].isdecimal():
331
            req_parts[0] = req_parts[0].replace('q', 'p').replace('pr', 'rp')
332
            if req_parts[0] in str_vals and max_jobs[req_parts[0]] == np.inf:
333
                max_jobs[req_parts[0]] = int(req_parts[1])
334
        else:
335
            logger.error(err_msg)
336
            raise ValueError(err_msg)
337

    
338
    return max_jobs
339

    
340

    
341
def get_special_atoms():
342
    from ase.data import chemical_symbols
343

    
344
    spec_at_err = '\'special_atoms\' does not have an adequate format.\n' \
345
                  'Adequate format: (Fe1 Fe) (O1 O)'
346
    special_atoms = dos_inp.get('Global', 'special_atoms', fallback="False")
347
    if special_atoms.lower() in turn_false_answers:
348
        special_atoms = []
349
    else:
350
        # Converts the string into a list of tuples
351
        lst_tple = [tuple(pair.replace("(", "").split()) for pair in
352
                    special_atoms.split(")")[:-1]]
353
        if len(lst_tple) == 0:
354
            logger.error(spec_at_err)
355
            raise ValueError(spec_at_err)
356
        for i, tup in enumerate(lst_tple):
357
            if not isinstance(tup, tuple) or len(tup) != 2:
358
                logger.error(spec_at_err)
359
                raise ValueError(spec_at_err)
360
            if tup[1].capitalize() not in chemical_symbols:
361
                elem_err = "The second element of the couple should be an " \
362
                           "actual element of the periodic table"
363
                logger.error(elem_err)
364
                raise ValueError(elem_err)
365
            if tup[0].capitalize() in chemical_symbols:
366
                elem_err = "The first element of the couple is already an " \
367
                           "actual element of the periodic table, "
368
                logger.error(elem_err)
369
                raise ValueError(elem_err)
370
            for j, tup2 in enumerate(lst_tple):
371
                if j <= i:
372
                    continue
373
                if tup2[0] == tup[0]:
374
                    label_err = f'You have specified the label {tup[0]} to ' \
375
                                f'more than one special atom'
376
                    logger.error(label_err)
377
                    raise ValueError(label_err)
378
        special_atoms = lst_tple
379
    return special_atoms
380

    
381

    
382
# Isolated
383

    
384
def get_isol_inp_file(code):  # TODO allow spaces in path names
385
    inp_file_lst = dos_inp.get('Isolated', 'isol_inp_file').split()
386
    check_inp_file(inp_file_lst[0] if len(inp_file_lst) == 1 else inp_file_lst,
387
                   code)
388
    return inp_file_lst[0] if len(inp_file_lst) == 1 else inp_file_lst
389

    
390

    
391
def get_molec_file():
392
    molec_file = dos_inp.get('Isolated', 'molec_file')
393
    if not os.path.isfile(molec_file):
394
        logger.error(f'File {molec_file} not found.')
395
        raise FileNotFoundError(f'File {molec_file} not found')
396
    return molec_file
397

    
398

    
399
def get_num_conformers():
400
    err_msg = num_error % ('num_conformers', 'positive integer')
401
    num_conformers = try_command(dos_inp.getint, [(ValueError, err_msg)],
402
                                 'Isolated', 'num_conformers', fallback=100)
403
    if num_conformers < 1:
404
        logger.error(err_msg)
405
        raise ValueError(err_msg)
406
    return num_conformers
407

    
408

    
409
def get_pre_opt():
410
    pre_opt_vals = ['uff', 'mmff'] + turn_false_answers
411
    check_expect_val(dos_inp.get('Isolated', 'pre_opt').lower(), pre_opt_vals)
412
    pre_opt = dos_inp.get('Isolated', 'pre_opt').lower()
413
    if pre_opt in turn_false_answers:
414
        return False
415
    else:
416
        return pre_opt
417

    
418

    
419
# Screening
420

    
421
def get_screen_inp_file(code):  # TODO allow spaces in path names
422
    inp_file_lst = dos_inp.get('Screening', 'screen_inp_file').split()
423
    check_inp_file(inp_file_lst[0] if len(inp_file_lst) == 1 else inp_file_lst,
424
                   code)
425
    return inp_file_lst[0] if len(inp_file_lst) == 1 else inp_file_lst
426

    
427

    
428
def get_surf_file():
429
    surf_file = dos_inp.get('Screening', 'surf_file')
430
    if not os.path.isfile(surf_file):
431
        logger.error(f'File {surf_file} not found.')
432
        raise FileNotFoundError(f'File {surf_file} not found')
433
    return surf_file
434

    
435

    
436
def get_sites():
437
    err_msg = 'The value of sites should be a list of atom numbers ' \
438
              '(ie. positive integers) or groups of atom numbers ' \
439
              'grouped by parentheses-like enclosers. \n' \
440
              'eg. 128,(135 138;141) 87 {45, 68}'
441
    # Convert the string into a list of lists
442
    sites = try_command(str2lst,
443
                        [(ValueError, err_msg), (AttributeError, err_msg)],
444
                        dos_inp.get('Screening', 'sites'))
445
    # Check all elements of the list (of lists) are positive integers
446
    for site in sites:
447
        if type(site) is list:
448
            for atom in site:
449
                if atom < 0:
450
                    logger.error(err_msg)
451
                    raise ValueError(err_msg)
452
        elif type(site) is int:
453
            if site < 0:
454
                logger.error(err_msg)
455
                raise ValueError(err_msg)
456
        else:
457
            logger.error(err_msg)
458
            raise ValueError(err_msg)
459

    
460
    return sites
461

    
462

    
463
def get_surf_ctrs2():
464
    err_msg = 'The value of surf_ctrs2 should be a list of atom numbers ' \
465
              '(ie. positive integers) or groups of atom numbers ' \
466
              'grouped by parentheses-like enclosers. \n' \
467
              'eg. 128,(135 138;141) 87 {45, 68}'
468
    # Convert the string into a list of lists
469
    surf_ctrs2 = try_command(str2lst,
470
                             [(ValueError, err_msg), (AttributeError, err_msg)],
471
                             dos_inp.get('Screening', 'surf_ctrs2'))
472
    # Check all elements of the list (of lists) are positive integers
473
    for ctr in surf_ctrs2:
474
        if type(ctr) is list:
475
            for atom in ctr:
476
                if atom < 0:
477
                    logger.error(err_msg)
478
                    raise ValueError(err_msg)
479
        elif type(ctr) is int:
480
            if ctr < 0:
481
                logger.error(err_msg)
482
                raise ValueError(err_msg)
483
        else:
484
            logger.error(err_msg)
485
            raise ValueError(err_msg)
486

    
487
    return surf_ctrs2
488

    
489

    
490
def get_molec_ctrs():
491
    err_msg = 'The value of molec_ctrs should be a list of atom' \
492
              ' numbers (ie. positive integers) or groups of atom ' \
493
              'numbers enclosed by parentheses-like characters. \n' \
494
              'eg. 128,(135 138;141) 87 {45, 68}'
495
    # Convert the string into a list of lists
496
    molec_ctrs = try_command(str2lst,
497
                             [(ValueError, err_msg),
498
                              (AttributeError, err_msg)],
499
                             dos_inp.get('Screening', 'molec_ctrs'))
500
    # Check all elements of the list (of lists) are positive integers
501
    for ctr in molec_ctrs:
502
        if isinstance(ctr, list):
503
            for atom in ctr:
504
                if atom < 0:
505
                    logger.error(err_msg)
506
                    raise ValueError(err_msg)
507
        elif isinstance(ctr, int):
508
            if ctr < 0:
509
                logger.error(err_msg)
510
                raise ValueError(err_msg)
511
        else:
512
            logger.error(err_msg)
513
            raise ValueError(err_msg)
514

    
515
    return molec_ctrs
516

    
517

    
518
def get_molec_ctrs2():
519
    err_msg = 'The value of molec_ctrs2 should be a list of atom ' \
520
              'numbers (ie. positive integers) or groups of atom ' \
521
              'numbers enclosed by parentheses-like characters. \n' \
522
              'eg. 128,(135 138;141) 87 {45, 68}'
523
    # Convert the string into a list of lists
524
    molec_ctrs2 = try_command(str2lst, [(ValueError, err_msg),
525
                                        (AttributeError, err_msg)],
526
                              dos_inp.get('Screening', 'molec_ctrs2'))
527

    
528
    # Check all elements of the list (of lists) are positive integers
529
    for ctr in molec_ctrs2:
530
        if isinstance(ctr, list):
531
            for atom in ctr:
532
                if atom < 0:
533
                    logger.error(err_msg)
534
                    raise ValueError(err_msg)
535
        elif isinstance(ctr, int):
536
            if ctr < 0:
537
                logger.error(err_msg)
538
                raise ValueError(err_msg)
539
        else:
540
            logger.error(err_msg)
541
            raise ValueError(err_msg)
542

    
543
    return molec_ctrs2
544

    
545

    
546
def get_molec_ctrs3():
547
    err_msg = 'The value of molec_ctrs3 should be a list of atom ' \
548
              'numbers (ie. positive integers) or groups of atom ' \
549
              'numbers enclosed by parentheses-like characters. \n' \
550
              'eg. 128,(135 138;141) 87 {45, 68}'
551
    # Convert the string into a list of lists
552
    molec_ctrs3 = try_command(str2lst, [(ValueError, err_msg),
553
                                        (AttributeError, err_msg)],
554
                              dos_inp.get('Screening', 'molec_ctrs3'))
555

    
556
    # Check all elements of the list (of lists) are positive integers
557
    for ctr in molec_ctrs3:
558
        if isinstance(ctr, list):
559
            for atom in ctr:
560
                if atom < 0:
561
                    logger.error(err_msg)
562
                    raise ValueError(err_msg)
563
        elif isinstance(ctr, int):
564
            if ctr < 0:
565
                logger.error(err_msg)
566
                raise ValueError(err_msg)
567
        else:
568
            logger.error(err_msg)
569
            raise ValueError(err_msg)
570

    
571
    return molec_ctrs3
572

    
573

    
574
def get_max_helic_angle():
575
    err_msg = "'max_helic_angle' must be a positive number in degrees"
576
    max_helic_angle = try_command(dos_inp.getfloat, [(ValueError, err_msg)],
577
                                  'Screening', 'max_helic_angle',
578
                                  fallback=180.0)
579
    if max_helic_angle < 0:
580
        logger.error(err_msg)
581
        raise ValueError(err_msg)
582

    
583
    return max_helic_angle
584

    
585

    
586
def get_select_magns():
587
    select_magns_vals = ['energy', 'moi']
588
    select_magns_str = dos_inp.get('Screening', 'select_magns',
589
                                   fallback='moi')
590
    select_magns_str.replace(',', ' ').replace(';', ' ')
591
    select_magns = select_magns_str.split(' ')
592
    select_magns = [m.lower() for m in select_magns]
593
    for m in select_magns:
594
        check_expect_val(m, select_magns_vals)
595
    return select_magns
596

    
597

    
598
def get_confs_per_magn():
599
    err_msg = num_error % ('confs_per_magn', 'positive integer')
600
    confs_per_magn = try_command(dos_inp.getint, [(ValueError, err_msg)],
601
                                 'Screening', 'confs_per_magn', fallback=2)
602
    if confs_per_magn <= 0:
603
        logger.error(err_msg)
604
        raise ValueError(err_msg)
605
    return confs_per_magn
606

    
607

    
608
def get_surf_norm_vect():
609
    err = "'surf_norm_vect' must be a 3 component vector, 'x', 'y' or 'z', " \
610
          "'auto' or 'asann'."
611
    cart_axes = {'x': [1.0, 0.0, 0.0], '-x': [-1.0, 0.0, 0.0],
612
                 'y': [0.0, 1.0, 0.0], '-y': [0.0, -1.0, 0.0],
613
                 'z': [0.0, 0.0, 1.0], '-z': [0.0, 0.0, -1.0]}
614
    surf_norm_vect_str = dos_inp.get('Screening', 'surf_norm_vect',
615
                                     fallback="auto").lower()
616
    if surf_norm_vect_str == "asann" or surf_norm_vect_str == "auto":
617
        return 'auto'
618
    if surf_norm_vect_str in cart_axes:
619
        return np.array(cart_axes[surf_norm_vect_str])
620
    surf_norm_vect = try_command(str2lst, [(ValueError, err)],
621
                                 surf_norm_vect_str, float)
622
    if len(surf_norm_vect) != 3:
623
        logger.error(err)
624
        raise ValueError(err)
625

    
626
    return np.array(surf_norm_vect) / np.linalg.norm(surf_norm_vect)
627

    
628

    
629
def get_adsorption_height():
630
    err_msg = num_error % ('adsorption_height', 'positive number')
631
    ads_height = try_command(dos_inp.getfloat, [(ValueError, err_msg)],
632
                             'Screening', 'adsorption_height', fallback=2.5)
633
    if ads_height <= 0:
634
        logger.error(err_msg)
635
        raise ValueError(err_msg)
636
    return ads_height
637

    
638

    
639
def get_set_angles():
640
    set_vals = ['euler', 'internal']
641
    check_expect_val(dos_inp.get('Screening', 'set_angles').lower(), set_vals)
642
    set_angles = dos_inp.get('Screening', 'set_angles',
643
                             fallback='euler').lower()
644
    return set_angles
645

    
646

    
647
def get_pts_per_angle():
648
    err_msg = num_error % ('sample_points_per_angle', 'positive integer')
649
    pts_per_angle = try_command(dos_inp.getint,
650
                                [(ValueError, err_msg)],
651
                                'Screening', 'sample_points_per_angle',
652
                                fallback=3)
653
    if pts_per_angle <= 0:
654
        logger.error(err_msg)
655
        raise ValueError(err_msg)
656
    return pts_per_angle
657

    
658

    
659
def get_max_structures():
660
    err_msg = num_error % ('max_structures', 'positive integer')
661
    max_structs = dos_inp.get('Screening', 'max_structures', fallback="False")
662
    if max_structs.lower() in turn_false_answers:
663
        return np.inf
664
    if try_command(int, [(ValueError, err_msg)], max_structs) <= 0:
665
        logger.error(err_msg)
666
        raise ValueError(err_msg)
667
    return int(max_structs)
668

    
669

    
670
def get_coll_thrsld():
671
    err_msg = num_error % ('collision_threshold', 'positive number')
672
    coll_thrsld_str = dos_inp.get('Screening', 'collision_threshold',
673
                                  fallback="False")
674
    if coll_thrsld_str.lower() in turn_false_answers:
675
        return False
676
    coll_thrsld = try_command(float, [(ValueError, err_msg)], coll_thrsld_str)
677

    
678
    if coll_thrsld <= 0:
679
        logger.error(err_msg)
680
        raise ValueError(err_msg)
681

    
682
    return coll_thrsld
683

    
684

    
685
def get_min_coll_height(norm_vect):
686
    err_msg = num_error % ('min_coll_height', 'decimal number')
687
    min_coll_height = dos_inp.get('Screening', 'min_coll_height',
688
                                  fallback="False")
689
    if min_coll_height.lower() in turn_false_answers:
690
        return False
691
    min_coll_height = try_command(float, [(ValueError, err_msg)],
692
                                  min_coll_height)
693
    cart_axes = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0],
694
                 [-1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, -1.0]]
695
    err_msg = "'min_coll_height' option is only implemented for " \
696
              "'surf_norm_vect' to be one of the x, y or z axes. "
697
    if not isinstance(norm_vect, str) or norm_vect != 'auto':
698
        check_expect_val(norm_vect.tolist(), cart_axes, err_msg)
699
    return min_coll_height
700

    
701

    
702
def get_exclude_ads_ctr():
703
    err_msg = "exclude_ads_ctr must have a boolean value."
704
    exclude_ads_ctr = try_command(dos_inp.getboolean, [(ValueError, err_msg)],
705
                                  "Screening", "exclude_ads_ctr",
706
                                  fallback=False)
707
    return exclude_ads_ctr
708

    
709

    
710
def get_H_donor(spec_atoms):
711
    from ase.data import chemical_symbols
712
    err_msg = "The value of 'h_donor' must be either False, a chemical symbol "\
713
              "or an atom index"
714
    h_donor_str = dos_inp.get('Screening', 'h_donor', fallback="False")
715
    h_donor = []
716
    if h_donor_str.lower() in turn_false_answers:
717
        return False
718
    err = False
719
    for el in h_donor_str.split():
720
        try:
721
            h_donor.append(int(el))
722
        except ValueError:
723
            if el not in chemical_symbols + [nw_sym for pairs in spec_atoms
724
                                             for nw_sym in pairs]:
725
                err = True
726
            else:
727
                h_donor.append(el)
728
        finally:
729
            if err:
730
                logger.error(err_msg)
731
                ValueError(err_msg)
732
    return h_donor
733

    
734

    
735
def get_H_acceptor(spec_atoms):
736
    from ase.data import chemical_symbols
737
    err_msg = "The value of 'h_acceptor' must be either 'all', a chemical " \
738
              "symbol or an atom index"
739
    h_acceptor_str = dos_inp.get('Screening', 'h_acceptor', fallback="all")
740
    if h_acceptor_str.lower() == "all":
741
        return "all"
742
    h_acceptor = []
743
    err = False
744
    for el in h_acceptor_str.split():
745
        try:
746
            h_acceptor.append(int(el))
747
        except ValueError:
748
            if el not in chemical_symbols + [nw_sym for pairs in spec_atoms
749
                                             for nw_sym in pairs]:
750
                err = True
751
            else:
752
                h_acceptor.append(el)
753
        finally:
754
            if err:
755
                logger.error(err_msg)
756
                ValueError(err_msg)
757
    return h_acceptor
758

    
759

    
760
def get_use_molec_file():
761
    use_molec_file = dos_inp.get('Screening', 'use_molec_file',
762
                                 fallback='False')
763
    if use_molec_file.lower() in turn_false_answers:
764
        return False
765
    if not os.path.isfile(use_molec_file):
766
        logger.error(f'File {use_molec_file} not found.')
767
        raise FileNotFoundError(f'File {use_molec_file} not found')
768

    
769
    return use_molec_file
770

    
771

    
772
# Refinement
773

    
774
def get_refine_inp_file(code):
775
    inp_file_lst = dos_inp.get('Refinement', 'refine_inp_file').split()
776
    check_inp_file(inp_file_lst[0] if len(inp_file_lst) == 1 else inp_file_lst,
777
                   code)
778
    return inp_file_lst[0] if len(inp_file_lst) == 1 else inp_file_lst
779

    
780

    
781
def get_energy_cutoff():
782
    err_msg = num_error % ('energy_cutoff', 'positive decimal number')
783
    energy_cutoff = try_command(dos_inp.getfloat,
784
                                [(ValueError, err_msg)],
785
                                'Refinement', 'energy_cutoff', fallback=0.5)
786
    if energy_cutoff < 0:
787
        logger.error(err_msg)
788
        raise ValueError(err_msg)
789
    return energy_cutoff
790

    
791

    
792
# Read input parameters
793

    
794
def read_input(in_file):
795
    from modules.formats import adapt_format
796

    
797
    err_msg = False
798
    try:
799
        dos_inp.read(in_file)
800
    except MissingSectionHeaderError as e:
801
        logger.error('There are options in the input file without a Section '
802
                     'header.')
803
        err_msg = e
804
    except DuplicateOptionError as e:
805
        logger.error('There is an option in the input file that has been '
806
                     'specified more than once.')
807
        err_msg = e
808
    except Exception as e:
809
        err_msg = e
810
    else:
811
        err_msg = False
812
    finally:
813
        if isinstance(err_msg, BaseException):
814
            raise err_msg
815

    
816
    inp_vars = {}
817

    
818
    # Global
819
    if not dos_inp.has_section('Global'):
820
        logger.error(no_sect_err % 'Global')
821
        raise NoSectionError('Global')
822

    
823
    # Mandatory options
824
    # Checks whether the mandatory options 'run_type', 'code', etc. are present.
825
    glob_mand_opts = ['run_type', 'code', 'batch_q_sys']
826
    for opt in glob_mand_opts:
827
        if not dos_inp.has_option('Global', opt):
828
            logger.error(no_opt_err % (opt, 'Global'))
829
            raise NoOptionError(opt, 'Global')
830

    
831
    # Gets which sections are to be carried out
832
    isolated, screening, refinement = get_run_type()
833
    inp_vars['isolated'] = isolated
834
    inp_vars['screening'] = screening
835
    inp_vars['refinement'] = refinement
836
    inp_vars['code'] = get_code()
837
    inp_vars['batch_q_sys'] = get_batch_q_sys()
838

    
839
    # Dependent options:
840
    if inp_vars['batch_q_sys']:
841
        inp_vars['max_jobs'] = get_max_jobs()
842
        if inp_vars['batch_q_sys'] != 'local':
843
            if not dos_inp.has_option('Global', 'subm_script'):
844
                logger.error(no_opt_err % ('subm_script', 'Global'))
845
                raise NoOptionError('subm_script', 'Global')
846
            inp_vars['subm_script'] = get_subm_script()
847

    
848
    # Facultative options (Default/Fallback value present)
849
    inp_vars['pbc_cell'] = get_pbc_cell()
850
    inp_vars['project_name'] = get_project_name()
851
    # inp_vars['relaunch_err'] = get_relaunch_err()
852
    inp_vars['special_atoms'] = get_special_atoms()
853

    
854
    # Isolated
855
    if isolated:
856
        if not dos_inp.has_section('Isolated'):
857
            logger.error(no_sect_err % 'Isolated')
858
            raise NoSectionError('Isolated')
859
        # Mandatory options
860
        # Checks whether the mandatory options are present.
861
        iso_mand_opts = ['isol_inp_file', 'molec_file']
862
        for opt in iso_mand_opts:
863
            if not dos_inp.has_option('Isolated', opt):
864
                logger.error(no_opt_err % (opt, 'Isolated'))
865
                raise NoOptionError(opt, 'Isolated')
866
        inp_vars['isol_inp_file'] = get_isol_inp_file(inp_vars['code'])
867
        inp_vars['molec_file'] = get_molec_file()
868

    
869
        # Checks for PBC
870
        atms = adapt_format('ase', inp_vars['molec_file'],
871
                            inp_vars['special_atoms'])
872
        if inp_vars['code'] == 'vasp' and np.linalg.det(atms.cell) == 0.0 \
873
                and inp_vars['pbc_cell'] is False:
874
            err_msg = "When running calculations with 'VASP', the PBC cell" \
875
                      "should be provided either implicitely inside " \
876
                      "'molec_file' or by setting the 'pbc_cell' option."
877
            logger.error(err_msg)
878
            raise ValueError(err_msg)
879
        elif np.allclose(inp_vars['pbc_cell'], atms.cell):
880
            logger.warning("'molec_file' has an implicit cell defined "
881
                           f"different than 'pbc_cell' ('molec_file'="
882
                           f"{atms.cell}, 'pbc_cell'= {inp_vars['pbc_cell']}). "
883
                           f"'pbc_cell' value will be used.")
884

    
885
        # Facultative options (Default/Fallback value present)
886
        inp_vars['num_conformers'] = get_num_conformers()
887
        inp_vars['pre_opt'] = get_pre_opt()
888

    
889
    # Screening
890
    if screening:
891
        if not dos_inp.has_section('Screening'):
892
            logger.error(no_sect_err % 'Screening')
893
            raise NoSectionError('Screening')
894
        # Mandatory options:
895
        # Checks whether the mandatory options are present.
896
        screen_mand_opts = ['screen_inp_file', 'surf_file', 'sites',
897
                            'molec_ctrs']
898
        for opt in screen_mand_opts:
899
            if not dos_inp.has_option('Screening', opt):
900
                logger.error(no_opt_err % (opt, 'Screening'))
901
                raise NoOptionError(opt, 'Screening')
902
        inp_vars['screen_inp_file'] = get_screen_inp_file(inp_vars['code'])
903
        inp_vars['surf_file'] = get_surf_file()
904
        inp_vars['sites'] = get_sites()
905
        inp_vars['molec_ctrs'] = get_molec_ctrs()
906

    
907
        # Checks for PBC
908
        atms = adapt_format('ase', inp_vars['surf_file'],
909
                            inp_vars['special_atoms'])
910
        if inp_vars['code'] == 'vasp' and np.linalg.det(atms.cell) == 0.0 \
911
                and inp_vars['pbc_cell'] is False:
912
            err_msg = "When running calculations with 'VASP', the PBC cell" \
913
                      "should be provided either implicitely inside " \
914
                      "'molec_file' or by setting the 'pbc_cell' option."
915
            logger.error(err_msg)
916
            raise ValueError(err_msg)
917
        elif np.allclose(inp_vars['pbc_cell'], atms.cell):
918
            logger.warning("'molec_file' has an implicit cell defined, "
919
                           "different than 'pbc_cell' ('molec_file'="
920
                           f"{atms.cell}, 'pbc_cell'={inp_vars['pbc_cell']}). "
921
                           "'pbc_cell' value will be used.")
922

    
923
        # Facultative options (Default value present)
924
        inp_vars['select_magns'] = get_select_magns()
925
        inp_vars['confs_per_magn'] = get_confs_per_magn()
926
        inp_vars['surf_norm_vect'] = get_surf_norm_vect()
927
        inp_vars['adsorption_height'] = get_adsorption_height()
928
        inp_vars['set_angles'] = get_set_angles()
929
        inp_vars['sample_points_per_angle'] = get_pts_per_angle()
930
        inp_vars['collision_threshold'] = get_coll_thrsld()
931
        inp_vars['min_coll_height'] = get_min_coll_height(
932
            inp_vars['surf_norm_vect'])
933
        if inp_vars['min_coll_height'] is False \
934
                and inp_vars['collision_threshold'] is False:
935
            logger.warning("Collisions are deactivated: Overlapping of "
936
                           "adsorbate and surface is possible")
937
        inp_vars['exclude_ads_ctr'] = get_exclude_ads_ctr()
938
        inp_vars['h_donor'] = get_H_donor(inp_vars['special_atoms'])
939
        inp_vars['max_structures'] = get_max_structures()
940
        inp_vars['use_molec_file'] = get_use_molec_file()
941

    
942
        # Options depending on the value of others
943
        if inp_vars['set_angles'] == "internal":
944
            internal_opts = ['molec_ctrs2', 'molec_ctrs3', 'surf_ctrs2',
945
                             'max_helic_angle']
946
            for opt in internal_opts:
947
                if not dos_inp.has_option('Screening', opt):
948
                    logger.error(no_opt_err % (opt, 'Screening'))
949
                    raise NoOptionError(opt, 'Screening')
950
            inp_vars['max_helic_angle'] = get_max_helic_angle()
951
            inp_vars['molec_ctrs2'] = get_molec_ctrs2()
952
            inp_vars['molec_ctrs3'] = get_molec_ctrs3()
953
            inp_vars['surf_ctrs2'] = get_surf_ctrs2()
954
            if len(inp_vars["molec_ctrs2"]) != len(inp_vars['molec_ctrs']) \
955
                    or len(inp_vars["molec_ctrs3"]) != \
956
                    len(inp_vars['molec_ctrs']) \
957
                    or len(inp_vars['surf_ctrs2']) != len(inp_vars['sites']):
958
                err_msg = "'molec_ctrs' 'molec_ctrs2' and 'molec_ctrs3' must " \
959
                          "have the same number of indices "
960
                logger.error(err_msg)
961
                raise ValueError(err_msg)
962

    
963
        if inp_vars['h_donor'] is False:
964
            inp_vars['h_acceptor'] = False
965
        else:
966
            inp_vars['h_acceptor'] = get_H_acceptor(inp_vars['special_atoms'])
967

    
968
    # Refinement
969
    if refinement:
970
        if not dos_inp.has_section('Refinement'):
971
            logger.error(no_sect_err % 'Refinement')
972
            raise NoSectionError('Refinement')
973
        # Mandatory options
974
        # Checks whether the mandatory options are present.
975
        ref_mand_opts = ['refine_inp_file']
976
        for opt in ref_mand_opts:
977
            if not dos_inp.has_option('Refinement', opt):
978
                logger.error(no_opt_err % (opt, 'Refinement'))
979
                raise NoOptionError(opt, 'Refinement')
980
        inp_vars['refine_inp_file'] = get_refine_inp_file(inp_vars['code'])
981

    
982
        # Facultative options (Default value present)
983
        inp_vars['energy_cutoff'] = get_energy_cutoff()
984
        # end energy_cutoff
985

    
986
    return_vars_str = "\n\t".join([str(key) + ": " + str(value)
987
                                   for key, value in inp_vars.items()])
988
    logger.info(f'Correctly read {in_file} parameters:'
989
                f' \n\n\t{return_vars_str}\n')
990

    
991
    return inp_vars
992

    
993

    
994
if __name__ == "__main__":
995
    import sys
996

    
997
    print(read_input(sys.argv[1]))