""" This is a QM:MM embedded system for ASE

torsten.kerber@ens-lyon.fr
"""

import ase
import ase.atoms
import numpy as np
from general import Calculator
from ase.embed import Embed
from ase.units import Hartree

from copy import deepcopy

import sys, os

class Qmx(Calculator):
    def __init__(self, calculator_high_cluster,  calculator_low_cluster,  calculator_low_system=None,  print_forces=False):
        self._constraints=None
        
        self.calculator_low_cluster = calculator_low_cluster
        self.calculator_low_system = calculator_low_system
        if self.calculator_low_system is None:
            self.calculator_low_system = deepcopy(calculator_low_cluster)
        self.calculator_high_cluster = calculator_high_cluster
        self.print_forces = print_forces

    def get_energy_subsystem(self, path, calculator, atoms, force_consistent):
        # go to directory and calculate energies
        print "running energy in: ", path
        os.chdir(path)
        atoms.set_calculator(calculator)
        energy = atoms.get_potential_energy()
        os.chdir("..")
        return energy

    def get_forces_subsystem(self, path, calculator, atoms):
        # go to directory and calculate forces
        print "running forces in: ", path
        os.chdir(path)
        atoms.set_calculator(calculator)
        forces = atoms.get_forces()
        os.chdir("..")
        return forces

    def get_potential_energy(self, embed, force_consistent=False):
        # perform energy calculations
        e_sys_lo = self.get_energy_subsystem("system.low-level", self.calculator_low_system, embed.get_system(), force_consistent)
        e_cl_lo  = self.get_energy_subsystem("cluster.low-level", self.calculator_low_cluster, embed.get_cluster(), force_consistent)
        e_cl_hi  = self.get_energy_subsystem("cluster.high-level", self.calculator_high_cluster, embed.get_cluster(), force_consistent)
        # calculate energies
        energy = e_sys_lo - e_cl_lo + e_cl_hi
        # print energies
        print "%20s = %15s - %15s + %15s" %("E(C:S)", "E(S-MM)", "E(C-MM)", "E(C-QM)")
        print "%20f = %15f - %15f + %15f" %(energy, e_sys_lo, e_cl_lo, e_cl_hi)
        # set energies and return
        if force_consistent:
            self.energy_free = energy
            return self.energy_free
        else:
            self.energy_zero = energy
            return self.energy_zero

    def get_forces(self, embed):
        atom_map_sys_cl = embed.atom_map_sys_cl
        # get forces for the three systems
        f_sys_lo = self.get_forces_subsystem("system.low-level", self.calculator_low_system, embed.get_system())
        f_cl_lo  = self.get_forces_subsystem("cluster.low-level", self.calculator_low_cluster, embed.get_cluster())
        f_cl_hi  = self.get_forces_subsystem("cluster.high-level", self.calculator_high_cluster, embed.get_cluster())

        # forces correction for the atoms
        f_cl = f_cl_hi - f_cl_lo

        if self.print_forces:
            cluster=embed.get_cluster()
            print "Forces: System LOW - Cluster LOW + Cluster HIGH"
            for iat_sys in xrange(len(embed)):
                print "%-2s (" % embed[iat_sys].get_symbol(),
                for idir in xrange(3):
                    print "%10.6f" % f_sys_lo[iat_sys][idir], 
                print ") <system LOW>"

                iat_cl = atom_map_sys_cl[iat_sys]
                if iat_cl > -1:
                    print "%s" % "-  (",
                    for idir in xrange(3):
                        print "%10.6f" % f_cl_lo[iat_cl][idir], 
                    print ") <cluster LOW>"
                    print "%s" % "+  (",
                    for idir in xrange(3):
                        print "%10.6f" % f_cl_hi[iat_cl][idir], 
                    print ") <cluster HIGH>"
                print
            print
    
        # lo-sys + (hi-lo)
        for iat_sys in xrange(len(embed)):
            iat_cl = atom_map_sys_cl[iat_sys]
            if iat_cl > -1:
                f_sys_lo[iat_sys] += f_cl[iat_cl]
        # some settings for the output
        i_change = np.zeros(len(embed), int)
        if self.print_forces:
            f_sys_lo_orig = f_sys_lo.copy()
        # correct gradients
        # Reference: Eichler, Koelmel, Sauer, J. of Comput. Chem., 18(4). 1997, 463-477.
        for cell_L, iat_cl_sys, iat_sys, r, iat_link in embed.linkatoms:
            # calculate the bond distance (r_bond) at the border
            xyz = embed[iat_sys].get_position() - embed[iat_cl_sys].get_position() + cell_L
            # calculate the bond lenght and the factor f
            rbond = np.sqrt(np.dot(xyz, xyz))
            f = r / rbond
            #normalize xyz
            xyz /= rbond
            # receive the gradients for the link atom
            fL = f_cl[iat_link]
            # dot product fL, xyz
            fs = np.dot(fL, xyz)
            # apply corrections for each direction
            i_change[iat_sys] = 1
            i_change[iat_cl_sys] = 1
            for idir in xrange(3):
                # correct the atom in the system
                f_sys_lo[iat_sys][idir] = f_sys_lo[iat_sys][idir] + f*fL[idir] - f*fs*xyz[idir]
                # correct the atom in the cluster
                f_sys_lo[iat_cl_sys][idir] = f_sys_lo[iat_cl_sys][idir] + (1-f)*fL[idir] + f*fs*xyz[idir]

        if self.print_forces:
            print " TOTAL FORCE (uncorrected : corrected) for link atoms"
            for iat_sys in xrange(len(embed)):
                print "%-2s (" % embed[iat_sys].get_symbol(),
                for idir in xrange(3):
                    print "%10.6f" % f_sys_lo_orig[iat_sys][idir], 
                print ") : (", 
                for idir in xrange(3):
                    print "%10.6f" % f_sys_lo[iat_sys][idir], 
                print ")", 
                if i_change[iat_sys]:
                    print " *", 
                print
            print
        
        return f_sys_lo
