"""Functions to generate the conformers to be adsorbed and the most stable one.

functions:
confs_to_mol_list: Converts the conformers inside a rdkit mol object to a list
    of independent mol objects.
remove_C_linked_Hs: Removes hydrogens bonded to a carbon atom from a molecule.
gen_confs: Generate a number of conformers in random orientations.
get_rmsd: Gets the rmsd matrix of the conformers in a rdkit mol object.
get_moments_of_inertia: Computes moments of inertia of the given conformers.
mmff_opt_confs: Optimizes the geometry of the given conformers and returns the
    new mol object and the energies of its conformers.
run_isolated: directs the execution of functions to achieve the goal
"""
import logging

import numpy as np
from rdkit.Chem import AllChem as Chem

from formats import adapt_format

logger = logging.getLogger('DockOnSurf')


def confs_to_mol_list(mol: Chem.rdchem.Mol):
    """Converts the conformers inside a rdkit mol object to a list of
    separate mol objects.

    @param mol: rdkit mol object containing at least one conformer.
    @return list: list of separate mol objects.
    """
    return [Chem.MolFromMolBlock(Chem.MolToMolBlock(mol, confId=conf.GetId()))
            for conf in mol.GetConformers()]


def remove_C_linked_Hs(mol: Chem.rdchem.Mol):
    """Removes hydrogen atoms bonded to a carbon atom from a rdkit mol object.

    @param mol: rdkit mol object of the molecule with hydrogen atoms.
    @return: rdkit mol object of the molecule without hydrogen atoms linked to
    a carbon atom.

    The functions removes the hydrogen atoms bonded to carbon atoms while
    keeping the ones bonded to other atoms or non-bonded at all.
    """

    mol = Chem.RWMol(mol)
    rev_atm_idxs = [atom.GetIdx() for atom in reversed(mol.GetAtoms())]

    for atm_idx in rev_atm_idxs:
        atom = mol.GetAtomWithIdx(atm_idx)
        if atom.GetAtomicNum() != 1:
            continue
        for neigh in atom.GetNeighbors():
            if neigh.GetAtomicNum() == 6:
                mol.RemoveAtom(atom.GetIdx())
    return mol


def gen_confs(mol: Chem.rdchem.Mol, num_confs: int, local_min=True):
    """Generate conformers in random orientations.

    @param mol: rdkit mol object of the molecule to be adsorbed.
    @param num_confs: number of conformers to randomly generate.
    @param local_min: bool: if generated conformers should be a local minimum.
    @return: mol: rdkit mol object containing the different conformers.
             rmsd_mtx: triangular matrix with the rmsd values of conformers.

    Using the rdkit library, conformers are randomly generated. If structures 
    are required to be local minima, ie. setting the 'local_min' value to 
    True, a geometry optimisation using UFF is performed.
    """
    logger.debug('Generating Conformers')

    conf_ids = Chem.EmbedMultipleConfs(mol, numConfs=num_confs, numThreads=0)
    if local_min:
        for conf in conf_ids:
            Chem.UFFOptimizeMolecule(mol, confId=conf)

    Chem.AlignMolConformers(mol)
    logger.info(f'Generated {len(mol.GetConformers())} conformers')
    return mol


def get_rmsd(mol: Chem.rdchem.Mol, remove_Hs="c"):
    """Computes the rmsd matrix of the conformers in a rdkit mol object.

    @param mol: rdkit mol object containing at least two conformers.
    @param remove_Hs: bool or str,
    @return rmsd_matrix: Matrix containing the rmsd values of every pair of
    conformers.

    The RMSD values of every pair of conformers is computed, stored in matrix
    form and returned back. The calculation of rmsd values can take into
    account all hydrogens, none, or only the ones not linked to carbon atoms.
    """
    if mol.GetNumConformers() < 2:
        err = "The provided molecule has less than 2 conformers"
        logger.error(err)
        raise ValueError(err)

    if not remove_Hs:
        pass
    elif remove_Hs or remove_Hs.lower() == "all":
        mol = Chem.RemoveHs(mol)
    elif remove_Hs.lower() == "c":
        mol = remove_C_linked_Hs(mol)
    else:
        pass

    num_confs = mol.GetNumConformers()
    conf_ids = list(range(num_confs))
    rmsd_mtx = np.zeros((num_confs, num_confs))
    for conf1 in conf_ids:
        for conf2 in conf_ids[conf1 + 1:]:
            rmsd = Chem.GetBestRMS(mol, mol, prbId=conf2, refId=conf1)
            rmsd_mtx[conf1][conf2] = rmsd
            rmsd_mtx[conf2][conf1] = rmsd

    return rmsd_mtx


