"""Module for the conversion between atomic coordinates files and objects

functions:
confs_to_mol_list: Converts the conformers inside a rdkit mol object to a list
    of separate mol objects.
rdkit_mol_to_ase_atoms: Converts a rdkit mol object into ase Atoms object.
adapt_format: Converts the coordinate files into a required library object type.
read_coords: Reads the atomic coordinates resulting from finished calculations.
"""

import logging

import rdkit.Chem.AllChem as Chem

logger = logging.getLogger('DockOnSurf')


def confs_to_mol_list(mol: Chem.rdchem.Mol, idx_lst=None):
    """Converts the conformers inside a rdkit mol object to a list of
    separate mol objects.

    @param mol: rdkit mol object containing at least one conformer.
    @param idx_lst: list of conformer indices to be considered. If not passed,
        all conformers are considered.
    @return: list of separate mol objects.
    """
    if idx_lst is None:
        idx_lst = list(range(mol.GetNumConformers()))
    return [Chem.MolFromMolBlock(Chem.MolToMolBlock(mol, confId=int(idx)),
                                 removeHs=False) for idx in idx_lst]


def rdkit_mol_to_ase_atoms(mol: Chem.rdchem.Mol):
    """Converts a rdkit mol object into ase Atoms object.
    @param mol: rdkit mol object containing only one conformer.
    @return ase.Atoms: ase Atoms object with the same coordinates.
    """
    from ase import Atoms
    if mol.GetNumConformers() > 1:
        logger.warning('A mol object with multiple conformers is parsed, '
                       'converting to Atoms only the first conformer.')
    symbols = [atm.GetSymbol() for atm in mol.GetAtoms()]
    positions = mol.GetConformer(0).GetPositions()
    return Atoms(symbols=symbols, positions=positions)


def add_special_atoms(symbol_pairs):
    """Allows to use custom elements with symbols not in the periodic table.

    This function adds new chemical elements to be used by ase. Every new custom
    element must have a traditional (present in the periodic table) partner
    from which to obtain all its properties.
    @param symbol_pairs: List of tuples containing the pairs of chemical symbols.
        Every tuple contains a pair of chemical symbols, the first label must be
        the label of the custom element and the second one the symbol of the
        reference one (traditional present on the periodic table).
    @return:
    """
    import numpy as np
    from ase import data
    for i, pair in enumerate(symbol_pairs):
        data.chemical_symbols += [pair[0]]
        z_orig = data.atomic_numbers[pair[1]]
        orig_iupac_mass = data.atomic_masses_iupac2016[z_orig]
        orig_com_mass = data.atomic_masses_common[z_orig]
        data.atomic_numbers[pair[0]] = max(data.atomic_numbers.values()) + 1
        data.atomic_names += [pair[0]]
        data.atomic_masses_iupac2016 = np.append(data.atomic_masses_iupac2016,
                                                 orig_iupac_mass)
        data.atomic_masses = data.atomic_masses_iupac2016
        data.atomic_masses_common = np.append(data.atomic_masses_common,
                                              orig_com_mass)
        data.covalent_radii = np.append(data.covalent_radii,
                                        data.covalent_radii[z_orig])
        data.reference_states += [data.reference_states[z_orig]]
        # TODO Add vdw_radii, gsmm and aml (smaller length)


