# Copyright (C) 2003  CAMP
# Please see the accompanying LICENSE file for further information.

"""Class representing a wavefunction from a dacapo run
This class is responsible for translating the reciprocal
representation of the dacapo wavefunction in the netcdffile
to real space
"""

__docformat__ = 'reStructuredText'

from ASE.Utilities.Grid import Grid
from ASE.Utilities import ArrayTools
from Dacapo import NetCDF
from ASE.Utilities.VectorSpaces import BravaisLattice
from Scientific.IO.NetCDF import NetCDFFile

# Uses:	
#	Numeric
#	Scientific.IO.NetCDF

class WaveFunction(Grid):
	"""Class representing a single eigenstate from a dacapo run"""

        def __init__(self,calculator):

                self.calculator = calculator
		unitcell = calculator.GetListOfAtoms().GetUnitCell()
		kpoint   = calculator.GetIBZKPoints()[0]

		# Grid.__init__(self,space=BravaisLattice(unitcell))
		Grid.__init__(self,space=BravaisLattice(unitcell))
		self.SetKPoint(kpoint,unitcell)
		
		# set defaults for band,spin and kpointnumber
		self.SetBand(band=0)
                self.SetKPointNumber(kpointnumber=0)
                self.SetSpin(spin=0)

		self.ReadFromNetCDFFile(calculator.GetNetCDFFile())


	def ReadFromNetCDFFile(self,ncfile,index=None):
		"""Method to read a wavefunction"""
		# ReadNumberPlaneWaves must preceed ReadBlochFunction
		file = NetCDFFile(ncfile)
		self.SetNetCDFFilename(ncfile)
		self.ReadNumberOfPlaneWaves(file)
		self.ReadBlochFunction()
                file.close()
		
		
	def ReadBlochFunction(self):
		"""Method for reading the wave function from a netcdf file"""
		file = self.GetNetCDFFile()
		self.ReadReciprocalBlochFunction(file)
		reciprocal = self.GetReciprocalBlochFunctionGrid() 
		self.SetGridValues(ArrayTools.FFT(reciprocal))
                file.close()

	def SetBlochFunction(self,reciprocalblochfunction):
		"""Method for generation the wave function from a
		reciprocal wavefunction"""
		self.SetReciprocalBlochFunction(reciprocalblochfunction)
		reciprocal = self.GetReciprocalBlochFunctionGrid() 
		self.SetGridValues(ArrayTools.FFT(reciprocal))
		

        def GetReciprocalBlochFunction(self): 
		return self._wflist_

        def SetReciprocalBlochFunction(self,wflist): 
		self._wflist_ = wflist

	def SetWaveFunctionFFTindex(self,reciprocalindex): 
		self._wavefunctionindex_ = reciprocalindex

	def GetWaveFunctionFFTindex(self):
		return self._wavefunctionindex_ 

	def GetGridDimensions(self):
		return self._griddimensions

	def SetGridDimensions(self,N1,N2,N3):
		self._griddimensions = N1,N2,N3


	def ReadReciprocalBlochFunction(self,file):
		"""Method for reading the reciprocal Bloch function"""
		from Numeric import zeros,Complex,Int

                file = self.GetNetCDFFile()

		band=self.GetBand()
		kpointnumber=self.GetKPointNumber()
		spin=self.GetSpin()

		# Reading the blochfunction
		blochfunction=file.variables["WaveFunction"][kpointnumber,spin,band]
		wflist=zeros(self.GetNumberOfPlaneWaves(),Complex)
		wflist.real=blochfunction[0:self.GetNumberOfPlaneWaves(),1]
		wflist.imaginary=blochfunction[0:self.GetNumberOfPlaneWaves(),0]

		self.SetReciprocalBlochFunction(wflist)

		# Reading the FFT grid
		N1=NetCDF.Dimension("softgrid_dim1").ReadFromNetCDFFile(file).GetSize()
		N2=NetCDF.Dimension("softgrid_dim2").ReadFromNetCDFFile(file).GetSize()
		N3=NetCDF.Dimension("softgrid_dim3").ReadFromNetCDFFile(file).GetSize()
		self.SetGridDimensions(N1,N2,N3)
		
		wfreciprocal=zeros([N1,N2,N3],Complex)

		# Mapping to 3D array
		# NOTE: The reciprocalindex is explicitly assured to be of 
		# type Int. Ohterwise the array assignment will not work
		reciprocalindex=file.variables["WaveFunctionFFTindex"][kpointnumber,:,:].astype(Int)

		self.SetWaveFunctionFFTindex(reciprocalindex)

                file.close()

	def GetReciprocalBlochFunctionGrid(self):
		from Numeric import zeros,Complex,shape

		wflist = self.GetReciprocalBlochFunction()

		N1,N2,N3 = self.GetGridDimensions()
		wfreciprocal=zeros([N1,N2,N3],Complex)

		reciprocalindex = self.GetWaveFunctionFFTindex()
		
		for i in xrange(len(wflist)):
			wfreciprocal[reciprocalindex[0,i]-1,reciprocalindex[1,i]-1,reciprocalindex[2,i]-1]=wflist[i]

		return wfreciprocal

	def GetBlochFunction(self):
		"""Returns the Bloch function

		This method returns the Bloch function. This is the object
		which is used by the class internally and is hence also 
		accessable via the method 'GetGridValues'.
		"""
		return self.GetGridValues()


	def GetWaveFunction(self,band=None,kpt=None,spin=None,wavefunction=None,kpoint=None,phase=True):
		"""Returns the wave function

		This method returns the wave function, i.e. the Bloch function
		multiplied by the phase factor 'exp(ikx)'.
		"""
		from Numeric import multiply,fromfunction,dot,exp,sqrt

		# check arguments
		if wavefunction is not None:
			if (band or spin or kpt):
				raise ValueError("wavefunction can not be given together " +
				"with band,spin or k-point number")

		if kpt and kpoint:
			raise ValueError("kpt number and kpoint can not be given at the " +
			        "at the same time")
	       
		if wavefunction is not None:
			self.SetBlochFunction(wavefunction)
			if kpoint is not None:
				self.SetKPoint(kpoint,self.GetSpace())
		else: 
			
			if band is not None:
				self.SetBand(band)
			if spin is not None:
				self.SetSpin(spin)
				
			if kpt is not None: 
				kpoint   = self.calculator.GetIBZKPoints()[kpt]
                                print 'GetWaveFunction: kpt = ',kpoint
				self.SetKPoint(kpoint,self.GetSpace())
				self.SetKPointNumber(kpt)

			if kpoint is not None:
				self.SetKPoint(kpoint,self.GetSpace())

			self.ReadBlochFunction()
	
		if phase is False:
		# Return the periodic part of the wavefunction
			return self.GetBlochFunction()
	
		# Defining the phase function (using cartesian coordinates)
		phasefunction=lambda coor,kpoint=self.GetKPoint().GetCartesianCoordinates(),dot=dot,exp=exp:exp(1.0j*dot(kpoint,coor))

		# Calculating the Bloch phase at the origin of the grid
		blochphase=phasefunction(self.GetOrigin().GetCartesianCoordinates())
		gridunitvectors=self.GetGridUnitVectors()

		for dim in range(len(self.GetSpatialShape())):
			# Multiplying with the phase at the origin
			deltaphase=phasefunction(gridunitvectors[dim])
			# and calculating phase difference between each point
			newphase=fromfunction(lambda i,phase=deltaphase:pow(phase,i),(self.GetSpatialShape()[dim],))
			blochphase=multiply.outer(blochphase,newphase)

		vol = self.GetSpace().GetVolumeOfCell()
		return multiply(blochphase,self.GetBlochFunction())/sqrt(vol)


	def ReadNumberOfPlaneWaves(self,file):
		"""Reads the number of plane waves from a netcdf file"""
		nplanewave=NetCDF.Entry(name="NumberPlaneWavesKpoint").ReadFromNetCDFFile(file).GetValue()[self.GetKPointNumber()]
		self.SetNumberOfPlaneWaves(nplanewave)

	def GetNumberOfPlaneWaves(self):
		"""Returns the number of plane waves"""
		return self._nplanewave_

	def SetNumberOfPlaneWaves(self,nplanewave):
		"""Sets the number of plane waves"""
		if nplanewave is not None:
			self._nplanewave_=nplanewave		

	def GetAbsoluteValues(self):
		"""Returns the absolute values of the wavefunction"""
		from Numeric import absolute
		# The absolute values do not involve the phase
		return absolute(self.GetGridValues())

	def GetPhaseValues(self):
		"""Returns an array with the phase of the wave function"""
		from Numeric import log
		# By copying the grid it is made contiguous
		return copy.copy(log(self.GetWaveFunction()).imaginary)

	def SetKPoint(self,kpoint,unitcell):
		""" define a instance of the KPoint class from the
		scaled k-point kpoint and the unitcell"""
		from Dacapo.KPoints import KPoint

                blattice =  BravaisLattice(self.GetSpace().GetBasis())
		reclattice=blattice.GetReciprocalBravaisLattice()
                self.kpoint = KPoint(coordinates=kpoint,space=reclattice)

	def GetKPoint(self):
		""" return an instance of the KPoint class """
		return self.kpoint

	def SetBand(self,band):
		""" define the electronic band"""
		self.band = band
		
	def GetBand(self):
		return self.band

	def SetKPointNumber(self,kpointnumber):
		""" define the kpoint number in the IBZ"""
		self.kpointnumber = kpointnumber

	def GetKPointNumber(self):
		return self.kpointnumber

	def SetSpin(self,spin):
		""" define the electronic spin"""
		self.spin = spin

	def GetSpin(self):
		return self.spin

	def SetNetCDFFilename(self,ncfilename):
		self.ncfilename = ncfilename
		
	def GetNetCDFFilename(self):
		return self.ncfilename

	def GetNetCDFFile(self):
                file = NetCDFFile(self.GetNetCDFFilename())
		return file


