import logging
import numpy as np
import ase

logger = logging.getLogger('DockOnSurf')


def assign_prop(atoms: ase.Atoms, prop_name: str, prop_val):  # TODO Needed?
    atoms.info[prop_name] = prop_val


def select_confs(orig_conf_list: list, magns: list, num_sel: int, code: str):
    """Takes a list ase.Atoms and selects the most different magnitude-wise.

    Given a list of ase.Atoms objects and a list of magnitudes, it selects a
    number of the most different conformers according to every magnitude
    specified.
    @param orig_conf_list: list of ase.Atoms objects to select among.
    @param magns: list of str with the names of the magnitudes to use for the
        conformer selection.
        Supported magnitudes: 'energy', 'moi'.
    @param num_sel: number of conformers to select for every of the magnitudes.
    @param code: The code that generated the magnitude information.
         Supported codes: See formats.py
    @return: list of the selected ase.Atoms objects.
    """
    from copy import deepcopy
    from modules.formats import read_energies

    conf_list = deepcopy(orig_conf_list)
    selected_ids = []
    if num_sel >= len(conf_list):
        logger.warning('Number of conformers per magnitude is equal or larger '
                       'than the total number of conformers. Using all '
                       f'available conformers: {len(conf_list)}.')
        return conf_list

    # Read properties
    if 'energy' in magns:
        conf_enrgs = read_energies(code, 'isolated')
    if 'moi' in magns:
        mois = np.array([conf.get_moments_of_inertia() for conf in conf_list])

    # Assign values
    for i, conf in enumerate(conf_list):
        assign_prop(conf, 'idx', i)
        if 'energy' in magns:
            assign_prop(conf, 'energy', conf_enrgs[i])
        if 'moi' in magns:
            assign_prop(conf, 'moi', mois[i, 2])

    # pick ids
    for magn in magns:
        sorted_list = sorted(conf_list, key=lambda conf: abs(conf.info[magn]))
        if sorted_list[-1].info['idx'] not in selected_ids:
            selected_ids.append(sorted_list[-1].info['idx'])
        if num_sel > 1:
            for i in range(0, len(sorted_list) - 1,
                           len(conf_list) // (num_sel - 1)):
                if sorted_list[i].info['idx'] not in selected_ids:
                    selected_ids.append(sorted_list[i].info['idx'])

    logger.info(f'Selected {len(selected_ids)} conformers for adsorption.')
    return [conf_list[idx] for idx in selected_ids]


def get_vect_angle(v1, v2, degrees=False):
    """Computes the angle between two vectors.

    @param v1: The first vector.
    @param v2: The second vector.
    @param degrees: Whether the result should be in radians (True) or in
        degrees (False).
    @return: The angle in radians if degrees = False, or in degrees if
        degrees =True
    """
    v1_u = v1 / np.linalg.norm(v1)
    v2_u = v2 / np.linalg.norm(v2)
    angle = np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))
    return angle if not degrees else angle * 180 / np.pi


def vect_avg(vects):
    """Computes the element-wise mean of a set of vectors.

    @param vects: list of lists-like: containing the vectors (num_vectors,
        length_vector).
    @return: vector average computed doing the element-wise mean.
    """
    from utilities import try_command
    err = "vect_avg parameter vects must be a list-like, able to be converted" \
          " np.array"
    array = try_command(np.array, [(ValueError, err)], vects)
    if len(array.shape) == 1:
        return array
    else:
        num_vects = array.shape[1]
        return np.array([np.average(array[:, i]) for i in range(num_vects)])


def get_atom_coords(atoms: ase.Atoms, ctrs_list=None):
    """Gets the coordinates of the specified indices from a ase.Atoms object.

    Given an ase.Atoms object and a list of atom indices specified in ctrs_list
    it gets the coordinates of the specified atoms. If the element in the
    ctrs_list is not an index but yet a list of indices, it computes the
    element-wise mean of the coordinates of the atoms specified in the inner
    list.
    @param atoms: ase.Atoms object for which to obtain the coordinates of.
    @param ctrs_list: list of (indices/list of indices) of the atoms for which
                      the coordinates should be extracted.
    @return: np.ndarray of atomic coordinates.
    """
    coords = []
    err = "'ctrs_list' argument must be an integer, a list of integers or a " \
          "list of lists of integers. Every integer must be in the range " \
          "[0, num_atoms)"
    if ctrs_list is None:
        ctrs_list = range(len(atoms))
    elif isinstance(ctrs_list, int):
        if ctrs_list not in range(len(atoms)):
            logger.error(err)
            raise ValueError(err)
        return atoms[ctrs_list].position
    for elem in ctrs_list:
        if isinstance(elem, list):
            coords.append(vect_avg([atoms[c].position for c in elem]))
        elif isinstance(elem, int):
            coords.append(atoms[elem].position)
        else:

            logger.error(err)
            raise ValueError
    return np.array(coords)