def get_moments_of_inertia(mol: Chem.rdchem.Mol):
    """Computes the moments of inertia of the given conformers

    @param mol: rdkit mol object of the relevant molecule.
    @return numpy array 2D: The inner array contains the moments of inertia for
    the three principal axis of a given conformer. They are ordered by its value
    in ascending order. The outer tuple loops over the conformers.
    """
    from rdkit.Chem.Descriptors3D import PMI1, PMI2, PMI3

    return np.array([[PMI(mol, confId=conf) for PMI in (PMI1, PMI2, PMI3)]
                     for conf in range(mol.GetNumConformers())])


def mmff_opt_confs(mol: Chem.rdchem.Mol, max_iters=2000):
    """Optimizes the geometry of the given conformers and returns the new mol
    object and the energies of its conformers.

    @param mol: rdkit mol object of the relevant molecule.
    @param max_iters: Maximum number of geometry optimization iterations. With 0
    a single point energy calculation is performed and only the conformer
    energies are returned.
    @return mol: rdkit mol object of the optimized molecule.
    @return numpy.ndarray: Array with the energies of the optimized conformers.

    The MMFF forcefield is used for the geometry optimization in its rdkit
    implementation. With max_iters value set to 0, a single point energy
    calculation is performed and only the energies are returned. For values
    larger than 0, if the geometry does not converge for a certain conformer,
    the latter is removed from the list of conformers and its energy is not
    included in the returned list.
    """
    from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMoleculeConfs

    init_num_confs = mol.GetNumConformers()
    results = np.array(MMFFOptimizeMoleculeConfs(mol, numThreads=0,
                                                 maxIters=max_iters,
                                                 nonBondedThresh=10))

    # Remove non-converged conformers if optimization is on, ie. maxIters > 0
    # return all conformers if optimization is switched off, ie. maxIters = 0
    if max_iters > 0:
        for i, conv in enumerate(results[:, 0]):
            if conv != 0:
                mol.RemoveConformer(i)
        for i, conf in enumerate(mol.GetConformers()):
            conf.SetId(i)
        if mol.GetNumConformers() < init_num_confs:
            logger.warning(f'MMFF Geometry optimization did not comverge for at'
                           f'least one conformer. Continuing with '
                           f'{mol.GetNumConformers()} converged conformers')
        logger.info(f'Optimized conformers with MMFF.')
        return mol, np.array([res[1] for res in results if res[0] == 0])
    else:
        logger.info(f'Computed conformers energy with MMFF.')
        return np.array([res[1] for res in results])


def run_isolated(inp_vars):
    """Directs the execution of functions to obtain the conformers to adsorb

    @param inp_vars:
    @return:
    """
    # from clustering import *
    logger.info('Carrying out procedures for the isolated molecule')
    rd_mol = adapt_format('rdkit', inp_vars['molec_file'])
    confs = gen_confs(rd_mol, inp_vars['num_conformers'])
    rmsd_mtx = get_rmsd(confs)

    if 'moi' in inp_vars['cluster_magns']:
        confs_moi = get_moments_of_inertia(confs)

    if 'energy' in inp_vars['cluster_magns']:
        if inp_vars['min_confs']:
            confs, confs_eners = mmff_opt_confs(confs)
        else:
            confs_eners = mmff_opt_confs(confs, max_iters=0)

    # clustering2()
