"""Functions to deal with DockOnSurf input files.

Functions
try_command:Tries to run a command and logs its exceptions (expected and not).
str2lst: Converts a string of integers, and groups of them, to a list of lists.
check_expect_val: Checks whether the value of an option has an adequate value.
read_input: Sets up the calculation by reading the parameters from input file.
get_run_type: Gets 'run_type' value and checks that its value is acceptable.
get_code: Gets 'code' value and checks that its value is acceptable.
get_batch_q_sys: Gets 'batch_q_sys' value and checks that its value is
acceptable.
get_relaunch_err: Gets 'relaunch_err' value and checks that its value is
acceptable.
get_max_qw: Gets 'max_qw' value and checks that its value is acceptable.
get_special_atoms: Gets 'special_atoms' value and checks that its value is
acceptable.
get_isol_inp_file: Gets 'isol_inp_file' value and checks that its value is
acceptable.
get_cluster_magns: Gets 'cluster_magns' value and checks that its value is
acceptable.
get_num_conformers: Gets 'num_conformers' value and checks that its value is
acceptable.
get_num_prom_cand: Gets 'num_prom_cand' value and checks that its value is
acceptable.
get_iso_rmsd: Gets 'iso_rmsd' value and checks that its value is acceptable.
get_min_confs: Gets 'min_confs' value and checks that its value is acceptable.
get_screen_inp_file: Gets 'screen_inp_file' value and checks that its value is
acceptable.
get_sites: Gets 'sites' value and checks that its value is acceptable.
get_molec_ads_ctrs: Gets 'molec_ads_ctrs' value and checks that its value is
acceptable.
get_try_disso: Gets 'try_disso' value and checks that its value is acceptable.
get_pts_per_angle: Gets 'pts_per_angle' value and checks that its value is
acceptable.
get_coll_thrsld: Gets 'coll_thrsld' value and checks that its value is
acceptable.
get_screen_rmsd: Gets 'screen_rmsd' value and checks that its value is
acceptable.
get_coll_bottom_z: Gets 'coll_bottom_z' value and checks that its value is
acceptable.
get_refine_inp_file: Gets 'refine_inp_file' value and checks that its value is
acceptable.
get_energy_cutoff: Gets 'energy_cutoff' value and checks that its value is
acceptable.
"""
import os.path
import logging
from configparser import ConfigParser, NoSectionError, NoOptionError, \
    MissingSectionHeaderError, DuplicateOptionError

logger = logging.getLogger('DockOnSurf')

dos_inp = ConfigParser(inline_comment_prefixes='#',
                       empty_lines_in_values=False)

new_answers = {'n': False, 'none': False, 'nay': False,
               'y': True, 'sí': True, 'aye': True, 'sure': True}
for answer, val in new_answers.items():
    dos_inp.BOOLEAN_STATES[answer] = val
turn_false_answers = [answer for answer in dos_inp.BOOLEAN_STATES
                      if dos_inp.BOOLEAN_STATES[answer] is False]

no_sect_err = "Section '%s' not found on input file"
no_opt_err = "Option '%s' not found on section '%s'"
num_error = "'%s' value must be a %s"
unexp_error = "An unexpected error occurred"


def try_command(command, expct_error_types: list, *args, **kwargs):
    """Try to run a command and record exceptions (expected and not) on a log.
    
    @param command: method or function, the command to be executed.
    @param expct_error_types: tuple of tuples, every inner tuple is supposed to
    contain an exception type (eg. ValueError, TypeError, etc.) to be caught and
    a message to print in the log and on the screen explaining the exception.
    Error types that are not allow to be called with a custom message as only
    error argument are not supported.
    The outer tuple encloses all couples of error types and their relative
    messages.
    *args and **kwargs: arguments and keyword-arguments of the command to be
    executed.
    When trying to run 'command' with its args and kwargs, if an exception
    present on the 'error_types' occurs, its relative error message is recorded
    on the log and a same type exception is raised with the custom message.
    """

    err = False
    try:
        return_val = command(*args, **kwargs)
    except Exception as e:
        for expct_err in expct_error_types:
            if isinstance(e, expct_err[0]):
                logger.error(expct_err[1])
                err = expct_err[0](expct_err[1])
                break
        else:
            logger.exception(unexp_error)
            err = e
    else:
        err = False
        return return_val
    finally:
        if isinstance(err, BaseException):
            raise err