def add_adsorbate(slab, adsorbate, site_coord, ctr_coord, height, offset=None,
                  norm_vect=(0, 0, 1)):
    """Add an adsorbate to a surface.

    This function extends the functionality of ase.build.add_adsorbate
    (https://wiki.fysik.dtu.dk/ase/ase/build/surface.html#ase.build.add_adsorbate)
    by enabling to change the z coordinate and the axis perpendicular to the
    surface.
    @param slab: ase.Atoms object containing the coordinates of the surface
    @param adsorbate: ase.Atoms object containing the coordinates of the
        adsorbate.
    @param site_coord: The coordinates of the adsorption site on the surface.
    @param ctr_coord: The coordinates of the adsorption center in the molecule.
    @param height: The height above the surface where to adsorb.
    @param offset: Offsets the adsorbate by a number of unit cells. Mostly
        useful when adding more than one adsorbate.
    @param norm_vect: The vector perpendicular to the surface.
    """
    from copy import deepcopy
    info = slab.info.get('adsorbate_info', {})
    pos = np.array([0.0, 0.0, 0.0])  # part of absolute coordinates
    spos = np.array([0.0, 0.0, 0.0])  # part relative to unit cell
    norm_vect_u = np.array(norm_vect) / np.linalg.norm(norm_vect)
    if offset is not None:
        spos += np.asarray(offset, float)
    if isinstance(site_coord, str):
        # A site-name:
        if 'sites' not in info:
            raise TypeError('If the atoms are not made by an ase.build '
                            'function, position cannot be a name.')
        if site_coord not in info['sites']:
            raise TypeError('Adsorption site %s not supported.' % site_coord)
        spos += info['sites'][site_coord]
    else:
        pos += site_coord
    if 'cell' in info:
        cell = info['cell']
    else:
        cell = slab.get_cell()
    pos += np.dot(spos, cell)
    # Convert the adsorbate to an Atoms object
    if isinstance(adsorbate, ase.Atoms):
        ads = deepcopy(adsorbate)
    elif isinstance(adsorbate, ase.Atom):
        ads = ase.Atoms([adsorbate])
    else:
        # Assume it is a string representing a single Atom
        ads = ase.Atoms([ase.Atom(adsorbate)])
    pos += height * norm_vect_u
    # Move adsorbate into position
    ads.translate(pos - ctr_coord)
    # Attach the adsorbate
    slab.extend(ads)


def check_collision(slab_molec, slab_num_atoms, min_height, vect, nn_slab=0,
                    nn_molec=0, coll_coeff=0.9):
    """Checks whether a slab and a molecule collide or not.

    @param slab_molec: The system of adsorbate-slab for which to detect if there
        are collisions.
    @param nn_slab: Number of neigbors in the surface.
    @param nn_molec: Number of neighbors in the molecule.
    @param coll_coeff: The coefficient that multiplies the covalent radius of
        atoms resulting in a distance that two atoms being closer to that is
        considered as atomic collision.
    @param slab_num_atoms: Number of atoms of the bare slab.
    @param min_height: The minimum height atoms can have to not be considered as
        colliding.
    @param vect: The vector perpendicular to the slab.
    @return: bool, whether the surface and the molecule collide.
    """
    from ase.neighborlist import natural_cutoffs, neighbor_list
    if min_height is not False:
        cart_axes = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0],
                     [-1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, -1.0]]
        if vect.tolist() in cart_axes:
            for atom in slab_molec[slab_num_atoms:]:
                for i, coord in enumerate(vect):
                    if coord == 0:
                        continue
                    if atom.position[i] * coord < min_height:
                        return True
            return False
    else:
        slab_molec_cutoffs = natural_cutoffs(slab_molec, mult=coll_coeff)
        slab_molec_nghbs = len(
            neighbor_list("i", slab_molec, slab_molec_cutoffs))
        if slab_molec_nghbs > nn_slab + nn_molec:
            return True
        else:
            return False


