""" class Wannier provides the function GetLocalizationMatrix.


"""

import Numeric as num
from DacapoUltraSoftReader import UltraSoftReader

class DacapoWannier:

   def __init__(self,calculator,nbands,spin=0):

      print 'Initialize Dacapo Wannier'
      atoms = calculator.GetListOfAtoms()
      loe = calculator.GetElectronicStates()
      self.loe = loe
      self.SetNumberOfBands(nbands)
      self.SetKPoints(calculator.GetBZKPoints())
      self.SetSpin(spin)
      self.SetListOfWaveFunctions(loe)
      self.SetFFTIndex(loe)
      grid = loe[0].GetArray()
      shape = num.shape(grid)
      self.SetGridDimensions(shape)
      file = calculator.GetNetCDFFile()
      # setup reader for ultra-soft pseudo-potential
      self.usreader = UltraSoftReader(file,spin,nbands)
      print "Stopped initializing"

   def SetKPoints(self,kpoints):
      self.kpoints = kpoints

   def GetKPoints(self):
      return self.kpoints

   def SetSpin(self,spin): 
      self.spin=spin

   def GetSpin(self): 
      return self.spin

   def SetFFTIndex(self,loe):
      fftindex = []
      kpoints = self.GetKPoints()
      for kpt in range(len(kpoints)): 
         state = self.loe.GetState(band=0,kptindex=kpt,spin=0)
         fftindex.append(state.GetReciprocalIndex())
         
      self.fftindex = fftindex

   def GetFFTIndex(self):
      return self.fftindex

   def SetGridDimensions(self,dimensions):
      self.griddimensions = dimensions

   def GetGridDimensions(self):
      return self.griddimensions
      
		
   def SetListOfWaveFunctions(self,loe):
      listofwavefct = []
      kpoints = self.GetKPoints()
      for kpoint in kpoints:
         states = loe.GetStatesKPoint(kpoint,spin=self.GetSpin())
         eigenstates = []
         for state in states[:self.GetNumberOfBands()]:
            if state.GetSpin()==self.GetSpin(): 
            	eigenstates.append(state.GetWaveFunction())

         listofwavefct.append(eigenstates)
         self.listofwavefct = listofwavefct
 
   def GetListOfWaveFunctions(self):
      return self.listofwavefct


   def SetNumberOfBands(self,numberofbands):	
      self.numberofbands = numberofbands
      
   def GetNumberOfBands(self): 
      return self.numberofbands


   def GetZIBlochMatrix(self,dirG,kpoint,nextkpoint,G_I):
      """ calculate matrix of ZIi,j values
      This matrix consist of 3 matrices each of dimension MxM, i.e. corresponding to the full space.
      """

      # print 'DacapoWannier: Initialize ZIBlochMatrix ..'

      M=self.GetNumberOfBands() # ???
     	
      phi=num.swapaxes(num.array(self.GetListOfWaveFunctions()[kpoint]),0,1)

      # K1 and reciprocal lattice vector G_I  given kpoint K
      # that fulfills the criteria : K1-K-K0+G1=0
      list1,list2 = self.GetGGList(kpoint,nextkpoint,G_I)
				
      a=num.take(phi,list1)
      a=num.swapaxes(a,0,1)
      phi1 = num.swapaxes(num.array(self.GetListOfWaveFunctions()[nextkpoint]),0,1)
      b=num.take(phi1,list2)

      ziblochmatrix = num.matrixmultiply(num.conjugate(a),b)
      usziblochmatrix = self.usreader.GetUltraSoftNonLocMatrix(dirG,kpoint,nextkpoint)
      
      ziblochmatrix += usziblochmatrix

      return ziblochmatrix


   def GetGGList(self,kpt1,kpt2,GI):
       """ define list of (G,G+G1) defining the product
       phi(kpt1,G)*phi(kpt2,G+G1),

       GI is one of
       [[1,0,0],[0,1,0],[0,0,1],[1,1,0],[1,0,1],[0,1,1]]
			 
       The layout of fourier components is 
       1   2   3   4   5   6   7   8   ngx = 8 
       0   1   2   3   4  -3  -2  -1	n*2pi/L	
       """

       numberplanewaves = len(self.GetListOfWaveFunctions()[kpt1][0])
       reciprocalindex = self.GetFFTIndex()[kpt1]

       ngrids = self.GetGridDimensions()
			
       # setup the mapping from the 3D FFT grid to the wavefuction list
       map2 = self.GetIndexMap(kpt2) 
		
       gglist = []
       # print "Generating plane wave index list for direction ",GI," kpt12 ",kpt1,kpt2
       list1 = []
       list2 = []
	   
       # find G,G+GI
       for n in range(numberplanewaves):
           index = reciprocalindex[:,n]
	   index = index - 1
	   # find G,G+GI
	   for dir in range(3): 
	       index[dir] += GI[dir]
	       if index[dir]>=ngrids[dir]:
	           # wrap around
		   index[dir] = 0
				
           # now find the corresponding index into phi(kpt2)
           n1 = map2[index[0],index[1],index[2]]

           if n1>=0:
              list1.append(n)
              list2.append(n1)
			
       # print '  Number of elements in GG list ',len(list1)
       return list1,list2

		

   def GetIndexMap(self,kpt):
       """ generate mapping from 3D FFT grid to the wavefunction list
		
       A negative number is returned from map(g1,g2,g3) is the
       grid point does not exists in the wavefunction list
       """
       ngrids = self.GetGridDimensions()
       map_to_wflist = num.zeros(ngrids,num.Int)
       map_to_wflist = map_to_wflist - 1

       numberplanewaves = len(self.GetListOfWaveFunctions()[kpt][0])
       reciprocalindex = self.GetFFTIndex()[kpt]
			
       for n in range(numberplanewaves):
           i0 = reciprocalindex[0][n]-1
	   i1 = reciprocalindex[1][n]-1
	   i2 = reciprocalindex[2][n]-1
	   map_to_wflist[i0,i1,i2] = n

       return map_to_wflist
		
