"""
@author:       Torsten Kerber and Paul Fleurat-Lessard
@organization: Ecole Normale Superieure de Lyon
@contact:      Paul.Fleurat-Lessard@ens-lyon.fr

@data: June, 6, 2012

This work is supported by Award No. UK-C0017 made by King Abdullah University of Science and Technology(KAUST)
"""
import os, sys
import numpy as np

from ase.units import Hartree, Bohr
from ase.calculators.general import Calculator

str_keys = [
    'functional', # functional
    'basis',      # basis set
    'flags'       # more options
    'memory',     # RAM memory used
    'disk',       # disk space available
    'filename',   # name of the gaussian files
    'command'     # command to run Gaussian
    ]

int_keys = [
    'multiplicity', # multiplicity of the system
    'nproc',        # number of processors
    ]

class Gaussian(Calculator):
    def __init__(self, atoms=None, **kwargs):
        self.set_standard()
        
        self.set(**kwargs)
        self.atoms = atoms
        
        self.potential_energy = 0.0
        self.forces = None

        self.new_calculation = True
        if atoms != None:
            self.forces = np.zeros((len(atoms), 3))
        self.stress = None
        
    def set_standard(self):
        self.str_params = {}
        self.int_params = {}
        
        for key in str_keys:
            self.str_params['key'] = None
        
        for key in int_keys:
            self.int_params['key'] = None
        
        self.str_params['functional'] = "B3LYP"
        self.str_params['basis'] = '6-31G*'
        self.str_params['flags'] = ' Force '
        self.str_params['memory'] = '60MW'
        self.str_params['disk'] = '16GB'
        
        self.str_params['command'] = 'g09'
        self.str_params['filename'] = 'ASE-gaussian'

        self.int_params['multiplicity'] = 1
        self.int_params['nproc'] = 8
        
        
    def set(self, **kwargs):
        for key in kwargs:
            if self.str_params.has_key(key):
                self.str_params[key] = kwargs[key]
            if self.int_params.has_key(key):
                self.int_params[key] = kwargs[key]
                
                
    def set_atoms(self, atoms):
        if self.atoms != atoms:
            self.atoms = atoms.copy()
            self.new_calculation = True
    
    def write_file_line(self, file, line):
        file.write(line)
        file.write('\n')
    
    def write_input(self):
        name = self.str_params['filename']
        
        inputfile = open(name + '.com', 'w')
        self.write_file_line(inputfile, '%chk=' + name + '.chk')
        self.write_file_line(inputfile, '%mem=' + self.str_params['memory'])
        self.write_file_line(inputfile, '%nproc=' + ('%d' % self.int_params['nproc']))
        self.write_file_line(inputfile, 'MaxDisk=' + self.str_params['disk'])

        line = '#' + self.str_params['functional'] + '/' + self.str_params['basis'] + ' ' + self.str_params['flags']
        self.write_file_line(inputfile, line)
        
        self.write_file_line(inputfile, '')
        self.write_file_line(inputfile, 'Gaussian job created by ASE (supported by King Abdullah University of Science and Technology ,KAUST)')
        self.write_file_line(inputfile, '')

        line = '%d %d' % (sum(self.atoms.get_charges()), self.int_params['multiplicity'])
        self.write_file_line(inputfile, line)

        coords = self.atoms.get_positions()
        for i, at in enumerate(self.atoms.get_chemical_symbols()):
            line = at
            coord = coords[i]
            line += '  %16.5f %16.5f %16.5f'%(coord[0], coord[1], coord[2])
            self.write_file_line(inputfile, line)

        if self.atoms.get_pbc().any(): 
            cell = self.atoms.get_cell()
            line = ''
            for v in cell: 
                line += 'TV %8.3f %8.3f %8.3f\n'%(v[0], v[1], v[2])
            self.write_file_line(inputfile, line)
            
        self.write_file_line(inputfile, '')
        
    def update(self, atoms):
        if (self.atoms is None or self.atoms.positions.shape != atoms.positions.shape):
            self.atoms = atoms.copy()
            self.calculate(atoms)

    def calculate(self, atoms):
        self.write_input()
        line = self.str_params['command'] + ' ' + self.str_params['filename'] + '.com'
        exitcode = os.system(line)
        if exitcode != 0:
            raise RuntimeError('Gaussian exited with exit code: %d.  ' % exitcode)
        self.read_output()
        self.new_calculation = False
            
    def get_potential_energy(self, atoms=None, force_consistent=False):
        self.set_atoms(atoms)
        if self.new_calculation:
            self.calculate(atoms)
        return self.potential_energy

    def get_forces(self, atoms):
        self.set_atoms(atoms)
        if self.new_calculation:
            self.calculate(atoms)
        return self.forces

    def read_output(self):
        outfile = open(self.str_params['filename'] + '.log')
        lines = outfile.readlines()
        outfile.close()

        factor = Hartree
        for line in lines:
            if 'SCF Done' in line:
                line = line.split()
                self.potential_energy = float(line[4]) * factor

        factor = Hartree / Bohr
        nats = len(self.atoms)
        for iline, line in enumerate(lines):
            if 'Forces (Hartrees/Bohr)' in line:
                self.forces = np.zeros((nats, 3), float)
                for iat in range(nats):
                    line = lines[iline+iat+3].split()
                    for idir, val in enumerate(line[2:5]):
                        self.forces[iat, idir] = float(val) * factor
                break