def dissociate_h(slab_molec_orig, h_idx, num_atoms_slab, neigh_cutoff=1):
    # TODO rethink
    """Tries to dissociate a H from the molecule and adsorbs it on the slab.

    Tries to dissociate a H atom from the molecule and adsorb in on top of the
    surface if the distance is shorter than two times the neigh_cutoff value.
    @param slab_molec_orig: The ase.Atoms object of the system adsorbate-slab.
    @param h_idx: The index of the hydrogen atom to carry out adsorption of.
    @param num_atoms_slab: The number of atoms of the slab without adsorbate.
    @param neigh_cutoff: half the maximum distance between the surface and the
        H for it to carry out dissociation.
    @return: An ase.Atoms object of the system adsorbate-surface with H
    """
    from copy import deepcopy
    from ase.neighborlist import NeighborList
    slab_molec = deepcopy(slab_molec_orig)
    cutoffs = len(slab_molec) * [neigh_cutoff]
    nl = NeighborList(cutoffs, self_interaction=False, bothways=True)
    nl.update(slab_molec)
    surf_h_vect = np.array([np.infty] * 3)
    for neigh_idx in nl.get_neighbors(h_idx)[0]:
        if neigh_idx < num_atoms_slab:
            dist = np.linalg.norm(slab_molec[neigh_idx].position -
                                  slab_molec[h_idx].position)
            if dist < np.linalg.norm(surf_h_vect):
                surf_h_vect = slab_molec[neigh_idx].position \
                              - slab_molec[h_idx].position
    if np.linalg.norm(surf_h_vect) != np.infty:
        trans_vect = surf_h_vect - surf_h_vect / np.linalg.norm(surf_h_vect)
        slab_molec[h_idx].position = slab_molec[h_idx].position + trans_vect
        return slab_molec


def dissociation(slab_molec, disso_atoms, num_atoms_slab):
    # TODO rethink
    # TODO multiple dissociation
    """Decides which H atoms to dissociate according to a list of atoms.

    Given a list of chemical symbols or atom indices it checks for every atom
    or any of its neighbor if it's a H and calls dissociate_h to try to carry
    out dissociation of that H. For atom indices, it checks both whether
    the atom index or its neighbors are H, for chemical symbols, it only checks
    if there is a neighbor H.
    @param slab_molec: The ase.Atoms object of the system adsorbate-slab.
    @param disso_atoms: The indices or chemical symbols of the atoms
    @param num_atoms_slab:
    @return:
    """
    from ase.neighborlist import natural_cutoffs, NeighborList
    molec = slab_molec[num_atoms_slab:]
    cutoffs = natural_cutoffs(molec)
    nl = NeighborList(cutoffs, self_interaction=False, bothways=True)
    nl.update(molec)
    disso_structs = []
    for el in disso_atoms:
        if isinstance(el, int):
            if molec[el].symbol == 'H':
                disso_struct = dissociate_h(slab_molec, el + num_atoms_slab,
                                            num_atoms_slab)
                if disso_struct is not None:
                    disso_structs.append(disso_struct)
            else:
                for neigh_idx in nl.get_neighbors(el)[0]:
                    if molec[neigh_idx].symbol == 'H':
                        disso_struct = dissociate_h(slab_molec, neigh_idx +
                                                    num_atoms_slab,
                                                    num_atoms_slab)
                        if disso_struct is not None:
                            disso_structs.append(disso_struct)
        else:
            for atom in molec:
                if atom.symbol.lower() == el.lower():
                    for neigh_idx in nl.get_neighbors(atom.index)[0]:
                        if molec[neigh_idx].symbol == 'H':
                            disso_struct = dissociate_h(slab_molec, neigh_idx \
                                                        + num_atoms_slab,
                                                        num_atoms_slab)
                            if disso_struct is not None:
                                disso_structs.append(disso_struct)
    return disso_structs