def str2lst(cmplx_str):  # TODO: enable deeper level of nested lists
    """Converts a string of integers, and groups of them, to a list.

    Keyword arguments:
    @param cmplx_str: str, string of integers and groups of them enclosed by
    parentheses-like characters.
    - Group enclosers: '()' '[]' and '{}'.
    - Integer separators: ',' ';' and ' '.
    - Nested groups are not allowed: '3 ((6 7) 8) 4'.

    @return list, list of integers, or list of integers in the case they were
    grouped. First, the singlets are placed, and then the groups in input order.

    eg. '128,(135 138;141] 87 {45, 68}' -> [128, 87, [135, 138, 141], [45, 68]]
    """

    # Checks
    error_msg = "Function argument should be a str,sequence of integer " \
                "numbers separated by ',' ';' or ' '." \
                "\nThey can be grouped in parentheses-like enclosers: '()', " \
                "'[]' or {}. Nested groups are not allowed. \n" \
                "eg. 128,(135 138;141) 87 {45, 68}"
    cmplx_str = try_command(cmplx_str.replace, [(AttributeError, error_msg)],
                            ',', ' ')

    cmplx_str = cmplx_str.replace(';', ' ').replace('[', '(').replace(
        ']', ')').replace('{', '(').replace('}', ')')

    try_command(list, [(ValueError, error_msg)], map(int, cmplx_str.replace(
        ')', '').replace('(', '').split()))

    deepness = 0
    for el in cmplx_str.split():
        if '(' in el:
            deepness += 1
        if ')' in el:
            deepness += -1
        if deepness > 1 or deepness < 0:
            logger.error(error_msg)
            raise ValueError(error_msg)

    init_list = cmplx_str.split()
    start_group = []
    end_group = []
    for i, element in enumerate(init_list):
        if '(' in element:
            start_group.append(i)
            init_list[i] = element.replace('(', '')
        if ')' in element:
            end_group.append(i)
            init_list[i] = element.replace(')', '')

    init_list = list(map(int, init_list))

    new_list = []
    for start_el, end_el in zip(start_group, end_group):
        new_list.append(init_list[start_el:end_el + 1])

    for v in new_list:
        for el in v:
            init_list.remove(el)
    return init_list + new_list


def check_expect_val(value, expect_vals):
    """Checks whether an option lies within its expected values.

    Keyword arguments:
    @param value: The variable to check if its value lies within the expected
    ones
    @param expect_vals: list, list of values allowed for the present option.
    @raise ValueError: if the value is not among the expected ones.
    @return True if the value is among the expected ones.
    """
    adeq_val_err = "'%s' is not an adequate value.\n" \
                   "Adequate values: %s"
    if not any([exp_val in value for exp_val in expect_vals]):
        logger.error(adeq_val_err % (value, expect_vals))
        raise ValueError(adeq_val_err % (value, expect_vals))

    return True


def check_inp_file(inp_file, code):
    if code == 'cp2k':
        from pycp2k import CP2K
        cp2k = CP2K()
        try_command(cp2k.parse,
                    [(UnboundLocalError, "Invalid CP2K input file")], inp_file)


def get_run_type():
    isolated, screening, refinement = (False, False, False)
    run_type_vals = ['isolated', 'screening', 'refinement', 'adsorption',
                     'full']
    check_expect_val(dos_inp.get('Global', 'run_type').lower(), run_type_vals)

    run_type = dos_inp.get('Global', 'run_type').lower()
    if 'isolated' in run_type:
        isolated = True
    if 'screening' in run_type:
        screening = True
    if 'refinement' in run_type:
        refinement = True
    if 'adsorption' in run_type:
        screening, refinement = (True, True)
    if 'full' in run_type:
        isolated, screening, refinement = (True, True, True)

    return isolated, screening, refinement


