# $Id: import_data.py 38 2017-09-20 13:35:58Z tjiang $
import os, sys, argparse
from ase import Atoms
from ase.io import read
from ase.db import connect
from ase.calculators.vasp import Vasp
from ase.constraints import FixAtoms
from numpy.linalg import norm
from bz2 import BZ2File
from gzip import GzipFile


def import_vasp_calculations(path,con,include_dir=[],exclude_dir=[],use_poscar=False,creator=os.environ['USER']):
    log = open('success.log','w')
    errlog = open('err.log','w')
    _all = os.walk(os.path.abspath(path))
    calculator=None
    xc_dict = {'8': 'PBE', '91': 'PW91', 'CA': 'LDA'}
    read_all_dir = 0
    if len(include_dir) == 0:
        read_all_dir = 1
    for files in _all:
        vasp_out = [i for i in range(len(files[2])) if ('OUTCAR' in files[2][i] and files[2][i][:6]=='OUTCAR')]
        condition1 = (read_all_dir or files[0].split('/')[-1] in include_dir)
        #condition2 = not files[0].split('/')[-1] in exclude_dir
        condition3 = [ i for i, x in enumerate(exclude_dir) if x in files[0] ] 
        if len(vasp_out) != 0 and condition1 and not condition3 :
            ncalc = 0
            for f in vasp_out:
                try:
            	    Atoms = read(files[0]+'/'+files[2][f])
                    if use_poscar:
                        constraints = read('POSCAR').constraints
                        Atoms.set_constraint(constraints)
                        print 'get constraint from poscar'
                    else:
                        Atoms0 = read(files[0]+'/'+files[2][f], index=0)
                        pos0 = Atoms0.positions
                        pos = Atoms.positions
                        diff = pos - pos0
                        disp = norm(diff, axis=1)
                        if disp.sum() > 0:
                            print 'find constraint from outcar'
                            mask = [_disp == 0 for _disp in disp]
                            constraints = FixAtoms(mask=mask)
                            Atoms.set_constraint(constraints)
                        else:
                            print 'no constraint'
                            constraints = None
                    #os.system('cp '+files[0]+'/'+files[2][f]+' OUTCAR')
                    #calc = Vasp()
                    #calc.atoms = Atoms
                    #calc.sort = list(range(len(Atoms)))
                    #calc.resort = list(range(len(Atoms)))
                    #calc.read_outcar()
                    #Atoms = calc.atoms
                    #nbands = calc.get_number_of_bands()
                    ##xc = calc.get_xc_functional()
                    #os.system('rm -f OUTCAR')

                    #read version number, XC, ENCUT etc.
                    if files[2][f][-4:] == '.bz2':
                        fobj = BZ2File(files[0]+'/'+files[2][f])
                    elif files[2][f][-3:] == '.gz':
                        fobj = GzipFile(files[0]+'/'+files[2][f])
                    else:
                        fobj = open(files[0]+'/'+files[2][f])
                    with fobj as outcar:
                        version = outcar.readline()
                        line = outcar.readline()
                        read_para = 0
                        lsol = False
                        ldipol = False
                        while line != '':
                            if line.startswith(' INCAR:'):
                                line = outcar.readline().split()
                                if line[1] == 'PAW_PBE':
                                    pot = 'PAW'
                                else:
                                    pot = line[1]
                            elif line.startswith(' Dimension of arrays'):
                            #if line.startswith(' Dimension of arrays'):
                                read_para = 1
                            elif read_para:
                                if 'LEXCH' in line:
                                    xc_flag = line.split()[2].upper()
                                    if xc_flag not in xc_dict.keys():
                                        raise ValueError('Unknown xc-functional flag found in POTCAR,'
                                                         ' LEXCH=%s' % xc_flag)
                                    xc = xc_dict[xc_flag]
                                elif 'NKPTS' in line:
                                #if 'NKPTS' in line:
                                    nkpts = int(line.split()[3])
                                elif 'ENCUT' in line:
                                    encut = float(line.split()[2])
                                elif 'ENAUG' in line:
                                    enaug = float(line.split()[2])
                                elif 'NBANDS' in line:
                                    nbands = int(line.split()[-1])
                                elif 'EDIFFG' in line:
                                    f_limit = -float(line.split()[2])
                                elif 'LSOL' in line:
                                    if line.split()[2] == 'T':
                                        lsol = True
                                elif 'LDIPOL' in line:
                                    if line.split()[2] == 'T':
                                        ldipol = True
                                elif '-'*104 in line:
                                    break
                            line = outcar.readline()
                    if constraints is not None:
                        fmax = ((Atoms.get_forces()) ** 2).sum(1).max()**0.5
                        if fmax < f_limit:
                            print 'updating the database with minimized calculation from file: ', fmax, files[2][f]
                            if use_poscar:
                    	        con.write(Atoms,functional = xc,path=files[0],code='VASP',filename=files[2][f], version=version, potential=pot, encut=encut, enaug=enaug, lsol=lsol, ldipol=ldipol, nkpts=nkpts, constraint=True)
                            else:
                    	        con.write(Atoms,functional = xc,creator=creator,path=files[0],code='VASP',filename=files[2][f], version=version, potential=pot, encut=encut, enaug=enaug, lsol=lsol, ldipol=ldipol, nkpts=nkpts, constraint=True)
                        else:
                            print 'not updating the database as fmax is too big for file: ', fmax, files[2][f]
                    else:
                            print 'updating the database with calculation', files[2][f]
                    	    con.write(Atoms,functional = xc,creator=creator,path=files[0],code='VASP',filename=files[2][f], version=version, potential=pot, encut=encut, enaug=enaug, lsol=lsol, ldipol=ldipol, nkpts=nkpts, constraint=False)
                    print >> log, files[0], files[2][f]
                except (IndexError, ValueError):
                    print >> errlog, files[0], files[2][f]
                
        log.flush()
        errlog.flush()
                 
#def import_adf_calculations(path,con):
#    log = open('success.log','w')
#    errlog = open('err.log','w')
#    _all = os.walk(path)
#    calculator=None
#    for files in _all:
#        vasp_out = [i for i in range(len(files[2])) if ('OUTCAR' in files[2][i] and files[2][i][:6]=='OUTCAR')]
#        if len(vasp_out) != 0:
#            print files[0]
#            ncalc = 0
#            for f in vasp_out:
#            	calculator = 'vasp'
#                try:
#            	    Atoms = read(files[0]+'/'+files[2][f])
#            	    con.write(Atoms,creator='rkerber',path=files[0],code='vasp',filename=files[2][f])
#                    print >> log, files[0]
#                except (IndexError, ValueError):
#                    print >> errlog, files[0]
#                
#        log.flush()
#        errlog.flush()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("-p", help="Use POSCAR in the same directory to determine the constraint", action="store_true")
    args = parser.parse_args()
    if args.p:
        use_poscar = True
    else:
        use_poscar = False

    con = connect('test.db')
    #import_vasp_calculations('/data/users/tjiang/rkerber', con)
    #import_vasp_calculations('.', con, include_dir=['<opt_dir>'], exclude_dir=['<to_exclude_dir1>', '<to_exclude_dir2>'])
    import_vasp_calculations('.', con, use_poscar=use_poscar)