def correct_coll(molec, slab, ctr_coord, site_coord, num_pts,
                 min_coll_height, norm_vect, slab_nghbs, molec_nghbs,
                 coll_coeff, height=2.5):
    # TODO Rethink this function
    """Tries to adsorb a molecule on a slab trying to avoid collisions by doing
    small rotations.

    @param molec: ase.Atoms object of the molecule to adsorb
    @param slab: ase.Atoms object of the surface on which to adsorb the
        molecule
    @param ctr_coord: The coordinates of the molecule to use as adsorption
        center.
    @param site_coord: The coordinates of the surface on which to adsorb the
        molecule
    @param num_pts: Number on which to sample Euler angles.
    @param min_coll_height: The lowermost height for which to detect a collision
    @param norm_vect: The vector perpendicular to the surface.
    @param slab_nghbs: Number of neigbors in the surface.
    @param molec_nghbs: Number of neighbors in the molecule.
    @param coll_coeff: The coefficient that multiplies the covalent radius of
        atoms resulting in a distance that two atoms being closer to that is
        considered as atomic collision.
    @param height: Height on which to try adsorption
    @return collision: bool, whether the structure generated has collisions
        between slab and adsorbate.
    """
    from copy import deepcopy
    slab_num_atoms = len(slab)
    collision = True
    max_corr = 6  # Should be an even number
    d_angle = 180 / ((max_corr / 2.0) * num_pts)
    num_corr = 0
    while collision and num_corr <= max_corr:
        k = num_corr * (-1) ** num_corr
        slab_molec = deepcopy(slab)
        molec.euler_rotate(k * d_angle, k * d_angle / 2, k * d_angle,
                           center=ctr_coord)
        add_adsorbate(slab_molec, molec, site_coord, ctr_coord, height,
                      norm_vect=norm_vect)
        collision = check_collision(slab_molec, slab_num_atoms, min_coll_height,
                                    norm_vect, slab_nghbs, molec_nghbs,
                                    coll_coeff)
        num_corr += 1
    return slab_molec, collision


def ads_euler(orig_molec, slab, ctr_coord, site_coord, num_pts,
              min_coll_height, coll_coeff, norm_vect, slab_nghbs, molec_nghbs,
              disso_atoms):
    """Generates adsorbate-surface structures by sampling over Euler angles.

    This function generates a number of adsorbate-surface structures at
    different orientations of the adsorbate sampled at multiple Euler (zxz)
    angles.
    @param orig_molec: ase.Atoms object of the molecule to adsorb
    @param slab: ase.Atoms object of the surface on which to adsorb the molecule
    @param ctr_coord: The coordinates of the molecule to use as adsorption
        center.
    @param site_coord: The coordinates of the surface on which to adsorb the
        molecule
    @param num_pts: Number on which to sample Euler angles.
    @param min_coll_height: The lowermost height for which to detect a collision
    @param coll_coeff: The coefficient that multiplies the covalent radius of
        atoms resulting in a distance that two atoms being closer to that is
        considered as atomic collision.
    @param norm_vect: The vector perpendicular to the surface.
    @param slab_nghbs: Number of neigbors in the surface.
    @param molec_nghbs: Number of neighbors in the molecule.
    @param disso_atoms: List of atom types or atom numbers to try to dissociate.
    @return: list of ase.Atoms object conatining all the orientations of a given
        conformer
    """
    from copy import deepcopy
    slab_ads_list = []
    # rotation around z
    for alpha in np.arange(0, 360, 360 / num_pts):
        # rotation around x'
        for beta in np.arange(0, 180, 180 / num_pts):
            # rotation around z'
            for gamma in np.arange(0, 360, 360 / num_pts):
                molec = deepcopy(orig_molec)
                molec.euler_rotate(alpha, beta, gamma, center=ctr_coord)
                slab_molec, collision = correct_coll(molec, slab,
                                                     ctr_coord, site_coord,
                                                     num_pts, min_coll_height,
                                                     norm_vect,
                                                     slab_nghbs, molec_nghbs,
                                                     coll_coeff)
                if not collision and \
                        not any([(slab_molec.positions == conf.positions).all()
                                 for conf in slab_ads_list]):
                    slab_ads_list.append(slab_molec)
                    slab_ads_list.extend(dissociation(slab_molec, disso_atoms,
                                                      len(slab)))

    return slab_ads_list


def ads_chemcat(site, ctr, pts_angle):
    return "TO IMPLEMENT"