def get_code():
    code_vals = ['cp2k']
    check_expect_val(dos_inp.get('Global', 'code').lower(), code_vals)
    code = dos_inp.get('Global', 'code').lower()
    return code


def get_batch_q_sys():
    batch_q_sys_vals = ['sge', 'lsf', 'local', 'none']
    check_expect_val(dos_inp.get('Global', 'batch_q_sys').lower(),
                     batch_q_sys_vals)
    batch_q_sys = dos_inp.get('Global', 'batch_q_sys').lower()
    return batch_q_sys


def get_subm_script():
    subm_script = dos_inp.get('Global', 'subm_script', fallback=False)
    if subm_script and not os.path.isfile(subm_script):
        logger.error(f'File {subm_script} not found')
        raise FileNotFoundError(f'File {subm_script} not found')
    return subm_script


def get_project_name():
    project_name = dos_inp.get('Global', 'project_name', fallback='')
    return project_name


def get_relaunch_err():
    relaunch_err_vals = ['geo_not_conv', 'false']
    relaunch_err = dos_inp.get('Global', 'relaunch_err',
                               fallback="False")
    if relaunch_err.lower() in turn_false_answers:
        return False
    else:
        check_expect_val(relaunch_err.lower(), relaunch_err_vals)
    return relaunch_err


def get_max_qw():
    err_msg = num_error % ('max_qw', 'positive integer')
    max_qw = try_command(dos_inp.getint, [(ValueError, err_msg)],
                         'Global', 'max_qw', fallback=3)

    if max_qw < 1:
        logger.error(num_error % ('max_qw', 'positive integer'))
        raise ValueError(num_error % ('max_qw', 'positive integer'))
    return max_qw


def get_special_atoms():
    from ase.data import chemical_symbols

    spec_at_err = '\'special_atoms\' does not have an adequate format.\n' \
                  'Adequate format: (Fe1 Fe) (O1 O)'
    special_atoms = dos_inp.get('Global', 'special_atoms', fallback="False")
    if special_atoms.lower() in turn_false_answers:
        special_atoms = False
    else:
        # Converts the string into a list of tuples
        lst_tple = [tuple(pair.replace("(", "").split()) for pair in
                    special_atoms.split(")")[:-1]]
        if len(lst_tple) == 0:
            logger.error(spec_at_err)
            raise ValueError(spec_at_err)
        for i, tup in enumerate(lst_tple):
            if type(tup) is not tuple or len(tup) != 2:
                logger.error(spec_at_err)
                raise ValueError(spec_at_err)
            if tup[1].capitalize() not in chemical_symbols:
                elem_err = "The second element of the couple should be an " \
                           "actual element of the periodic table"
                logger.error(elem_err)
                raise ValueError(elem_err)
            if tup[0].capitalize() in chemical_symbols:
                elem_err = "The first element of the couple is already an " \
                           "actual element of the periodic table, "
                logger.error(elem_err)
                raise ValueError(elem_err)
            for j, tup2 in enumerate(lst_tple):
                if j <= i:
                    continue
                if tup2[0] == tup[0]:
                    label_err = f'You have specified the label {tup[0]} to ' \
                                f'more than one special atom'
                    logger.error(label_err)
                    raise ValueError(label_err)
        special_atoms = lst_tple
    return special_atoms


def get_isol_inp_file():
    isol_inp_file = dos_inp.get('Isolated', 'isol_inp_file')
    if not os.path.isfile(isol_inp_file):
        logger.error(f'File {isol_inp_file} not found')
        raise FileNotFoundError(f'File {isol_inp_file} not found')
    return isol_inp_file


def get_molec_file():
    molec_file = dos_inp.get('Isolated', 'molec_file')
    if not os.path.isfile(molec_file):
        logger.error(f'File {molec_file} not found')
        raise FileNotFoundError(f'File {molec_file} not found')
    return molec_file


