"""Functions to generate the conformers to be adsorbed and the most stable one.

functions:
remove_C_linked_Hs: Removes hydrogens bonded to a carbon atom from a molecule.
gen_confs: Generate a number of conformers in random orientations.
get_rmsd: Gets the rmsd matrix of the conformers in a rdkit mol object.
get_moments_of_inertia: Computes moments of inertia of the given conformers.
mmff_opt_confs: Optimizes the geometry of the given conformers and returns the
    new mol object and the energies of its conformers.
run_isolated: directs the execution of functions to achieve the goal
"""
import os
import logging

import numpy as np
import rdkit.Chem.AllChem as Chem

logger = logging.getLogger('DockOnSurf')


def remove_C_linked_Hs(mol: Chem.rdchem.Mol):
    """Removes hydrogen atoms bonded to a carbon atom from a rdkit mol object.

    @param mol: rdkit mol object of the molecule with hydrogen atoms.
    @return: rdkit mol object of the molecule without hydrogen atoms linked to
    a carbon atom.

    The functions removes the hydrogen atoms bonded to carbon atoms while
    keeping the ones bonded to other atoms or non-bonded at all.
    """

    mol = Chem.RWMol(mol)
    rev_atm_idxs = [atom.GetIdx() for atom in reversed(mol.GetAtoms())]

    for atm_idx in rev_atm_idxs:
        atom = mol.GetAtomWithIdx(atm_idx)
        if atom.GetAtomicNum() != 1:
            continue
        for neigh in atom.GetNeighbors():
            if neigh.GetAtomicNum() == 6:
                mol.RemoveAtom(atom.GetIdx())
    return mol


def gen_confs(mol: Chem.rdchem.Mol, num_confs: int, local_min=True):
    """Generate conformers in random orientations.

    @param mol: rdkit mol object of the molecule to be adsorbed.
    @param num_confs: number of conformers to randomly generate.
    @param local_min: bool: if generated conformers should be a local minimum.
    @return: mol: rdkit mol object containing the different conformers.
             rmsd_mtx: Matrix with the rmsd values of conformers.

    Using the rdkit library, conformers are randomly generated. If structures 
    are required to be local minima, ie. setting the 'local_min' value to 
    True, a geometry optimisation using UFF is performed.
    """
    logger.debug('Generating Conformers.')

    mol = Chem.AddHs(mol)
    Chem.EmbedMultipleConfs(mol, numConfs=num_confs, numThreads=0)
    Chem.AlignMolConformers(mol)
    logger.info(f'Generated {len(mol.GetConformers())} conformers.')
    return mol


def get_moments_of_inertia(mol: Chem.rdchem.Mol):  # TODO Rethink its usage
    """Computes the moments of inertia of the given conformers

    @param mol: rdkit mol object of the relevant molecule.
    @return numpy array 2D: The inner array contains the moments of inertia for
    the three principal axis of a given conformer. They are ordered by its value
    in ascending order. The outer tuple loops over the conformers.
    """
    from rdkit.Chem.Descriptors3D import PMI1, PMI2, PMI3

    return np.array([[PMI(mol, confId=conf) for PMI in (PMI1, PMI2, PMI3)]
                     for conf in range(mol.GetNumConformers())])


def mmff_opt_confs(mol: Chem.rdchem.Mol, max_iters=2000):
    """Optimizes the geometry of the given conformers and returns the new mol
    object and the energies of its conformers.

    @param mol: rdkit mol object of the relevant molecule.
    @param max_iters: Maximum number of geometry optimization iterations. With 0
    a single point energy calculation is performed and only the conformer
    energies are returned.
    @return mol: rdkit mol object of the optimized molecule.
    @return numpy.ndarray: Array with the energies of the optimized conformers.

    The MMFF forcefield is used for the geometry optimization in its rdkit
    implementation. With max_iters value set to 0, a single point energy
    calculation is performed and only the energies are returned. For values
    larger than 0, if the geometry does not converge for a certain conformer,
    the latter is removed from the list of conformers and its energy is not
    included in the returned list.
    """
    from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMoleculeConfs

    init_num_confs = mol.GetNumConformers()
    results = np.array(MMFFOptimizeMoleculeConfs(mol, numThreads=0,
                                                 maxIters=max_iters,
                                                 nonBondedThresh=10))

    # Remove non-converged conformers if optimization is on, ie. maxIters > 0
    # return all conformers if optimization is switched off, ie. maxIters = 0
    if max_iters > 0:
        for i, conv in enumerate(results[:, 0]):
            if conv != 0:
                mol.RemoveConformer(i)
        for i, conf in enumerate(mol.GetConformers()):
            conf.SetId(i)
        if mol.GetNumConformers() < init_num_confs:
            logger.warning(f'MMFF Geometry optimization did not comverge for at'
                           f'least one conformer. Continuing with '
                           f'{mol.GetNumConformers()} converged conformers.')
        logger.info(f'Pre-optimized conformers with MMFF.')
        return mol, np.array([res[1] for res in results if res[0] == 0])
    else:
        logger.info(f'Computed conformers energy with MMFF.')
        return np.array([res[1] for res in results])


def run_isolated(inp_vars):
    """Directs the execution of functions to obtain the conformers to adsorb

    @param inp_vars: Calculation parameters from input file.
    @return:
    """
    from modules.formats import adapt_format, confs_to_mol_list, \
        rdkit_mol_to_ase_atoms
    from modules.clustering import clustering, get_rmsd
    from modules.calculation import run_calc

    logger.info('Carrying out procedures for the isolated molecule.')
    rd_mol = adapt_format('rdkit', inp_vars['molec_file'])
    confs = gen_confs(rd_mol, inp_vars['num_conformers'])
    if inp_vars['min_confs']:
        confs, confs_ener = mmff_opt_confs(confs)
    else:
        confs_ener = mmff_opt_confs(confs, max_iters=0)
    conf_list = confs_to_mol_list(confs)
    rmsd_mtx = get_rmsd(conf_list)
    confs_moi = get_moments_of_inertia(confs)
    exemplars = clustering(rmsd_mtx)
    mol_list = confs_to_mol_list(confs, exemplars)
    ase_atms_list = [rdkit_mol_to_ase_atoms(mol) for mol in mol_list]
    run_calc('isolated', inp_vars, ase_atms_list)
