"""
This module is the EMBED module for ASE
implemented by T. Kerber

Torsten Kerber, ENS LYON: 2011, 07, 11

This work is supported by Award No. UK-C0017, made by King Abdullah
University of Science and Technology (KAUST)
"""
import ase
import ase.atoms
import numpy as np
from general import Calculator
from ase.embed import Embed
from ase.units import Hartree

from copy import deepcopy

import sys, os

class Qmx(Calculator):
    def __init__(self, calculator_high_cluster,  calculator_low_cluster,  calculator_low_system=None,  print_forces=False):
        self._constraints=None
        
        self.calculator_low_cluster = calculator_low_cluster
        self.calculator_low_system = calculator_low_system
        if self.calculator_low_system is None:
            self.calculator_low_system = deepcopy(calculator_low_cluster)
        self.calculator_high_cluster = calculator_high_cluster
        self.print_forces = print_forces

    def get_energy_subsystem(self, path, calculator, atoms, force_consistent):
        # writing output
        line = "running energy in: " + path + "\n"
        sys.stderr.write(line)
        # go to directory and calculate energies
        os.chdir(path)
        atoms.set_calculator(calculator)
        energy = atoms.get_potential_energy()
        # return path and result
        os.chdir("..")
        return energy

    def get_forces_subsystem(self, path, calculator, atoms):
        # writing output
        line = "running forces in: " + path + "\n"
        sys.stderr.write(line)
        # go to directory and calculate forces
        os.chdir(path)
        atoms.set_calculator(calculator)
        forces = atoms.get_forces()
        # return path and result
        os.chdir("..")
        return forces

    def get_potential_energy(self, embed, force_consistent=False):
        # perform energy calculations
        e_sys_lo = self.get_energy_subsystem("system.low-level", self.calculator_low_system, embed.get_system(), force_consistent)
        e_cl_lo  = self.get_energy_subsystem("cluster.low-level", self.calculator_low_cluster, embed.get_cluster(), force_consistent)
        e_cl_hi  = self.get_energy_subsystem("cluster.high-level", self.calculator_high_cluster, embed.get_cluster(), force_consistent)
        # calculate energies
        energy = e_sys_lo - e_cl_lo + e_cl_hi
        # print energies
        print "%20s = %15s - %15s + %15s" %("E(C:S)", "E(S,LL)", "E(C,LL)", "E(C,HL)")
        print "%20f = %15f - %15f + %15f" %(energy, e_sys_lo, e_cl_lo, e_cl_hi)
        # set energies and return
        if force_consistent:
            self.energy_free = energy
            return self.energy_free
        else:
            self.energy_zero = energy
            return self.energy_zero

    def get_forces(self, embed):
        atom_map_sys_cl = embed.atom_map_sys_cl
        # get forces for the three systems
        f_sys_lo = self.get_forces_subsystem("system.low-level", self.calculator_low_system, embed.get_system())
        f_cl_lo  = self.get_forces_subsystem("cluster.low-level", self.calculator_low_cluster, embed.get_cluster())
        f_cl_hi  = self.get_forces_subsystem("cluster.high-level", self.calculator_high_cluster, embed.get_cluster())

        # forces correction for the atoms
        f_cl = f_cl_hi - f_cl_lo

        if self.print_forces:
            cluster=embed.get_cluster()
            print "Forces: System LOW - Cluster LOW + Cluster HIGH"
            for iat_sys in xrange(len(embed)):
                print "%-2s (" % embed[iat_sys].get_symbol(),
                for idir in xrange(3):
                    print "%10.6f" % f_sys_lo[iat_sys][idir], 
                print ") <system LOW>"

                iat_cl = atom_map_sys_cl[iat_sys]
                if iat_cl > -1:
                    print "%s" % "-  (",
                    for idir in xrange(3):
                        print "%10.6f" % f_cl_lo[iat_cl][idir], 
                    print ") <cluster LOW>"
                    print "%s" % "+  (",
                    for idir in xrange(3):
                        print "%10.6f" % f_cl_hi[iat_cl][idir], 
                    print ") <cluster HIGH>"
                print
            print
    
        # lo-sys + (hi-lo)
        for iat_sys in xrange(len(embed)):
            iat_cl = atom_map_sys_cl[iat_sys]
            if iat_cl > -1:
                f_sys_lo[iat_sys] += f_cl[iat_cl]
        # some settings for the output
        i_change = np.zeros(len(embed), int)
        if self.print_forces:
            f_sys_lo_orig = f_sys_lo.copy()
        # correct gradients
        # Reference: Eichler, Koelmel, Sauer, J. of Comput. Chem., 18(4). 1997, 463-477.
        for cell_L, iat_cl_sys, iat_sys, r, iat_link in embed.linkatoms:
            # calculate the bond distance (r_bond) at the border
            xyz = embed[iat_sys].get_position() - embed[iat_cl_sys].get_position() + cell_L
            # calculate the bond lenght and the factor f
            rbond = np.sqrt(np.dot(xyz, xyz))
            f = r / rbond
            #normalize xyz
            xyz /= rbond
            # receive the gradients for the link atom
            fL = f_cl[iat_link]
            # dot product fL, xyz
            fs = np.dot(fL, xyz)
            # apply corrections for each direction
            i_change[iat_sys] = 1
            i_change[iat_cl_sys] = 1
            for idir in xrange(3):
                # correct the atom in the system
                f_sys_lo[iat_sys][idir] = f_sys_lo[iat_sys][idir] + f*fL[idir] - f*fs*xyz[idir]
                # correct the atom in the cluster
                f_sys_lo[iat_cl_sys][idir] = f_sys_lo[iat_cl_sys][idir] + (1-f)*fL[idir] + f*fs*xyz[idir]

        if self.print_forces:
            print " TOTAL FORCE (uncorrected : corrected) for link atoms"
            for iat_sys in xrange(len(embed)):
                print "%-2s (" % embed[iat_sys].get_symbol(),
                for idir in xrange(3):
                    print "%10.6f" % f_sys_lo_orig[iat_sys][idir], 
                print ") : (", 
                for idir in xrange(3):
                    print "%10.6f" % f_sys_lo[iat_sys][idir], 
                print ")", 
                if i_change[iat_sys]:
                    print " *", 
                print
            print
        
        return f_sys_lo

    def get_stress(self, atoms):
    	return None
