"""
This module is the Gaussian module for ASE,  based on the PyMD Gaussian Module by R. Bulo,  
implemented by T. Kerber

 @author:       Rosa Bulo
 @organization: Vrije Universiteit Amsterdam
 @contact:      bulo@few.vu.nl

Torsten Kerber,  ENS LYON: 2011,  07,  11

This work is supported by Award No. UK-C0017,  made by King Abdullah
University of Science and Technology(KAUST)
"""
import os,  sys,  string

import numpy as np
from ase.units import Hartree,  Bohr
from ase.calculators.general import Calculator
import os, sys, re, math, commands, subprocess, time

class Gaussian(Calculator):
    """
    Class for representing the settings of a ReaxffForceJob
        
    It contains the following variables:
                        
    basis:  
        A string representing the basisset to be used

    functional:
        A string representing the DFT functional or wavefunction
        method to be used.

    """

    def __init__(self, functional='BP86', basis='TZVP', atoms=None, inputfile=None):
        self.functional = functional
        self.basis = basis
        self.restartdata = None
        self.forces = None
        self.atoms_old = None
        if inputfile != None:
            self = self.get_settings(inputfile)                
        self.atoms = atoms

    def set_functional(self, functional):
        self.functional = functional

    def set_basis(self, basis):
        self.basis = basis

    def setup_job(self,  type='force'):
        ForceJob.setup_job(self, type=type)
 
        os.chdir(self.dirname)
        self.myfiles = GaussianFileManager()
        # Add the restart restuls to the filemanager. The checksum will be empty('')
        if self.restartdata != None:
            self.r_m = GaussianResults(self.myfiles)
            self.myfiles.add_results(self.r_m, dir=self.restartdata+'/')

        os.chdir(self.dir)

    def write_parfile(self, filename):
        parfile = open(filename, 'w')
        parfile.write(self.prm)
        parfile.close()

    def write_moldata(self,  type='force'):
        out = commands.getoutput('mkdir SETUP')
        out = commands.getoutput('ls SETUP/coord.*').split()
        pdb = False
        psf = False
        prm = False

        # Change labels variable if necessary
        norderats = 0
#        for el in self.labels:
#            norderats += 1
#            if norderats != self.atoms.pdb.get_number_of_atoms():
#                self.labels = self.get_labels()

        self.write_input(type)

    def write_input(self, type):
        input = self.get_input_block(type)
        pyfile = open('ASE-gaussian.com', 'w')
        pyfile.write(input)
        pyfile.close()

    def get_inp(self,  filename):
        input = None
        lis = commands.getoutput('ls %s'%(filename)).split('\n')
        if not re.search('ls:', lis[0]):
            infile = open(filename)
            input = infile.readlines()
            infile.close()
        return input

    def run(self, type='force'):
        startrun = time.time()
        # I have to write the input every timestep,  because that is where the coordinates are
        # self.write_input(type)
        self.write_moldata(type)

        # Make sure that restart data is copied to the current directory
#        if self.r_m != None:
#            self.r_m.files.copy_job_result_files(self.r_m.fileid)

        out = commands.getoutput('g09 ASE-gaussian.com')
        endrun = time.time()
        #print 'timings: ', endrun-startrun
        # Write the output
#        outputfile = open('ASE-gaussian.log')
#        outfile = open('ASE-gaussian.out',  'w')
#        for line in outputfile:
#            outfile.write(line)
#        outputfile.close()
#        outfile.close()
        # End write output
        self.energy = self.read_energy()
        if type == 'force':
            self.forces = self.read_forces()
        self.energy_zero = self.energy

        # Change the results object to the new results
        input = self.get_input_block(type)

    def read_energy(self,  charges=True):
        hartree2kcal = 627.5095

        outfile = open('ASE-gaussian.log')
        lines = outfile.readlines()
        outfile.close()

        chargelines = 0
        if charges:
            nats = len(self.atoms)
            block = ''
        for line in lines:
            if re.search('SCF Done', line):
                words = line.split()
                energy = float(words[4]) * hartree2kcal
            if charges:
                if re.search(' Mulliken atomic charges:', line):
                    chargelines += 1
            if chargelines > 0 and chargelines <= nats+2:
                chargelines += 1
                block += line
   
        if charges:
            chargefile = open('charge.out', 'a')
            chargefile.write(block)
            chargefile.close()

        return energy

    def read_forces(self):
        factor = Hartree / Bohr
        factor = 1.0

        outfile = open('ASE-gaussian.log')
        lines = outfile.readlines()
        outfile.close()

        outputforces = None
        forces = None
        nats = len(self.atoms)

        for iline, line in enumerate(lines):
            if re.search('Forces \(Ha', line):
                forces = np.zeros((nats,  3),  float)
                for iat in range(nats):
                    forceline = lines[iline+iat+3]
                    words = forceline.split()
                    for idir,  word in enumerate(words[2:5]):
                        forces[iat,  idir] = float(word) * factor
                break
        return forces

    def get_input_block(self, type):
        block = ''
        block += '%chk=ASE-gaussian.chk\n'
        block += '%Mem=256MB\n'
        block += '%nproc=32\n'
        block += 'MaxDisk=16gb\n'

        block += '#'
        block += '%s'%(self.functional)
        block += '/'
        block += '%s'%(self.basis)

        if type == 'force':
            block += 'Force '
#        if self.r_m != None:
#            block += 'Guess=Read'

        block += '\n\nGaussian job\n\n'

        charge = sum(self.atoms.get_charges())
        block += '%i '%(charge)
        block += '%i '%(1)
        block += '\n'

        coords = self.atoms.get_positions()
        for i, at in enumerate(self.atoms.get_chemical_symbols()):
            block += ' %2s'%(at)
            coord = coords[i]
            block += '  %16.5f %16.5f %16.5f'%(coord[0], coord[1], coord[2])
            block += '\n'

        if self.atoms.get_pbc().any(): 
            cell = self.atoms.get_cell()
           
            for v in cell: 
                block += 'TV %8.3f %8.3f %8.3f\n'%(v[0], v[1], v[2])
        block += '\n'

        return block


    def get_settings(self,  filename=None):
        settings = GaussianSettings()

        if filename != None or filename == '':
            input = self.get_inp(filename)
        else:
            settings.set_functional("")
            return settings
    
        if input == None:
            settings.set_functional("")
            return settings

        for i, line in enumerate(input):
            if len(line.strip()) == 0:
                continue
            if line[0] == '#':
                keyline = i
                words = line[1:].split()
            else:
                settings.functional = words[0].split('/')[0]
                settings.basis = words[0].split('/')[1]
            break
            return settings
                
    def update(self,  atoms):
        if self.atoms_old != atoms:
            self.atoms = atoms.copy()
            self.atoms_old = atoms.copy()
            self.run()