def adsorb_confs(conf_list, surf, molec_ctrs, sites, algo, num_pts, neigh_ctrs,
                 norm_vect, min_coll_height, coll_coeff, disso_atoms):
    """Generates a number of adsorbate-surface structure coordinates.

    Given a list of conformers, a surface, a list of atom indices (or list of
    list of indices) of both the surface and the adsorbate, it generates a
    number of adsorbate-surface structures for every possible combination of
    them at different orientations.
    @param conf_list: list of ase.Atoms of the different conformers
    @param surf: the ase.Atoms object of the surface
    @param molec_ctrs: the list atom indices of the adsorbate.
    @param sites: the list of atom indices of the surface.
    @param algo: the algorithm to use for the generation of adsorbates.
    @param num_pts: the number of points per angle orientation to sample
    @param neigh_ctrs: the indices of the neighboring atoms to the adsorption
        atoms.
    @param norm_vect: The vector perpendicular to the surface.
    @param min_coll_height: The lowermost height for which to detect a collision
    @param coll_coeff: The coefficient that multiplies the covalent radius of
        atoms resulting in a distance that two atoms being closer to that is
        considered as atomic collision.
    @param disso_atoms: List of atom types or atom numbers to try to dissociate.
    @return: list of ase.Atoms for the adsorbate-surface structures
    """
    from ase.neighborlist import natural_cutoffs, neighbor_list
    surf_ads_list = []
    sites_coords = get_atom_coords(surf, sites)
    if min_coll_height is False:
        surf_cutoffs = natural_cutoffs(surf, mult=coll_coeff)
        surf_nghbs = len(neighbor_list("i", surf, surf_cutoffs))
    else:
        surf_nghbs = 0
    for conf in conf_list:
        molec_ctr_coords = get_atom_coords(conf, molec_ctrs)
        molec_neigh_coords = get_atom_coords(conf, neigh_ctrs)
        if min_coll_height is False:
            conf_cutoffs = natural_cutoffs(conf, mult=coll_coeff)
            molec_nghbs = len(neighbor_list("i", conf, conf_cutoffs))
        else:
            molec_nghbs = 0
        for site in sites_coords:
            for ctr in molec_ctr_coords:
                if algo == 'euler':
                    surf_ads_list.extend(ads_euler(conf, surf, ctr, site,
                                                   num_pts, min_coll_height,
                                                   coll_coeff, norm_vect,
                                                   surf_nghbs, molec_nghbs,
                                                   disso_atoms))
                elif algo == 'chemcat':
                    surf_ads_list.extend(ads_chemcat(site, ctr, num_pts))
    return surf_ads_list


def run_screening(inp_vars):
    """Carries out the screening of adsorbate structures on a surface.

    @param inp_vars: Calculation parameters from input file.
    """
    import os
    import random
    from modules.formats import read_coords, adapt_format
    from modules.calculation import run_calc

    if not os.path.isdir("isolated"):
        err = "'isolated' directory not found. It is needed in order to carry "
        "out the screening of structures to be adsorbed"
        logger.error(err)
        raise ValueError(err)

    logger.info('Carrying out procedures for the screening of adsorbate-surface'
                ' structures.')
    conf_list = read_coords(inp_vars['code'], 'isolated', 'ase',
                            inp_vars['special_atoms'])
    logger.info(f"Found {len(conf_list)} structures of isolated conformers.")
    selected_confs = select_confs(conf_list, inp_vars['select_magns'],
                                  inp_vars['confs_per_magn'],
                                  inp_vars['code'])
    surf = adapt_format('ase', inp_vars['surf_file'], inp_vars['special_atoms'])
    surf_ads_list = adsorb_confs(selected_confs, surf,
                                 inp_vars['molec_ads_ctrs'], inp_vars['sites'],
                                 inp_vars['ads_algo'],
                                 inp_vars['sample_points_per_angle'],
                                 inp_vars['molec_neigh_ctrs'],
                                 inp_vars['surf_norm_vect'],
                                 inp_vars['min_coll_height'],
                                 inp_vars['collision_threshold'],
                                 inp_vars['disso_atoms'])
    if len(surf_ads_list) > inp_vars['max_structures']:
        surf_ads_list = random.sample(surf_ads_list, inp_vars['max_structures'])
    logger.info(f'Generated {len(surf_ads_list)} adsorbate-surface atomic '
                f'configurations to carry out a calculation of.')
    for i in range(len(surf_ads_list)):
        for j in range(i):
            if (surf_ads_list[i].positions == surf_ads_list[j].positions).all():
                logger.warning(i, j, 'Same')

    run_calc('screening', inp_vars, surf_ads_list)
    logger.info('Finished the procedures for the screening of adsorbate-surface'
                ' structures section.')