def get_cluster_magns():
    clust_magns_vals = ['energy', 'moi']
    cluster_magns_str = dos_inp.get('Isolated', 'cluster_magns',
                                    fallback='energy')
    cluster_magns_str.replace(',', ' ').replace(';', ' ')
    cluster_magns = cluster_magns_str.split(' ')
    cluster_magns = [m.lower() for m in cluster_magns]
    for m in cluster_magns:
        check_expect_val(m, clust_magns_vals)
    return cluster_magns


def get_num_conformers():
    err_msg = num_error % ('num_conformers', 'positive integer')
    num_conformers = try_command(dos_inp.getint, [(ValueError, err_msg)],
                                 'Isolated', 'num_conformers', fallback=100)
    if num_conformers < 1:
        logger.error(err_msg)
        raise ValueError(err_msg)
    return num_conformers


def get_num_prom_cand():
    err_msg = num_error % ('num_prom_cand', 'positive integer')
    num_prom_cand = try_command(dos_inp.getint, [(ValueError, err_msg)],
                                'Isolated', 'num_prom_cand', fallback=3)
    if num_prom_cand < 1:
        logger.error(err_msg)
        raise ValueError(err_msg)
    return num_prom_cand


def get_iso_rmsd():
    err_msg = num_error % ('iso_rmsd', 'positive decimal number')
    iso_rmsd = try_command(dos_inp.getfloat, [(ValueError, err_msg)],
                           'Isolated', 'iso_rmsd', fallback=0.05)
    if iso_rmsd <= 0.0:
        logger.error(err_msg)
        raise ValueError(err_msg)
    return iso_rmsd


def get_min_confs():
    err_msg = "'min_confs' should be have a boolean value (True or False)"
    min_confs = try_command(dos_inp.getboolean,
                            [(ValueError, err_msg)],
                            'Isolated', 'min_confs', fallback=True)
    return min_confs


def get_screen_inp_file():
    screen_inp_file = dos_inp.get('Screening', 'screen_inp_file')
    if not os.path.isfile(screen_inp_file):
        logger.error(f'File {screen_inp_file} not found')
        raise FileNotFoundError(f'File {screen_inp_file} not found')
    return screen_inp_file


def get_sites():
    err_msg = 'The value of sites should be a list of atom numbers ' \
              '(ie. positive integers) or groups of atom numbers ' \
              'grouped by parentheses-like enclosers. \n' \
              'eg. 128,(135 138;141) 87 {45, 68}'
    # Convert the string into a list of lists
    sites = try_command(str2lst,
                        [(ValueError, err_msg), (AttributeError, err_msg)],
                        dos_inp.get('Screening', 'sites'))
    # Check all elements of the list (of lists) are positive integers
    for site in sites:
        if type(site) is list:
            for atom in site:
                if atom < 0:
                    logger.error(err_msg)
                    raise ValueError(err_msg)
        elif type(site) is int:
            if site < 0:
                logger.error(err_msg)
                raise ValueError(err_msg)
        else:
            logger.error(err_msg)
            raise ValueError(err_msg)

    return sites


def get_molec_ads_ctrs():
    err_msg = 'The value of molec_ads_ctrs should be a list of atom' \
              ' numbers (ie. positive integers) or groups of atom ' \
              'numbers enclosed by parentheses-like characters. \n' \
              'eg. 128,(135 138;141) 87 {45, 68}'
    # Convert the string into a list of lists
    molec_ads_ctrs = try_command(str2lst,
                                 [(ValueError, err_msg),
                                  (AttributeError, err_msg)],
                                 dos_inp.get('Screening', 'molec_ads_ctrs'))
    # Check all elements of the list (of lists) are positive integers
    for ctr in molec_ads_ctrs:
        if type(ctr) is list:
            for atom in ctr:
                if atom < 0:
                    logger.error(err_msg)
                    raise ValueError(err_msg)
        elif type(ctr) is int:
            if ctr < 0:
                logger.error(err_msg)
                raise ValueError(err_msg)
        else:
            logger.error(err_msg)
            raise ValueError(err_msg)

    return molec_ads_ctrs


