import numpy as np

from ase.old import OldASEListOfAtomsWrapper

try:
    import Numeric as num
except ImportError:
    pass

def np2num(a, typecode=None):
    if num.__version__ > '23.8':
        return num.array(a, typecode)
    if typecode is None:
        typecode = num.Float
    b = num.fromstring(a.tostring(), typecode)
    b.shape = a.shape
    return b

def restart(filename, **kwargs):
    calc = Dacapo(filename, **kwargs)
    atoms = calc.get_atoms()
    return atoms, calc

class Dacapo:
    def __init__(self, filename=None, stay_alive=False, stress=False,
                 **kwargs):

        self.kwargs = kwargs
        self.stay_alive = stay_alive
        self.stress = stress
        
        if filename is not None:
            from Dacapo import Dacapo
            self.loa = Dacapo.ReadAtoms(filename, **kwargs)
            self.calc = self.loa.GetCalculator()
        else:
            self.loa = None
            self.calc = None

        self.pps = []
        
    def set_pp(self, Z, path):
        self.pps.append((Z, path))

    def set_txt(self, txt):
        if self.calc is None:
            self.kwargs['txtout'] = txt
        else:
            self.calc.SetTxtFile(txt)

    def set_nc(self, nc):
        if self.calc is None:
            self.kwargs['out'] = nc
        else:
            self.calc.SetNetCDFFile(nc)

    def update(self, atoms):
        from Dacapo import Dacapo
        if self.calc is None:
            if 'nbands' not in self.kwargs:
                n = sum([valence[atom.symbol] for atom in atoms])
                self.kwargs['nbands'] = int(n * 0.65) + 4

            magmoms = atoms.get_initial_magnetic_moments()
            if magmoms.any():
                self.kwargs['spinpol'] = True

            self.calc = Dacapo(**self.kwargs)

            if self.stay_alive:
                self.calc.StayAliveOn()
            else:
                self.calc.StayAliveOff()

            if self.stress:
                self.calc.CalculateStress()

            for Z, path in self.pps:
                self.calc.SetPseudoPotential(Z, path)

        if self.loa is None:
            from ASE import Atom, ListOfAtoms
            numbers = atoms.get_atomic_numbers()
            positions = atoms.get_positions()
            magmoms = atoms.get_initial_magnetic_moments()
            self.loa = ListOfAtoms([Atom(Z=numbers[a],
                                         position=positions[a],
                                         magmom=magmoms[a])
                                    for a in range(len(atoms))],
                                   cell=np2num(atoms.get_cell()),
                                   periodic=tuple(atoms.get_pbc()))
            self.loa.SetCalculator(self.calc)
        else:
            self.loa.SetCartesianPositions(np2num(atoms.get_positions()))
            self.loa.SetUnitCell(np2num(atoms.get_cell()), fix=True)
            
    def get_atoms(self):
        atoms = OldASEListOfAtomsWrapper(self.loa).copy()
        atoms.set_calculator(self)
        return atoms
    
    def get_potential_energy(self, atoms):
        self.update(atoms)
        return self.calc.GetPotentialEnergy()

    def get_forces(self, atoms):
        self.update(atoms)
        return np.array(self.calc.GetCartesianForces())

    def get_stress(self, atoms):
        self.update(atoms)
        stress = np.array(self.calc.GetStress())
        if stress.ndim == 2:
            return stress.ravel()[[0, 4, 8, 5, 2, 1]]
        else:
            return stress

    def calculation_required(self, atoms, quantities):
        if self.calc is None:
            return True

        if atoms != self.get_atoms():
            return True

        return False
        
    def get_number_of_bands(self):
        return self.calc.GetNumberOfBands()

    def get_k_point_weights(self):
        return np.array(self.calc.GetIBZKPointWeights())

    def get_number_of_spins(self):
        return 1 + int(self.calc.GetSpinPolarized())

    def get_eigenvalues(self, kpt=0, spin=0):
        return np.array(self.calc.GetEigenvalues(kpt, spin))

    def get_fermi_level(self):
        return self.calc.GetFermiLevel()

    def get_number_of_grid_points(self):
        return np.array(self.get_pseudo_wave_function(0, 0, 0).shape)

    def get_pseudo_density(self, spin=0):
        return np.array(self.calc.GetDensityArray(s))
    
    def get_pseudo_wave_function(self, band=0, kpt=0, spin=0, pad=True):
        kpt_c = self.get_bz_k_points()[kpt]
        state = self.calc.GetElectronicStates().GetState(band=band, spin=spin,
                                                         kptindex=kpt)

        # Get wf, without bloch phase (Phase = True doesn't do anything!)
        wave = state.GetWavefunctionOnGrid(phase=False)

        # Add bloch phase if this is not the Gamma point
        if np.all(kpt_c == 0):
            return wave
        coord = state.GetCoordinates()
        phase = coord[0] * kpt_c[0] + coord[1] * kpt_c[1] + coord[2] * kpt_c[2]
        return np.array(wave) * np.exp(-2.j * np.pi * phase) # sign! XXX

        #return np.array(self.calc.GetWaveFunctionArray(n, k, s)) # No phase!

    def get_bz_k_points(self):
        return np.array(self.calc.GetBZKPoints())

    def get_ibz_k_points(self):
        return np.array(self.calc.GetIBZKPoints())

    def get_wannier_localization_matrix(self, nbands, dirG, kpoint,
                                        nextkpoint, G_I, spin):
        return np.array(self.calc.GetWannierLocalizationMatrix(
            G_I=G_I.tolist(), nbands=nbands, dirG=dirG.tolist(),
            kpoint=kpoint, nextkpoint=nextkpoint, spin=spin))
    
    def initial_wannier(self, initialwannier, kpointgrid, fixedstates,
                        edf, spin):
        # Use initial guess to determine U and C
        init = self.calc.InitialWannier(initialwannier, self.atoms,
                                        np2num(kpointgrid, num.Int))

        states = self.calc.GetElectronicStates()
        waves = [[state.GetWaveFunction()
                  for state in states.GetStatesKPoint(k, spin)]
                 for k in self.calc.GetIBZKPoints()] 

        init.SetupMMatrix(waves, self.calc.GetBZKPoints())
        c, U = init.GetListOfCoefficientsAndRotationMatrices(
            (self.calc.GetNumberOfBands(), fixedstates, edf))
        U = np.array(U)
        for k in range(len(c)):
            c[k] = np.array(c[k])
        return c, U

valence = {
'H':   1,
'B':   3,
'C':   4,
'N':   5,
'O':   6,
'Li':  1,
'Na':  1,
'K':   9,
'Mg':  8,
'Ca': 10,
'Sr': 10,
'Al':  3,
'Ga': 13,
'Sc': 11,
'Ti': 12,
'V':  13,
'Cr': 14,
'Mn':  7,
'Fe':  8,
'Co':  9,
'Ni': 10,
'Cu': 11,
'Zn': 12,
'Y':  11,
'Zr': 12,
'Nb': 13,
'Mo':  6,
'Ru':  8,
'Rh':  9,
'Pd': 10,
'Ag': 11,
'Cd': 12,
'Ir': 9,
'Pt': 10,
'Au': 11,
}
