"""
This module is the MOPAC module for ASE, 
based on the PyMD MOPAC Module by R. Bulo, 
implemented by T. Kerber

2011, ENS Lyon
"""
import commands
import numpy as np

from ase.atoms import Atoms
from ase.units import kcal, mol
from ase.calculators.general import Calculator

kcal_mol = kcal / mol
class Mopac(Calculator):
    def __init__(self, files=None, restart=0, geomopt=False, functional="PM6"):
        self.restart = restart
        self.geomopt = geomopt
        self.functional = functional

        self.spin = 0
        self.occupations = None

        self.first = True
        self.stress = None
        self.energy_zero = None
        self.energy_free = None
        self.forces = None
        
    def write_moldata(self):
        """
        Writes the files that have to be written each timestep
        """
        # Change labels variable if necessary
        # FIXME: what does norderats stands for????
        norderats = 0
        for el in self.atoms.get_chemical_symbols():
            norderats += 1
        if norderats != self.atoms.get_number_of_atoms():
            self.labels = self.get_labels()

        input = self.get_input_block()
        pyfile = open('pymopac.in','w')
        pyfile.write(input)
        pyfile.close()

    def run(self):
        self.write_moldata()
        out = commands.getoutput('mopac pymopac.in')
        
        outputfile = open('pymopac.in.out')
        outfile = open('mopac%02i.out'%(self.restart),'w')
        for line in outputfile:
            outfile.write(line)
        outfile.close()
        outputfile.close()

        energy = self.read_energy()
        if energy == None:
            energy = self.rerun()
        self.forces = self.read_forces()
        
        self.energy_zero = energy
        self.energy_free = energy
        self.first = False

    def read_energy(self, charges=True):
        outfile = open('pymopac.in.out')
        lines = outfile.readlines()
        outfile.close()

        chargelines = 0
        if charges:
            nats = len(self.atoms)
            block = ''
        for line in lines:
            if line.find('HEAT OF FORMATION') != -1:
                words = line.split()
                energy = float(words[5])
            if line.find('H.o.F. per unit cell') != -1:
                words = line.split()
                energy = float(words[5])
            if line.find('"""""""""""""UNABLE TO ACHIEVE SELF-CONSISTENCE') != -1:
                energy = None
            if charges:
                if line.find('NET ATOMIC CHARGES') != -1:
                    chargelines += 1
            if chargelines > 0 and chargelines <= nats+3:
                chargelines += 1
                block += line
             
        if charges:
            chargefile = open('charge%02i.out'%(self.restart),'a')
            chargefile.write(block)
            chargefile.close()
 
        return energy

    def read_forces(self):

        outfile = open('pymopac.in.aux')
        lines = outfile.readlines()
        outfile.close()

        outputforces = None
        nats = len(self.atoms)
        nx = nats * 3

        for i,line in enumerate(lines):
            if line.find('GRADIENTS:') != -1:
                forces = []
                l = 0
                for j in range(nx):
                    if j%10 == 0:
                        l += 1
                        gline = lines[i+l]
                    k = j -((l-1)*10)
                    try:
                        forces.append(-float(gline[k*18:(k*18)+18]))
                    except ValueError:
                        if outputforces == None:
                            outputforces = self.read_output_forces()
                        forces.append(outputforces[j])
                break

        forces = np.reshape(forces,  (3, nats))
        forces *= kcal_mol

        return forces

    def read_output_forces(self):
        outfile = open('pymopac.in.out')
        lines = outfile.readlines()
        outfile.close()

        nats = len(self.atoms)
        nx = nats * 3

        forces = []
        for i,line in enumerate(lines):
            if line.find('GRADIENT\n') != -1:
                for j in range(nx):
                    gline = lines[i+j+1]
                    forces.append(-float(gline[49:62]))
                break
        
        return forces
        
    def get_input_block(self):
        # For restarting I can use 'DENOUT' and 'OLDENS' at some point 
        block = ''
        block += '%s '%(self.functional)
        #if self.functional == 'PM6-DH2':
        block += 'NOANCI '
        block += '1SCF GRADIENTS '
        block += 'AUX(0,PRECISION=9) '
        block += 'RELSCF = 0.0001 '
        charge = 0
        if charge != 0:
            block += 'CHARGE=%i '%(charge)
        if self.spin == 1.:
            block += 'DOUBLET '
        elif self.spin == 2.:
            block += 'TRIPLET '
        block += '\n'
        block += 'Title: ASE job\n\n'

        constraints = self.atoms._get_constraints()
        nconstraints = len(constraints)
        for iat in xrange(len(self.atoms)):
            f = [1, 1, 1]
            if iat < nconstraints:
                if constraints[iat] is not None:
                    f = [0, 0, 0]
            atom = self.atoms[iat]
            xyz = atom.position
            block += ' %2s'%atom.symbol
            block += '    %16.5f %i    %16.5f %i    %16.5f %i \n'%(xyz[0],f[0],xyz[1],f[1],xyz[2],f[2])

        if self.atoms.pbc.any():
            for v in self.atoms.get_cell(): 
                block += 'Tv %8.3f %8.3f %8.3f\n'%(v[0],v[1],v[2])
        return block

    def update(self, atoms):
        if not (self.atoms.get_positions() == atoms.get_positions()).all() or self.first:
            self.atoms = atoms.copy()
            self.run()