def adapt_format(requirement, coord_file, spec_atms=tuple()):
    """Converts the coordinate files into a required library object type.

    Depending on the library required to use and the file type, it converts the
    coordinate file into a library-workable object.
    @param requirement: str, the library for which the conversion should be
    made. Accepted values: 'ase', 'rdkit'.
    @param coord_file: str, path to the coordinates file aiming to convert.
    Accepted file tyoes: 'xyz', 'mol'.
    @param spec_atms: List of tuples containing pairs of new/traditional
        chemical symbols.
    @return: an object the required library can work with.
    """
    import ase.io
    from ase.io.formats import filetype

    req_vals = ['rdkit', 'ase']
    file_type_vals = ['xyz', 'mol']
    lib_err = f"The conversion to the '{requirement}' library object type" \
              f" has not yet been implemented"
    conv_info = f"Converted {coord_file} to {requirement} object type"

    fil_type_err = f'The {filetype(coord_file)} file formnat is not supported'

    if requirement not in req_vals:
        logger.error(lib_err)
        raise NotImplementedError(lib_err)

    if filetype(coord_file) not in file_type_vals:
        logger.error(fil_type_err)
        raise NotImplementedError(fil_type_err)

    if requirement == 'rdkit':
        if filetype(coord_file) == 'xyz':
            from modules.xyz2mol import xyz2mol
            ase_atms = ase.io.read(coord_file)
            atomic_nums = ase_atms.get_atomic_numbers().tolist()
            xyz_coordinates = ase_atms.positions.tolist()
            rd_mol_obj = xyz2mol(atomic_nums, xyz_coordinates, charge=0)
            logger.debug(conv_info)
            return Chem.AddHs(rd_mol_obj)
        elif filetype(coord_file) == 'mol':
            logger.debug(conv_info)
            return Chem.AddHs(Chem.MolFromMolFile(coord_file, removeHs=False))

    if requirement == 'ase':
        add_special_atoms(spec_atms)
        if filetype(coord_file) == 'xyz':
            logger.debug(conv_info)
            return ase.io.read(coord_file)
        elif filetype(coord_file) == 'mol':
            logger.debug(conv_info)
            rd_mol = Chem.AddHs(Chem.MolFromMolFile(coord_file, removeHs=False))
            return rdkit_mol_to_ase_atoms(rd_mol)


def read_coords(code, run_type, req, spec_atms=tuple()):
    """Reads the atomic coordinates resulting from finished calculations.

    Given a run_type ('isolated', 'screening' or 'refinement') directory
    containing different subdirectories with finished calculations in every
    subdirectory, it reads, from each subirectory, the final coordinates
    resulting from the calculation and returns a list of objects adequate to the
    required library.

    @param code: the code that produced the calculation results files.
    @param run_type: the type of calculation (and also the name of the folder)
                     containing the calculation subdirectories.
    @param req: The required library object type to make the list of (eg. rdkit,
                ase)
    @param spec_atms: List of tuples containing pairs of new/traditional
        chemical symbols.
    @return: list of collection-of-atoms objects. (rdkit.Mol, ase.Atoms, etc.)
    """
    import os
    # Relate file-name patterns to codes
    if code == 'cp2k':
        pattern = '-pos-1.xyz'
    else:
        pattern = ''

    # Read appropriate files and transform them to adequate object
    atoms_list = []
    for conf in os.listdir(run_type):
        if not os.path.isdir(f'{run_type}/{conf}') or 'conf_' not in conf:
            continue
        for fil in os.listdir(f"{run_type}/{conf}"):
            if pattern not in fil:
                continue
            atoms_list.append(adapt_format(req, f'{run_type}/{conf}/{fil}',
                                           spec_atms))
    return atoms_list


def read_energies(code, run_type):
    """Reads the energies resulting from finished calculations.

    Given a run_type ('isolated', 'screening' or 'refinement') directory
    containing different subdirectories with finished calculations in every
    subdirectory, it reads the final energies of calculations inside each
    subdirectory.

    @param code: the code that produced the calculation results files.
    @param run_type: the type of calculation (and also the name of the folder)
                     containing the calculation subdirectories.
    @return: list of energies
    """
    import os
    import numpy as np
    from modules.utilities import tail

    energies = []
    if code == 'cp2k':
        pattern = '-pos-1.xyz'
        for conf in os.listdir(run_type):
            for fil in os.listdir(f"{run_type}/{conf}"):
                if pattern in fil:
                    traj_fh = open(f"{run_type}/{conf}/{fil}", 'rb')
                    num_atoms = int(traj_fh.readline().strip())
                    last_geo = tail(traj_fh, num_atoms + 2).splitlines()
                    for line in last_geo:
                        if 'E =' in line:
                            energies.append(float(line.split('E =')[1]))

    return np.array(energies)
