from ASE.Utilities.Grid import Grid
from ASE.Utilities.VectorSpaces import BravaisLattice
import ASE.Utilities.ElectronicStates
import Numeric as num

class ElectronicState(ASE.Utilities.ElectronicStates.ElectronicState):
    """ class holding a single eigenstate from a dacapo calculation,
    implements its own GetArray and GetWaveFunctionOnGrid,
    which translate the representation in reciprocal space in
    Set/GEtWavefunction to realspace. 
    """

    def GetArray(self,phase=True):
        wf,keywords =  self._GetGridMethod()
        wf.SetBand(keywords['band'])
        wf.SetSpin(keywords['spin'])
        wf.SetKPointNumber(keywords['kpt'])
        wf.ReadFromNetCDFFile(keywords['filename'])
        if self.wavefunction is not None:
            return wf.GetWaveFunction(wavefunction=self.wavefunction,phase=phase)
        else:
            return wf.GetWaveFunction(phase=phase)
        
    def GetGridValues(self,phase=True):
	# Overrides the GetGridValues method from the Grid class in order to allow specification of the phase
	return self.GetArray(phase=phase)
 
    def GetWavefunctionOnGrid(self,phase=True):
        return self.GetArray(phase=phase)

    def GetWaveFunction(self):
        if self.wavefunction is not None:
            return self.wavefunction
        
        wf,keywords =  self._GetGridMethod()
        wf.SetBand(keywords['band'])
        wf.SetSpin(keywords['spin'])
        wf.SetKPointNumber(keywords['kpt'])
        wf.ReadFromNetCDFFile(keywords['filename'])
        self.SetWaveFunction(wf.GetReciprocalBlochFunction())
	return self.wavefunction
            
    def SetWaveFunction(self,wavefunction):
        self.wavefunction = wavefunction

    def _SetGridMethod(self,gridmethod,**keywords):
        self.gridmethod = gridmethod
        self.keywords   = keywords

    def _GetGridMethod(self):
        return self.gridmethod,self.keywords

    def GetReciprocalIndex(self):
        wf,keywords =  self._GetGridMethod()
        wf.SetBand(keywords['band'])
        wf.SetSpin(keywords['spin'])
        wf.SetKPointNumber(keywords['kpt'])
        wf.ReadFromNetCDFFile(keywords['filename'])
        return wf.GetWaveFunctionFFTindex()

    def _Copy(self,**kwargs):
        state = ElectronicState(kwargs)
        state.gridmethod = self.gridmethod
        state.keywords   = self.keywords
        return state

    def Copy(self): 
        state = ElectronicState(kpoint=self.kpoint,kpointindex=self.kpointindex,
                          bandindex=self.bandindex,wavefunction=self.wavefunction,
                          spin=self.spin,energy=self.energy,occ=self.occ,
                          kpointweight=self.kpointweight,
                          unitcell=self.GetUnitCell(),
                          degeneracy=self.degeneracy)
        state.gridmethod = self.gridmethod
        state.keywords   = self.keywords
	return state




class ElectronicStates(ASE.Utilities.ElectronicStates.ElectronicStates):
    """ dacapo ElectronicStates implements its own ReadFromFile
    """

    def __init__(self,filename=None):
        list.__init__(self,[])

        if filename is not None:
            self.ReadFromFile(filename) 


    def ReadFromFile(self,filename):
        """ Read Electronic states in the dacapo NetCDF format.
        """
        from Dacapo import Dacapo

        # read in dacapo calculator
        atoms = Dacapo.ReadAtoms(filename)
        calc = atoms.GetCalculator()

        self.InitializeFromCalculator(calc)

    def InitializeFromCalculator(self,calc):
        from Dacapo.WaveFunction import WaveFunction

        filename = calc.GetNetCDFFile()
        atoms = calc.GetListOfAtoms()
        kpoints = calc.GetIBZKPoints()
        kpointweights = calc.GetIBZKPointWeights()

        if calc.GetSpinPolarized():
            degeneracy = 1
            self.SetNumberOfSpins(2)
        else:
            degeneracy = 2
            self.SetNumberOfSpins(1)

        self.SetNumberOfKPoints(len(kpoints))
        self.SetNumberOfBands(calc.GetNumberOfBands())
        
        unitcell = atoms.GetUnitCell()
        wf = WaveFunction(calc)

        # append states to list
        for spin in range(self.GetNumberOfSpins()): 
           
            for kpt in range(self.GetNumberOfKPoints()):
                eigenvalues = calc.GetEigenvalues(spin=spin,kpt=kpt)
                occupation  = calc.GetOccupationNumbers(spin=spin,kpt=kpt)
                
               

                for band in range(self.GetNumberOfBands()):
                    state = ElectronicState()

                    state._SetGridMethod(wf,band=band,spin=spin,kpt=kpt,filename=filename)

                    # wf.SetBand(band)
                    # wf.SetSpin(spin)
                    # wf.SetKPointNumber(kpt)
                    # wf.ReadFromNetCDFFile(filename)

                    # state.SetWavefunction(wf.GetReciprocalBlochFunction())
                    state.SetEnergy(eigenvalues[band])
                    state.SetOccupationNumber(occupation[band])
                    state.SetKPoint(kpoints[kpt])
                    state.SetKPointIndex(kpt)
                    state.SetBandIndex(band)
                    state.SetUnitCell(unitcell)
                    state.SetSpin(spin)
                    state.SetKPointWeight(kpointweights[kpt])
                    state.SetDegeneracy(degeneracy)

                    self.append(state)

        self.SetSpace(BravaisLattice(unitcell))
        self.SetFermiLevel(calc.GetFermiLevel())
    
if __name__ == "__main__":

    loe = ElectronicStates('Examples/Al100.nc')
    print 'number of kpoints ',loe.GetNumberOfKPoints()
    print 'kpoints ',loe.GetKPoints()

    loe_kpt1 = loe.GetStatesKPoint(loe.GetKPoints()[0])
