import logging
import numpy as np
import ase

logger = logging.getLogger('DockOnSurf')


def assign_prop(atoms: ase.Atoms, prop_name: str, prop_val):  # TODO Needed?
    atoms.info[prop_name] = prop_val


def select_confs(orig_conf_list: list, magns: list, num_sel: int, code: str):
    """Takes a list ase.Atoms and selects the most different magnitude-wise.

    Given a list of ase.Atoms objects and a list of magnitudes, it selects a
    number of the most different conformers according to every magnitude
    specified.
    @param orig_conf_list: list of ase.Atoms objects to select among.
    @param magns: list of str with the names of the magnitudes to use for the
        conformer selection.
        Supported magnitudes: 'energy', 'moi'.
    @param num_sel: number of conformers to select for every of the magnitudes.
    @param code: The code that generated the magnitude information.
         Supported codes: See formats.py
    @return: list of the selected ase.Atoms objects.
    """
    from copy import deepcopy
    from modules.formats import read_energies

    conf_list = deepcopy(orig_conf_list)
    selected_ids = []
    if num_sel >= len(conf_list):
        logger.warning('Number of conformers per magnitude is equal or larger '
                       'than the total number of conformers. Using all '
                       f'available conformers: {len(conf_list)}.')
        return conf_list

    # Read properties
    if 'energy' in magns:
        conf_enrgs = read_energies(code, 'isolated')
    if 'moi' in magns:
        mois = np.array([conf.get_moments_of_inertia() for conf in conf_list])

    # Assign values
    for i, conf in enumerate(conf_list):
        assign_prop(conf, 'idx', i)
        if 'energy' in magns:
            assign_prop(conf, 'energy', conf_enrgs[i])
        if 'moi' in magns:
            assign_prop(conf, 'moi', mois[i, 2])

    # pick ids
    for magn in magns:
        sorted_list = sorted(conf_list, key=lambda conf: abs(conf.info[magn]))
        if sorted_list[-1].info['idx'] not in selected_ids:
            selected_ids.append(sorted_list[-1].info['idx'])
        if num_sel > 1:
            for i in range(0, len(sorted_list) - 1,
                           len(conf_list) // (num_sel - 1)):
                if sorted_list[i].info['idx'] not in selected_ids:
                    selected_ids.append(sorted_list[i].info['idx'])

    logger.info(f'Selected {len(selected_ids)} conformers for adsorption.')
    return [conf_list[idx] for idx in selected_ids]


def get_vect_angle(v1, v2, degrees=False):
    """Computes the angle between two vectors.

    @param v1: The first vector.
    @param v2: The second vector.
    @param degrees: Whether the result should be in radians (True) or in
        degrees (False).
    @return: The angle in radians if degrees = False, or in degrees if
        degrees =True
    """
    v1_u = v1 / np.linalg.norm(v1)
    v2_u = v2 / np.linalg.norm(v2)
    angle = np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))
    return angle if not degrees else angle * 180 / np.pi


def vect_avg(vects):
    """Computes the element-wise mean of a set of vectors.

    @param vects: list of lists-like: containing the vectors (num_vectors,
        length_vector).
    @return: vector average computed doing the element-wise mean.
    """
    from utilities import try_command
    err = "vect_avg parameter vects must be a list-like, able to be converted" \
          " np.array"
    array = try_command(np.array, [(ValueError, err)], vects)
    if len(array.shape) == 1:
        return array
    else:
        num_vects = array.shape[1]
        return np.array([np.average(array[:, i]) for i in range(num_vects)])


def get_atom_coords(atoms: ase.Atoms, ctrs_list=None):
    """Gets the coordinates of the specified indices from a ase.Atoms object.

    Given an ase.Atoms object and a list of atom indices specified in ctrs_list
    it gets the coordinates of the specified atoms. If the element in the
    ctrs_list is not an index but yet a list of indices, it computes the
    element-wise mean of the coordinates of the atoms specified in the inner
    list.
    @param atoms: ase.Atoms object for which to obtain the coordinates of.
    @param ctrs_list: list of (indices/list of indices) of the atoms for which
                      the coordinates should be extracted.
    @return: np.ndarray of atomic coordinates.
    """
    coords = []
    err = "'ctrs_list' argument must be an integer, a list of integers or a " \
          "list of lists of integers. Every integer must be in the range " \
          "[0, num_atoms)"
    if ctrs_list is None:
        ctrs_list = range(len(atoms))
    elif isinstance(ctrs_list, int):
        if ctrs_list not in range(len(atoms)):
            logger.error(err)
            raise ValueError(err)
        return atoms[ctrs_list].position
    for elem in ctrs_list:
        if isinstance(elem, list):
            coords.append(vect_avg([atoms[c].position for c in elem]))
        elif isinstance(elem, int):
            coords.append(atoms[elem].position)
        else:

            logger.error(err)
            raise ValueError
    return np.array(coords)