def get_try_disso():
    err_msg = "try_disso should be have a boolean value (True or False)"
    try_disso = try_command(dos_inp.getboolean,
                            [(ValueError, err_msg)],
                            'Screening', 'try_disso', fallback=False)
    return try_disso


def get_pts_per_angle():
    err_msg = num_error % ('sample_points_per_angle',
                           'positive integer')
    pts_per_angle = try_command(dos_inp.getint,
                                [(ValueError, err_msg)],
                                'Screening', 'sample_points_per_angle',
                                fallback=3)

    return pts_per_angle


def get_coll_thrsld():
    err_msg = num_error % ('collision_threshold',
                           'positive decimal number')

    coll_thrsld = try_command(dos_inp.getfloat,
                              [(ValueError, err_msg)],
                              'Screening', 'collision_threshold', fallback=1.2)
    if coll_thrsld <= 0:
        logger.error(err_msg)
        raise ValueError(err_msg)

    return coll_thrsld


def get_screen_rmsd():
    err_msg = num_error % ('screen_rmsd', 'positive decimal number')
    screen_rmsd = try_command(dos_inp.getfloat,
                              [(ValueError, err_msg)],
                              'Screening', 'screen_rmsd', fallback=0.05)
    if screen_rmsd <= 0:
        logger.error(err_msg)
        raise ValueError(err_msg)

    return screen_rmsd


def get_coll_bottom_z():
    err_msg = num_error % ('collision_bottom_z', 'decimal number')
    coll_bottom_z = dos_inp.get('Screening', 'collision_bottom_z',
                                fallback="False")
    if coll_bottom_z.lower() in turn_false_answers:
        coll_bottom_z = False
    else:
        coll_bottom_z = try_command(float, [(ValueError, err_msg)],
                                    coll_bottom_z)

    return coll_bottom_z


def get_refine_inp_file():
    refine_inp_file = dos_inp.get('Refinement', 'refine_inp_file')
    if not os.path.isfile(refine_inp_file):
        logger.error(f'File {refine_inp_file} not found')
        raise FileNotFoundError(f'File {refine_inp_file} not found')

    return refine_inp_file


def get_energy_cutoff():
    err_msg = num_error % ('energy_cutoff', 'positive decimal number')
    energy_cutoff = try_command(dos_inp.getfloat,
                                [(ValueError, err_msg)],
                                'Refinement', 'energy_cutoff', fallback=0.5)
    if energy_cutoff < 0:
        logger.error(err_msg)
        raise ValueError(err_msg)
    return energy_cutoff


