# Copyright (C) 2003  CAMP
# Please see the accompanying LICENSE file for further information.

"""

"""

__docformat__ = 'reStructuredText'

from math import sqrt

import Numeric as num
from Dacapo import NetCDF
from Scientific.IO.NetCDF import NetCDFFile

def Dagger(matrix,copy=1):
        # First change the axis: (Does not allocate a new array)
        matrix_conj=num.swapaxes(matrix,0,1)
        if copy: # Allocate space for new array
                return num.conjugate(matrix_conj)
        else:    # The array of matrix is used for output
                num.multiply(matrix_conj.imag,-1,matrix_conj.imag)
                return matrix_conj

class UltraSoftReader: 

    def __init__(self, filename1,spin,nbands,filename2=None): 

        file=NetCDFFile(filename1,'r')
        self.spin = spin
	self.numberofbands=nbands

        if filename2 is not None:
            augfile = NetCDFFile(filename2,'r')
        else:
            augfile = NetCDFFile(filename1,'r')
                                 
        try:
            waugfactor = NetCDF.Entry(name='WannierAugFactor').ReadFromNetCDFFile(augfile)
            self.is_ultrasoft = True
        except KeyError:
            self.is_ultrasoft = None
            return

        self.SetWannierAugFactor(waugfactor.GetValue())

        nlprojpsi = NetCDF.Entry(name='NLProjectorPsi').ReadFromNetCDFFile(file)
        self.SetNLProjectorPsi(nlprojpsi.GetValue())

        numberprojectors = NetCDF.Entry(name='NumberOfNLProjectors').ReadFromNetCDFFile(file)
        self.SetNumberOfProjectors(numberprojectors.GetValue())

        self.SetNumberOfAtoms()

        structurefactor = NetCDF.Entry(name='StructureFactor').ReadFromNetCDFFile(file)
        self.SetStructureFactor(structurefactor.GetValue()) 


    def SetNumberOfAtoms(self):
        natoms = num.shape(self.nlprojpsi)[2]
        self.natoms = natoms

    def GetNumberOfAtoms(self):
        return self.natoms

    def SetStructureFactor(self,value):
        strfactor = num.zeros((num.shape(value[:,:,0])),num.Complex)
        strfactor.real = value[:,:,0]
        strfactor.imag = value[:,:,1]
        self.strfactor = strfactor

    def GetStructureFactor(self):
        return self.strfactor

    def SetNLProjectorPsi(self,value):
        nlprojpsi = num.zeros((num.shape(value[:,:,:,:,:,0])),num.Complex)
        nlprojpsi.real = value[:,:,:,:,:,0]
        nlprojpsi.imag = value[:,:,:,:,:,1]
        self.nlprojpsi = num.swapaxes(nlprojpsi,3,4)
        # print 'shape nlprojpsi ',num.shape(self.nlprojpsi)

    def ReadAugFactor(self,file): 
        waugfactor = NetCDF.Entry(name='WannierAugFactor').ReadFromNetCDFFile(file)
        self.SetWannierAugFactor(waugfactor.GetValue())
        
    def SetWannierAugFactor(self,value):
        augfactor = num.zeros((num.shape(value[:,:,:,:,0])),num.Complex)
        augfactor.real = value[:,:,:,:,0]
        augfactor.imag = value[:,:,:,:,1]
	self.augfactor = num.swapaxes(augfactor,0,1) 
	
        #shape = num.shape(augfactor) 
        #natoms = shape[2]
        #nproj  = num.shape(augfactor)[0]
        #for zi in range(3): 
        #    for atom in range(natoms): 
        #        for i in range(nproj):
        #            for j in range((i+1),nproj):
        #                h = augfactor[i,j,atom,zi]
        #                # augfactor[j,i,atom,zi] = num.conjugate(h)

    def SetNumberOfProjectors(self,value):
        self.nproj = value

    def GetNumberOfProjectors(self,atom=0):
        return self.nproj[atom]


    def GetNLProjectorPsi(self,atom=0,kpoint=0,spin=0): 
	# dimensions: number_IBZ_kpoints 
        #             number_of_spin
        #             number_of_dynamic_atoms
        #             max_projectors_per_atom 
        #             number_of_bands
        #             real_complex
	return self.nlprojpsi[kpoint,self.spin,atom,:self.numberofbands,:]
                                      


    def GetWannierAugFactor(self,atom=0,index=0):
        # dimensions: max_projectors_per_atom
        #             max_projectors_per_atom
        #             number_of_dynamic_atoms
        #             dim3
        #             real_complex
        return self.augfactor[:,:,atom,index]



    def GetUltraSoftNonLocMatrix(self,GI,kpt,kpt1):
        """ calculate
              a                            I        I              I
             W    = sum(n,m,I) <psi  | beta  > <beta  | psi   > * q
              i,j                  ik      m        n      jk1     mn

             n,m : projectors
             I   : atom no
             a (nbands,nbands) matrix is returned.
        """

        lst=[[1,0,0],[0,1,0],[0,0,1],[1,1,0],[1,0,1],[0,1,1]]

        # find direction corresponding to GI
        dir = lst.index(GI) 
       
        if not self.is_ultrasoft:  
            return None

        natoms = self.GetNumberOfAtoms()


        nbands=self.numberofbands
        matrix=num.zeros([nbands,nbands],num.Complex)
        # print "Constructing ultrasoft matrix for G_I: ",GI," dir:",dir," (k,k')",kpt,kpt1
	for atom in range(natoms):

            if kpt==kpt1:
                q=num.conjugate(self.GetWannierAugFactor(atom=atom,index=dir+1))
            else:
                q=self.GetWannierAugFactor(atom=atom,index=0)

                if dir<3:
                    strfactor = self.GetStructureFactor()[atom]
                    q=q*num.conjugate(strfactor[dir])

            # print 'shape q ',num.shape(q)
	    # print 'q = ',max(max(abs(q)))
            A=self.GetNLProjectorPsi(atom=atom,kpoint=kpt)
            B=self.GetNLProjectorPsi(atom=atom,kpoint=kpt1)
            # print 'matrix ',kpt,num.matrixmultiply(A,num.matrixmultiply(q,Dagger(B)))[0][0]
	    matrix=matrix+num.matrixmultiply(A,num.matrixmultiply(q,Dagger(B)))
        return matrix
            
          
            
    