def add_adsorbate(slab, adsorbate, site_coord, ctr_coord, height, offset=None,
                  norm_vect=(0, 0, 1)):
    """Add an adsorbate to a surface.

    This function extends the functionality of ase.build.add_adsorbate
    (https://wiki.fysik.dtu.dk/ase/ase/build/surface.html#ase.build.add_adsorbate)
    by enabling to change the z coordinate and the axis perpendicular to the
    surface.
    @param slab: ase.Atoms object containing the coordinates of the surface
    @param adsorbate: ase.Atoms object containing the coordinates of the
        adsorbate.
    @param site_coord: The coordinates of the adsorption site on the surface.
    @param ctr_coord: The coordinates of the adsorption center in the molecule.
    @param height: The height above the surface where to adsorb.
    @param offset: Offsets the adsorbate by a number of unit cells. Mostly
        useful when adding more than one adsorbate.
    @param norm_vect: The vector perpendicular to the surface.
    """
    from copy import deepcopy
    info = slab.info.get('adsorbate_info', {})
    pos = np.array([0.0, 0.0, 0.0])  # part of absolute coordinates
    spos = np.array([0.0, 0.0, 0.0])  # part relative to unit cell
    norm_vect_u = np.array(norm_vect) / np.linalg.norm(norm_vect)
    if offset is not None:
        spos += np.asarray(offset, float)
    if isinstance(site_coord, str):
        # A site-name:
        if 'sites' not in info:
            raise TypeError('If the atoms are not made by an ase.build '
                            'function, position cannot be a name.')
        if site_coord not in info['sites']:
            raise TypeError('Adsorption site %s not supported.' % site_coord)
        spos += info['sites'][site_coord]
    else:
        pos += site_coord
    if 'cell' in info:
        cell = info['cell']
    else:
        cell = slab.get_cell()
    pos += np.dot(spos, cell)
    # Convert the adsorbate to an Atoms object
    if isinstance(adsorbate, ase.Atoms):
        ads = deepcopy(adsorbate)
    elif isinstance(adsorbate, ase.Atom):
        ads = ase.Atoms([adsorbate])
    else:
        # Assume it is a string representing a single Atom
        ads = ase.Atoms([ase.Atom(adsorbate)])
    pos += height * norm_vect_u
    # Move adsorbate into position
    ads.translate(pos - ctr_coord)
    # Attach the adsorbate
    slab.extend(ads)


def check_collision(slab_molec, slab_num_atoms, min_height, vect, nn_slab=0,
                    nn_molec=0, coll_coeff=0.9):
    """Checks whether a slab and a molecule collide or not.

    @param slab_molec: The system of adsorbate-slab for which to detect if there
        are collisions.
    @param nn_slab: Number of neigbors in the surface.
    @param nn_molec: Number of neighbors in the molecule.
    @param coll_coeff: The coefficient that multiplies the covalent radius of
        atoms resulting in a distance that two atoms being closer to that is
        considered as atomic collision.
    @param slab_num_atoms: Number of atoms of the bare slab.
    @param min_height: The minimum height atoms can have to not be considered as
        colliding.
    @param vect: The vector perpendicular to the slab.
    @return: bool, whether the surface and the molecule collide.
    """
    from ase.neighborlist import natural_cutoffs, neighbor_list
    if min_height is not False:
        cart_axes = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0],
                     [-1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, -1.0]]
        if vect.tolist() in cart_axes:
            for atom in slab_molec[slab_num_atoms:]:
                for i, coord in enumerate(vect):
                    if coord == 0:
                        continue
                    if atom.position[i] * coord < min_height:
                        return True
            return False
    else:
        slab_molec_cutoffs = natural_cutoffs(slab_molec, mult=coll_coeff)
        slab_molec_nghbs = len(neighbor_list("i", slab_molec, slab_molec_cutoffs))
        if slab_molec_nghbs > nn_slab + nn_molec:
            return True
        else:
            return False