def read_input(in_file):
    err = False
    try:
        dos_inp.read(in_file)
    except MissingSectionHeaderError as e:
        logger.error('There are options in the input file without a Section '
                     'header')
        err = e
    except DuplicateOptionError as e:
        logger.error('There is an option in the input file that has been '
                     'specified more than once, possibly due to the lack of a '
                     'Section header')
        err = e
    except Exception as e:
        err = e
    else:
        err = False
    finally:
        if isinstance(err, BaseException):
            raise err

    return_vars = {}

    # Global
    if not dos_inp.has_section('Global'):
        logger.error(no_sect_err % 'Global')
        raise NoSectionError('Global')

    # Mandatory options
    # Checks whether the mandatory options 'run_type', 'code', etc. are present.
    glob_mand_opts = ['run_type', 'code', 'batch_q_sys']
    for opt in glob_mand_opts:
        if not dos_inp.has_option('Global', opt):
            logger.error(no_opt_err % (opt, 'Global'))
            raise NoOptionError(opt, 'Global')

    # Gets which sections are to be carried out
    isolated, screening, refinement = get_run_type()
    return_vars['isolated'] = isolated
    return_vars['screening'] = screening
    return_vars['refinement'] = refinement
    return_vars['code'] = get_code()
    return_vars['batch_q_sys'] = get_batch_q_sys()

    # Dependent options:
    return_vars['subm_script'] = get_subm_script()
    if return_vars['batch_q_sys'] != 'local' and not return_vars['subm_script']:
        sub_err = "'subm_script' must be provided if 'batch_q_sys' is not local"
        logger.error(sub_err)
        raise NoOptionError(opt, 'Global') # TODO Change to ValueError

    # Facultative options (Default/Fallback value present)
    return_vars['project_name'] = get_project_name()
    return_vars['relaunch_err'] = get_relaunch_err()
    return_vars['max_qw'] = get_max_qw()
    return_vars['special_atoms'] = get_special_atoms()

    # Isolated
    if isolated:
        if not dos_inp.has_section('Isolated'):
            logger.error(no_sect_err % 'Isolated')
            raise NoSectionError('Isolated')
        # Mandatory options
        # Checks whether the mandatory options are present.
        iso_mand_opts = ['isol_inp_file', 'molec_file']
        for opt in iso_mand_opts:
            if not dos_inp.has_option('Isolated', opt):
                logger.error(no_opt_err % (opt, 'Isolated'))
                raise NoOptionError(opt, 'Isolated')
        return_vars['isol_inp_file'] = get_isol_inp_file()
        check_inp_file(return_vars['isol_inp_file'], return_vars['code'])
        return_vars['molec_file'] = get_molec_file()

        # Facultative options (Default/Fallback value present)
        return_vars['cluster_magns'] = get_cluster_magns()
        return_vars['num_conformers'] = get_num_conformers()
        # return_vars['num_prom_cand'] = get_num_prom_cand()
        # return_vars['iso_rmsd'] = get_iso_rmsd()
        return_vars['min_confs'] = get_min_confs()

    # Screening
    if screening:
        if not dos_inp.has_section('Screening'):
            logger.error(no_sect_err % 'Screening')
            raise NoSectionError('Screening')
        # Mandatory options:
        # Checks whether the mandatory options are present.
        screen_mand_opts = ['sites', 'molec_ads_ctrs', 'screen_inp_file']
        for opt in screen_mand_opts:
            if not dos_inp.has_option('Screening', opt):
                logger.error(no_opt_err % (opt, 'Screening'))
                raise NoOptionError(opt, 'Screening')
        return_vars['screen_inp_file'] = get_screen_inp_file()
        return_vars['sites'] = get_sites()
        return_vars['molec_ads_ctrs'] = get_molec_ads_ctrs()

        # Facultative options (Default value present)
        return_vars['try_disso'] = get_try_disso()
        return_vars['sample_points_per_angle'] = get_pts_per_angle()
        return_vars['collision_threshold'] = get_coll_thrsld()
        # return_vars['screen_rmsd'] = get_screen_rmsd()
        return_vars['collision_bottom_z'] = get_coll_bottom_z()

    # Refinement
    if refinement:
        if not dos_inp.has_section('Refinement'):
            logger.error(no_sect_err % 'Refinement')
            raise NoSectionError('Refinement')
        # Mandatory options
        # Checks whether the mandatory options are present.
        ref_mand_opts = ['refine_inp_file']
        for opt in ref_mand_opts:
            if not dos_inp.has_option('Refinement', opt):
                logger.error(no_opt_err % (opt, 'Refinement'))
                raise NoOptionError(opt, 'Refinement')
        return_vars['refine_inp_file'] = get_refine_inp_file()

        # Facultative options (Default value present)
        return_vars['energy_cutoff'] = get_energy_cutoff()
        # end energy_cutoff

    return_vars_str = "\n\t".join([str(key) + ": " + str(val)
                                   for key, val in return_vars.items()])
    logger.info(
        f'Correctly read {in_file} parameters: \n\n\t{return_vars_str}\n')

    return return_vars


if __name__ == "__main__":
    import sys

    print(read_input(sys.argv[1]))
