"""Module for the conversion between atomic coordinates files and objects

functions:
confs_to_mol_list: Converts the conformers inside a rdkit mol object to a list
    of separate mol objects.
rdkit_mol_to_ase_atoms: Converts a rdkit mol object into ase Atoms object.
adapt_format: Converts the coordinate files into a required library object type.
read_coords: Reads the atomic coordinates resulting from finished calculations.
"""

import logging

import rdkit.Chem.AllChem as Chem

logger = logging.getLogger('DockOnSurf')


def confs_to_mol_list(mol: Chem.rdchem.Mol, idx_lst=None):
    """Converts the conformers inside a rdkit mol object to a list of
    separate mol objects.

    @param mol: rdkit mol object containing at least one conformer.
    @param idx_lst: list of conformer indices to be considered. If not passed,
        all conformers are considered.
    @return: list of separate mol objects.
    """
    if idx_lst is None:
        idx_lst = list(range(mol.GetNumConformers()))
    return [Chem.MolFromMolBlock(Chem.MolToMolBlock(mol, confId=int(idx)),
                                 removeHs=False) for idx in idx_lst]


def rdkit_mol_to_ase_atoms(mol: Chem.rdchem.Mol):
    """Converts a rdkit mol object into ase Atoms object.
    @param mol: rdkit mol object containing only one conformer.
    @return ase.Atoms: ase Atoms object with the same coordinates.
    """
    from ase import Atoms
    if mol.GetNumConformers() > 1:
        logger.warning('A mol object with multiple conformers is parsed, '
                       'converting to Atoms only the first conformer.')
    symbols = [atm.GetSymbol() for atm in mol.GetAtoms()]
    positions = mol.GetConformer(0).GetPositions()
    return Atoms(symbols=symbols, positions=positions)


def adapt_format(requirement, coord_file):
    """Converts the coordinate files into a required library object type.

    Depending on the library required to use and the file type, it converts the
    coordinate file into a library-workable object.
    @param requirement: str, the library for which the conversion should be
    made. Accepted values: 'ase', 'rdkit'.
    @param coord_file: str, path to the coordinates file aiming to convert.
    Accepted file tyoes: 'xyz', 'mol'.
    @return: an object the required library can work with.
    """
    import ase.io
    from ase.io.formats import filetype

    req_vals = ['rdkit', 'ase']
    file_type_vals = ['xyz', 'mol']
    lib_err = f"The conversion to the '{requirement}' library object type" \
              f" has not yet been implemented"
    conv_info = f"Converted {coord_file} to {requirement} object type"

    fil_type_err = f'The {filetype(coord_file)} file formnat is not supported'

    if requirement not in req_vals:
        logger.error(lib_err)
        raise NotImplementedError(lib_err)

    if filetype(coord_file) not in file_type_vals:
        logger.error(fil_type_err)
        raise NotImplementedError(fil_type_err)

    if requirement == 'rdkit':
        if filetype(coord_file) == 'xyz':
            from modules.xyz2mol import xyz2mol
            ase_atms = ase.io.read(coord_file)
            atomic_nums = ase_atms.get_atomic_numbers().tolist()
            xyz_coordinates = ase_atms.positions.tolist()
            # TODO Add routine to read charge
            rd_mol_obj = xyz2mol(atomic_nums, xyz_coordinates, charge=0)
            logger.debug(conv_info)
            return Chem.AddHs(rd_mol_obj)
        elif filetype(coord_file) == 'mol':
            logger.debug(conv_info)
            return Chem.AddHs(Chem.MolFromMolFile(coord_file, removeHs=False))

    if requirement == 'ase':
        if filetype(coord_file) == 'xyz':
            logger.debug(conv_info)
            return ase.io.read(coord_file)
        elif filetype(coord_file) == 'mol':
            logger.debug(conv_info)
            rd_mol = Chem.AddHs(Chem.MolFromMolFile(coord_file, removeHs=False))
            return rdkit_mol_to_ase_atoms(rd_mol)


def read_coords(code, run_type, req):
    """Reads the atomic coordinates resulting from finished calculations.

    Given a run_type ('isolated', 'screening' or 'refinement') directory
    containing different subdirectories with finished calculations in every
    subdirectory, it reads, from each subirectory, the final coordinates
    resulting from the calculation and returns a list of objects adequate to the
    required library.

    @param code: the code that produced the calculation results files.
    @param run_type: the type of calculation (and also the name of the folder)
                     containing the calculation subdirectories.
    @param req: The required library object type to make the list of (eg. rdkit,
                ase)
    @return: list of collection-of-atoms objects. (rdkit.Mol, ase.Atoms, etc.)
    """
    import os
    if code == 'cp2k':
        pattern = '-pos-1.xyz'
    else:
        pattern = ''
    return [adapt_format(req, f'{run_type}/{conf}/{fil}')
            for conf in os.listdir(run_type)
            for fil in os.listdir(f"{run_type}/{conf}") if pattern in fil]


def read_energies(code, run_type):
    """Reads the energies resulting from finished calculations.

    Given a run_type ('isolated', 'screening' or 'refinement') directory
    containing different subdirectories with finished calculations in every
    subdirectory, it reads the final energies of calculations inside each
    subdirectory.

    @param code: the code that produced the calculation results files.
    @param run_type: the type of calculation (and also the name of the folder)
                     containing the calculation subdirectories.
    @return: list of energies
    """
    import os
    import numpy as np
    from modules.utilities import tail

    energies = []
    if code == 'cp2k':
        pattern = '-pos-1.xyz'
        for conf in os.listdir(run_type):
            for fil in os.listdir(f"{run_type}/{conf}"):
                if pattern in fil:
                    traj_fh = open(f"{run_type}/{conf}/{fil}", 'rb')
                    num_atoms = int(traj_fh.readline().strip())
                    last_geo = tail(traj_fh, num_atoms + 2).splitlines()
                    for line in last_geo:
                        if 'E =' in line:
                            energies.append(float(line.split('E =')[1]))

    return np.array(energies)