def correct_coll(molec, slab, ctr_coord, site_coord, num_pts,
                 min_coll_height, norm_vect, slab_nghbs, molec_nghbs,
                 coll_coeff):
    # TODO Rethink this function
    """Tries to adsorb a molecule on a slab trying to avoid collisions by doing
    small rotations.

    @param molec: ase.Atoms object of the molecule to adsorb
    @param slab: ase.Atoms object of the surface on which to adsorb the
        molecule
    @param ctr_coord: The coordinates of the molecule to use as adsorption
        center.
    @param site_coord: The coordinates of the surface on which to adsorb the
        molecule
    @param num_pts: Number on which to sample Euler angles.
    @param min_coll_height: The lowermost height for which to detect a collision
    @param norm_vect: The vector perpendicular to the surface.
    @param slab_nghbs: Number of neigbors in the surface.
    @param molec_nghbs: Number of neighbors in the molecule.
    @param coll_coeff: The coefficient that multiplies the covalent radius of
        atoms resulting in a distance that two atoms being closer to that is
        considered as atomic collision.
    @return collision: bool, whether the structure generated has collisions
        between slab and adsorbate.
    """
    from copy import deepcopy
    slab_num_atoms = len(slab)
    collision = True
    max_corr = 6  # Should be an even number
    d_angle = 180 / ((max_corr / 2.0) * num_pts)
    num_corr = 0
    while collision and num_corr <= max_corr:
        k = num_corr * (-1) ** num_corr
        slab_molec = deepcopy(slab)
        molec.euler_rotate(k * d_angle, k * d_angle / 2, k * d_angle,
                           center=ctr_coord)
        add_adsorbate(slab_molec, molec, site_coord, ctr_coord, 2.5,
                      norm_vect=norm_vect)
        collision = check_collision(slab_molec, slab_num_atoms, min_coll_height,
                                    norm_vect, slab_nghbs, molec_nghbs,
                                    coll_coeff)
        num_corr += 1
    return slab_molec, collision


def ads_euler(orig_molec, slab, ctr_coord, site_coord, num_pts,
              min_coll_height, coll_coeff, norm_vect, slab_nghbs, molec_nghbs):
    """Generates adsorbate-surface structures by sampling over Euler angles.

    This function generates a number of adsorbate-surface structures at
    different orientations of the adsorbate sampled at multiple Euler (zxz)
    angles.
    @param orig_molec: ase.Atoms object of the molecule to adsorb
    @param slab: ase.Atoms object of the surface on which to adsorb the molecule
    @param ctr_coord: The coordinates of the molecule to use as adsorption
        center.
    @param site_coord: The coordinates of the surface on which to adsorb the
        molecule
    @param num_pts: Number on which to sample Euler angles.
    @param min_coll_height: The lowermost height for which to detect a collision
    @param coll_coeff: The coefficient that multiplies the covalent radius of
        atoms resulting in a distance that two atoms being closer to that is
        considered as atomic collision.
    @param norm_vect: The vector perpendicular to the surface.
    @param slab_nghbs: Number of neigbors in the surface.
    @param molec_nghbs: Number of neighbors in the molecule.
    @return: list of ase.Atoms object conatining all the orientations of a given
        conformer
    """
    from copy import deepcopy
    slab_ads_list = []
    # rotation around z
    for alpha in np.arange(0, 360, 360 / num_pts):
        # rotation around x'
        for beta in np.arange(0, 180, 180 / num_pts):
            # rotation around z'
            for gamma in np.arange(0, 360, 360 / num_pts):
                molec = deepcopy(orig_molec)
                molec.euler_rotate(alpha, beta, gamma, center=ctr_coord)
                slab_molec, collision = correct_coll(molec, slab,
                                                     ctr_coord, site_coord,
                                                     num_pts, min_coll_height,
                                                     norm_vect,
                                                     slab_nghbs, molec_nghbs,
                                                     coll_coeff)
                if not collision:
                    slab_ads_list.append(slab_molec)

    return slab_ads_list


