
from ASE.Utilities.Grid import Grid
import Numeric as num
from ASE.Utilities.ArrayTools import Translate
from ASE.Utilities.VectorSpaces import BravaisLattice

def Dagger(array):
	return num.transpose(num.conjugate(array))

class InitialWannierFunctions:

	def __init__(self,data,griddim,listofatoms,fftindex,kpointgrid=[1,1,1]):
		self.SetData(data)
		self.SetListOfAtoms(listofatoms)
		self.SetFFTIndex(fftindex)
                self.SetUnitCell(listofatoms.GetUnitCell())
		self.SetGridDimensions(griddim)
                self.SetKPointGrid(kpointgrid)


        def SetKPointGrid(self,kpointgrid):
                self.kpointgrid=kpointgrid

        def GetKPointGrid(self):
                return self.kpointgrid
		
	def SetData(self,data):
		self.data=data

	def GetData(self):
		return self.data

	def SetFFTIndex(self,fftindex):
		self.fftindex = fftindex

	def GetFFTIndex(self):
		return self.fftindex

	def SetDetailedData(self,detaileddata):
		"""List defining each of the atomic orbitals. An element is a tuple
		[r_c,l,m,a], where
		r_c : is center of orbital (cartesian coordinates)
		l,m : quantum number for angular monemtum (l=0,1,2. m=-l,...,l)
		a   : the width of the radial Gaussian. Rule of thomb put a=0.5*Covalent radius of species"""
		self.detaileddata=detaileddata
	
	def GetDetailedData(self):
		if not hasattr(self,'detaileddata'):
			self.ConvertDataToDetailedData()
		return self.detaileddata

	def ConvertDataToDetailedData(self):
		""" Coverts data on the form
		[atom,l,(m),a], or [center,l,(m),a].
                atom is a number in a bracket, example [2] for atom number 2
                center is 3 numbers in a bracket, example [1,0.5,0.3] denoting SCALED coordinates of the center.
                Here m is optional. If m is left out all m values are used."""
		datalist=self.GetData()
		detaileddata=[]
		atoms=self.GetListOfAtoms()
                cell=BravaisLattice(basis=self.GetUnitCell())
		for data in datalist:
			if len(data[0])==1:
				r_c=num.array(atoms[data[0][0]].GetCartesianPosition())
			elif len(data[0])==3:
				r_c=cell.CartesianCoordinatesFromCoordinates(data[0])
			else:
				print "First element in initial data must be of the form [atom] or [c1,c2,c3], where the latter is scaled coordinates of the center"
			if len(data)==4:
				# m is specified	
				detaileddata.append([r_c,data[1],data[2],data[3]])
			else:
				# Orbitals with all allowed m values are produced
				for m in range(-data[1],data[1]+1):
					detaileddata.append([r_c,data[1],m,data[2]])
		self.SetDetailedData(detaileddata)

	def GetListOfAtoms(self):
		return self.listofatoms

	def SetListOfAtoms(self,listofatoms):
		self.listofatoms=listofatoms

	def SetGridDimensions(self,griddim):
		self.griddim=griddim

	def GetGridDimensions(self):
		return self.griddim

        def GetRepeatedGridDimensions(self):
                return self.GetGridDimensions()*num.array(self.GetKPointGrid())

	def SetMMatrix(self,mmatrix):
		self.mmatrix=mmatrix
	
	def GetMMatrix(self):
		return self.mmatrix
	
	def GetDistanceArrayAtOrigin(self):
		from ASE.Utilities.Vector import Vector		
		from ASE.Utilities.ArrayTools import Translate
		import copy
	
		unitcell=self.GetRepeatedUnitCell()
		griddim=self.GetRepeatedGridDimensions()
		distgrid=Grid(array=num.zeros(griddim,num.Complex),space=BravaisLattice(unitcell))
		# Set origin to center of grid
		originindex=map(lambda coor:int(coor),distgrid.GridCoordinatesFromCoordinates([0.5,0.5,0.5]))
        	# origincoord is in scaled coordinates:
		origincoord=-num.array(distgrid.CoordinatesFromGridCoordinates(originindex))
        	distgrid.SetOrigin(Vector(origincoord,distgrid.GetSpace()))
		c=distgrid.GetCartesianCoordinates()
		dist=num.sqrt(c[0]**2+c[1]**2+c[2]**2)
		# Translate back to origin
		dist=Translate(dist,originindex) 		
		return dist

	def SetUnitCell(self,unitcell):
		self.unitcell=unitcell

	def GetUnitCell(self):
		return self.unitcell

	def GetRepeatedUnitCell(self):
                n1,n2,n3=self.GetKPointGrid()
                basis=self.GetUnitCell()
		return num.transpose(num.transpose(basis)*num.array([n1,n2,n3]))
            
	def GetListOfCoefficientsAndRotationMatrices(self,matrixdimensions):
		from ASE.Utilities.Wannier.Localize import Project,Normalize,GramSchmidtOrthonormalize,LowdinOrthonormalize
		import random
		import copy

                M,N,L=matrixdimensions
		nkpt=len(N)
		Ulist=[]
		clist=[]
		coeffmatrix=self.GetMMatrix()

                for kpt in range(nkpt):
                    #First normalize the columns of coeffmatrix
                    coeffmatrix[kpt]=Normalize(coeffmatrix[kpt])
                    T=coeffmatrix[kpt][N[kpt]:].copy()
                    numberoforbitals=T.shape[1]
                    c=num.zeros([M-N[kpt],L[kpt]],num.Complex)
                    U=num.zeros([N[kpt]+L[kpt],N[kpt]+L[kpt]],num.Complex)
                    # Initialize weights
                    w=abs(num.sum(T*num.conjugate(T)))
                    for i in range(min(L[kpt],numberoforbitals)):
			# print "K-point",kpt,". Projection onto unocc. subspace:"
			# print w
			# Find index of maximal element in w
			t=w.tolist().index(max(w))
			# print t
			c[:,i]=T[:,t]
			# Orthogonalize c[:,i] on previous vectors
			for j in range(i):
				c[:,i]=c[:,i]-Project(c[:,j],T[:,t])
			c[:,i]=c[:,i]/num.sqrt(num.dot(c[:,i],num.conjugate(c[:,i])))
			# Update weights
			w=w-abs(num.matrixmultiply(num.conjugate(c[:,i]),T))**2
                    if numberoforbitals<L[kpt]:
			# Supplement c by random vectors
			for i in range(numberoforbitals,L[kpt]):
                            for j in range(M-N[kpt]):
                                c[j,i]=random.random()
			c=GramSchmidtOrthonormalize(c)
                    # Test whether columns are orthonormal
                    if L[kpt]>0:
                        test=self.GetOrthonormalityFactor(c)
                        if test>1.0e-3:
                            print "ERROR: Columns of c not orthogonal!"		
                    U[:N[kpt],:numberoforbitals]=coeffmatrix[kpt][:N[kpt]]
                    U[N[kpt]:,:numberoforbitals]=num.matrixmultiply(Dagger(c),coeffmatrix[kpt][N[kpt]:])
                    # Perform democratic Lowdin orthogonalization on U[:,numberoforbitals]
		    #U[:,:numberoforbitals]=LowdinOrthonormalize(U[:,:numberoforbitals])
                    U[:,:numberoforbitals]=GramSchmidtOrthonormalize(U[:,:numberoforbitals])
                    if numberoforbitals<(N[kpt]+L[kpt]):
			#Supplement U by random vectors
			for i in range(numberoforbitals,N[kpt]+L[kpt]):
                            for j in range(N[kpt]+L[kpt]):
                                U[j,i]=random.random()
			# Finally orthogonalize everything
			# Note, only random vectors are affected
			U=GramSchmidtOrthonormalize(U)
                    # Test whether columns are orthonormal
                    test=self.GetOrthonormalityFactor(U)
                    if test>1.0e-3:
			print "ERROR: Columns of U not orthogonal for kpoint",kpt	
                    Ulist.append(U)
                    clist.append(c)
                return clist,Ulist
            
	def GetOrthonormalityFactor(self,matrix):
		defect=abs(num.matrixmultiply(Dagger(matrix),matrix))-num.identity(matrix.shape[1],num.Float)
		test=max(abs(defect.flat))
		return test

	def SetupMMatrix(self,listofeigenstates,bzkpoints):
		from Dacapo import FunctionSpaces
		from Dacapo import Operators
		from ASE.Utilities.Vector import Vector	

		fftindex = self.GetFFTIndex()

		# Initialize operator for translations
		unitcell=self.GetRepeatedUnitCell()
		transspace=FunctionSpaces.PlaneWaveSpace3D(self.GetRepeatedGridDimensions())
		transop=Operators.Translation(space=transspace,unitcell=unitcell)
		transvector=Vector(space=BravaisLattice(unitcell))
		data=self.GetDetailedData()
                Nkpoints=len(bzkpoints)
                Nbands=len(listofeigenstates[0])
		M=num.zeros([Nkpoints,Nbands,len(data)],num.Complex)
		griddim=self.GetRepeatedGridDimensions()
		orbital=num.zeros([griddim[0],griddim[1],griddim[2]],num.Complex)
		dist=self.GetDistanceArrayAtOrigin()
                rec_cell=BravaisLattice(basis=self.GetUnitCell()).GetReciprocalBravaisLattice()
                largerec_cell=BravaisLattice(basis=self.GetRepeatedUnitCell()).GetReciprocalBravaisLattice()
		for i in range(len(data)):
			# Translate orbital
			r_c=data[i][0]
			l,m=data[i][1],data[i][2]
			a=data[i][3]
			transvector.SetCartesianCoordinates(r_c)
			transop.SetTranslationVector(transvector)
			orbital=self.GetCubicHarmonicAtOrigin(l,m)*num.exp(-dist/a)
			orbital_fft=transspace.CoordinatesFromFunctionValues(orbital)
			orbital_fft=transop.Operate(orbital_fft)
                        for kpt in range(Nkpoints):
                            kpoint=bzkpoints[kpt]
                            kptnumber=largerec_cell.CoordinatesFromCartesianCoordinates(rec_cell.CartesianCoordinatesFromCoordinates(kpoint))
                            kptnumber[0]=round(kptnumber[0])
                            kptnumber[1]=round(kptnumber[1])
                            kptnumber[2]=round(kptnumber[2])
                            kptnumber=kptnumber.astype(int)
                            u_k=self.ExtractPeriodicPartOfSmallCell(orbital_fft,kptnumber)
                            compact_u_k=self.GetCompactFFTRepresentation(u_k,fftindex[kpt],len(listofeigenstates[kpt][0]))
                            M[kpt,:,i]=num.dot(num.conjugate(num.array(listofeigenstates[kpt])),compact_u_k)
                self.SetMMatrix(M)
	
        def ExtractPeriodicPartOfSmallCell(self,f,k):
                n1,n2,n3=self.GetKPointGrid()
                trans=[0,0,0]
                if k[0]<0:
                    k[0]+=n1
                    trans[0]=self.GetGridDimensions()[0]-1
                if k[1]<0:
                    k[1]+=n2
                    trans[1]=self.GetGridDimensions()[1]-1
                if k[2]<0:
                    k[2]+=n3
                    trans[2]=self.GetGridDimensions()[2]-1
                
                u=f[k[0]::n1,k[1]::n2,k[2]::n3].copy()
                return Translate(u,trans)

	def GetCompactFFTRepresentation(self,freciprocal,fftindex,numberofpws):
	    	wflist=num.zeros([numberofpws],num.Complex)
		for i in range(numberofpws):
			wflist[i]=freciprocal[int(fftindex[0,i]-1),int(fftindex[1,i]-1),int(fftindex[2,i]-1)]
	    	return wflist

	def GetScaledCoordinatesAndOriginIndex(self):
		if not hasattr(self,'nc'):
			from ASE.Utilities.Vector import Vector		
			from ASE.Utilities.ArrayTools import Translate
			import copy
	
			unitcell=self.GetRepeatedUnitCell()
			griddim=self.GetRepeatedGridDimensions()
			orbitalgrid=Grid(array=num.zeros(griddim,num.Complex),space=BravaisLattice(unitcell))
			# Set origin to center of grid
			originindex=map(lambda coor:int(coor),orbitalgrid.GridCoordinatesFromCoordinates([0.5,0.5,0.5]))
        		# origincoord is in scaled coordinates:
			origincoord=-num.array(orbitalgrid.CoordinatesFromGridCoordinates(originindex))
        		orbitalgrid.SetOrigin(Vector(origincoord,orbitalgrid.GetSpace()))
			c=orbitalgrid.GetCartesianCoordinates()
			dist=num.sqrt(c[0]**2+c[1]**2+c[2]**2)

			# We define "normalized" coordinates. To avoid undeterminancy at origin we move
			# to the point (1,1,1)*1e-8
			c[0][originindex]=1.0e-8
			c[1][originindex]=1.0e-8
			c[2][originindex]=1.0e-8
			dist[originindex]=num.sqrt(3)*1.0e-8
			nc=c/dist
			self.nc=nc
			self.originindex=originindex	
		return self.nc,self.originindex

	def GetCubicHarmonicAtOrigin(self,l,m):
		""" l=0,1,2. m=-l,...,l"""
		from ASE.Utilities.ArrayTools import Translate
		
		griddim=self.GetRepeatedGridDimensions()
		harmonic=num.zeros([griddim[0],griddim[1],griddim[2]],num.Complex)
		nc,originindex=self.GetScaledCoordinatesAndOriginIndex()		
		
		# Constructing cubic harmonic
        	if l==0 and m==0:
        		harmonic=(1/num.sqrt(4*num.pi))*num.ones(nc[0].shape,num.Complex)
			harmonic=Translate(harmonic,originindex) 
        	if l==1 and m==0:
		# p_x
        		harmonic=num.sqrt(3/(4*num.pi))*nc[0]
			harmonic=Translate(harmonic,originindex) 
        	if l==1 and m==-1:
		# p_z
        		harmonic=num.sqrt(3/(4*num.pi))*nc[2]
			harmonic=Translate(harmonic,originindex) 
        	if l==1 and m==1:
		# p_y
        		harmonic=num.sqrt(3/(4*num.pi))*nc[1]
			harmonic=Translate(harmonic,originindex) 
        	if l==2 and m==0:
        		harmonic=0.5*num.sqrt(5/(4*num.pi))*(3*(nc[0]**2)-num.ones(nc[0].shape,num.Complex))
			harmonic=Translate(harmonic,originindex) 
        	if l==2 and m==-1:
        		harmonic=num.sqrt(15/(16*num.pi))*(nc[2]**2-nc[1]**2)
			harmonic=Translate(harmonic,originindex) 
        	if l==2 and m==1:
        		harmonic=num.sqrt(15/(4*num.pi))*nc[0]*nc[1]
			harmonic=Translate(harmonic,originindex) 
        	if l==2 and m==-2:
        		harmonic=num.sqrt(15/(4*num.pi))*nc[2]*nc[1]
			harmonic=Translate(harmonic,originindex) 
        	if l==2 and m==2:
        		harmonic=num.sqrt(15/(4*num.pi))*nc[2]*nc[0]
			harmonic=Translate(harmonic,originindex) 
        	return harmonic