def ads_chemcat(site, ctr, pts_angle):
    return "TO IMPLEMENT"


def adsorb_confs(conf_list, surf, molec_ctrs, sites, algo, num_pts, neigh_ctrs,
                 norm_vect, min_coll_height, coll_coeff):
    """Generates a number of adsorbate-surface structure coordinates.

    Given a list of conformers, a surface, a list of atom indices (or list of
    list of indices) of both the surface and the adsorbate, it generates a
    number of adsorbate-surface structures for every possible combination of
    them at different orientations.
    @param conf_list: list of ase.Atoms of the different conformers
    @param surf: the ase.Atoms object of the surface
    @param molec_ctrs: the list atom indices of the adsorbate.
    @param sites: the list of atom indices of the surface.
    @param algo: the algorithm to use for the generation of adsorbates.
    @param num_pts: the number of points per angle orientation to sample
    @param neigh_ctrs: the indices of the neighboring atoms to the adsorption
        atoms.
    @param norm_vect: The vector perpendicular to the surface.
    @param min_coll_height: The lowermost height for which to detect a collision
    @param coll_coeff: The coefficient that multiplies the covalent radius of
        atoms resulting in a distance that two atoms being closer to that is
        considered as atomic collision.
    @return: list of ase.Atoms for the adsorbate-surface structures
    """
    from ase.neighborlist import natural_cutoffs, neighbor_list
    surf_ads_list = []
    sites_coords = get_atom_coords(surf, sites)
    if min_coll_height is not False:
        surf_cutoffs = natural_cutoffs(surf, mult=coll_coeff)
        surf_nghbs = len(neighbor_list("i", surf, surf_cutoffs))
    else:
        surf_nghbs = 0
    for conf in conf_list:
        molec_ctr_coords = get_atom_coords(conf, molec_ctrs)
        molec_neigh_coords = get_atom_coords(conf, neigh_ctrs)
        if min_coll_height is not False:
            conf_cutoffs = natural_cutoffs(conf, mult=coll_coeff)
            molec_nghbs = len(neighbor_list("i", conf, conf_cutoffs))
        else:
            molec_nghbs = 0
        for site in sites_coords:
            for ctr in molec_ctr_coords:
                if algo == 'euler':
                    surf_ads_list.extend(ads_euler(conf, surf, ctr, site,
                                                   num_pts, min_coll_height,
                                                   coll_coeff, norm_vect,
                                                   surf_nghbs, molec_nghbs))
                elif algo == 'chemcat':
                    surf_ads_list.extend(ads_chemcat(site, ctr, num_pts))
    return surf_ads_list


def run_screening(inp_vars):
    """Carry out the screening of adsorbate coordinates on a surface

    @param inp_vars: Calculation parameters from input file.
    """
    import os
    from modules.formats import read_coords, adapt_format
    from modules.calculation import run_calc

    if not os.path.isdir("isolated"):
        err = "'isolated' directory not found. It is needed in order to carry "
        "out the screening of structures to be adsorbed"
        logger.error(err)
        raise ValueError(err)

    conf_list = read_coords(inp_vars['code'], 'isolated', 'ase')
    logger.info(f"Found {len(conf_list)} structures of isolated conformers.")
    selected_confs = select_confs(conf_list, inp_vars['select_magns'],
                                  inp_vars['confs_per_magn'],
                                  inp_vars['code'])
    surf = adapt_format('ase', inp_vars['surf_file'])
    surf_ads_list = adsorb_confs(selected_confs, surf,
                                 inp_vars['molec_ads_ctrs'], inp_vars['sites'],
                                 inp_vars['ads_algo'],
                                 inp_vars['sample_points_per_angle'],
                                 inp_vars['molec_neigh_ctrs'],
                                 inp_vars['surf_norm_vect'],
                                 inp_vars['min_coll_height'],
                                 inp_vars['collision_threshold'])
    logger.info(f'Generated {len(surf_ads_list)} adsorbate-surface atomic '
                f'configurations, to carry out a calculation of.')
    run_calc('screening', inp_vars, surf_ads_list)
