# Copyright (C) 2003  CAMP
# Please see the accompanying LICENSE file for further information.

"""Class representing multicenter projection for dacapo

- define multicenter
mulcen = MultiCenterProjectedDOS()
mulcen.add([(0,0,0,1.),(1,0,0,-1.)])


- read multicenter
mulcen = MultiCenterProjectedDOS()
mulcen.read()

"""

__docformat__ = 'reStructuredText'

from ASE.Utilities.Grid import Grid
from Dacapo import NetCDF
from Numeric import zeros,Float


class MultiCenterProjectedDOS(NetCDF.EntryWithStaticName):

       def __init__(self,mcenters=None,
                    energywindow=(-15,5),
                    energywidth=0.2,
                    numberenergypoints=250,
                    cutoffradius=1.0):

            NetCDF.Entry.__init__(self,name="MultiCenterProjectedDOS",value="")

            NetCDF.Entry.__init__(self,name="MultiCenterProjectedDOS",value="")
            self.EnergyWindow = energywindow
            self.EnergyWidth = energywidth
            self.NumberEnergyPoints = numberenergypoints
	    self.CutoffRadius =  cutoffradius

            self._listofmcenters_ = mcenters  # import to use _attrib_ so it is not written to nc file

       def WriteToNetCDFFile(self,file,indices=None):
            """overload the NetCDF entries own method 
               usually called from the Simulation class. 
               The multicenters are written to NetCDGF as one matrix: 
                  M(number_of_multicenters,max_centers,4)
                  M(:                     ,:          ,0) = atomno 
                  M(:                     ,:          ,1) = l 
                  M(:                     ,:          ,2) = m 
                  M(:                     ,:          ,3) = weight
               This matrix is real, to be able to represent the weight's.  
            """

            # get number of multi centers
            ncenters = len(self._listofmcenters_)
            # get max number of orbitals any center 
            max_orbitals = max(map(len,self._listofmcenters_))

            mmatrix = zeros([ncenters,max_orbitals,4],Float)
            ncenter = 0
            for multicenter in self._listofmcenters_: 
                norbital = 0
                for orbital in multicenter: 
                   mmatrix[ncenter,norbital] = orbital 
                   norbital = norbital + 1 
              
                # signal that this multicenter contains less than max_orbital orbitals
                if len(multicenter)<max_orbitals: 
                     mmatrix[ncenter,len(multicenter):max_orbitals] = (-1.0,0,0,0)

                ncenter = ncenter + 1

            self.SetDimensionNames(['number_of_multicenters','max_orbitals','dim4'])
            self.SetValue(mmatrix)

            # NetCDF.Entry(name="MultiCenterProjectedDOS",value=mmatrix,
            # dimensionnames=['number_of_multicenters','max_orbitals','dim4']).WriteToNetCDFFile(file)

            NetCDF.Entry.WriteToNetCDFFile(self,file)


       def ReadFromNetCDFFile(self,file,index=None):


            # get data
            integrated = NetCDF.Entry(name="MultiCenterProjectedDOS_IntegratedDOS")
            integrated.ReadFromNetCDFFile(file)
            self._integrated_ = integrated.GetValue()

            energyresolved = NetCDF.Entry(name="MultiCenterProjectedDOS_EnergyResolvedDOS")
            energyresolved.ReadFromNetCDFFile(file)
            self._energyresolved_ = energyresolved.GetValue()

            energygrid = NetCDF.Entry(name="MultiCenterProjectedDOS_EnergyGrid")
            energygrid.ReadFromNetCDFFile(file)
            self._energygrid_ = energygrid.GetValue()

            self.ReadEFermi(file)
           
            # read rest of attributes in so we can set the description of the multicenter orbitals
            NetCDF.Entry.ReadFromNetCDFFile(self,file)



       def add(self,multicenter):
            self._listofmcenters_.append(multicenter)


       def __repr__(self):
            return "MultiCenterProjectedDOS()"


       def read(self,spin= 0,cutoff = "Infinite"):
            """ method to read multicenter projected dos from a netcdf file. 
                    mulcen.read() return a list of instances of the class 
                    MultiCenterProjectedDOSTool  """

            self._multicenterprojections_ = []

            icut = 1
            if cutoff=="Infinite":
			icut = 0

            # get number of projections/per spin/ncut
            try: 
            	        self._number_of_multicenters_  = self._integrated_.shape[0]
            except: 
			print 'Multi center projections not present in netcdf file'
			return

            # get number of cutoff used 
            self._number_of_cutoff_ = self._integrated_.shape[1]

            # get number of spins
            self._number_of_spin_ = self._integrated_.shape[2]

            for multicenter in range(self._number_of_multicenters_): 
                        orbitals = self.GetValue()[multicenter]
			energyresolveddata = self._energyresolved_[multicenter,icut,spin,:]
			integrateddata     = self._integrated_[multicenter,icut,spin]
                        self._multicenterprojections_.append(MultiCenterProjectedDOSTool(spin=spin,maxspin=self._number_of_spin_, 
                                                    cutoff=icut,multicenter=multicenter, 
                                                    energyresolveddata = energyresolveddata,integrateddata=integrateddata, 
                                                    energygrid=self._energygrid_, efermi=self.GetEFermi(), 
                                                    orbitals = orbitals ))

            print 'Found %d multicenters'%len( self._multicenterprojections_ )

            # return list of projections
            return self._multicenterprojections_ 

       def GetEFermi(self):
            """Returns the Fermi level"""
            return self._efermi_

       def SetEFermi(self,efermi):
            """Sets the Fermi level"""
            if efermi is not None:
                       	self._efermi_=efermi

       def ReadEFermi(self,file):
            """Internal method to read the Fermi level from a netcdf file"""
            efermi=NetCDF.Series(name="FermiLevel").ReadFromNetCDFFile(file).GetValue()[-1:][0]
            self.SetEFermi(efermi)




from AtomProjectedDOS import AtomProjectedDOSTool

class MultiCenterProjectedDOSTool(AtomProjectedDOSTool): 
	""" Define new init method to setup the 
	    methods GetData, GetIntegratedData
        """ 
        def __init__(self,spin,maxspin,cutoff,multicenter,energyresolveddata,integrateddata, 
                       energygrid,efermi,orbitals): 


                self._SetEFermi(efermi)

                # make a list  [(energy,dos) ]
                energylist = []
                factor = 1.0
                if maxspin==2:
                       factor = 2.0
                for i in range(len(energygrid)): 
                       energylist.append((energygrid[i]+self.GetEFermi(),factor*energyresolveddata[i]))

                self.SetData(energylist)

                self.SetIntegratedDOS(integrateddata) 

                self._SetDescription(spin,cutoff,multicenter,orbitals) 

                self._SetSpin(spin) 
                self._SetCutoff(cutoff) 

                self._atoms = None
                self._angularchannels = None

        def GetEFermi(self): 
                return self._efermi_

        def _SetEFermi(self,efermi): 
                self._efermi_ = efermi

        # add method describing this multicenter orbital
        def GetDescription(self): 
		print 'Description : ' 
                print 'number : ',self._description.get('number'), ' cutoff : ',self._description.get('cutoff'), ' spin : ',self._description.get('spin')
                print 'list of orbitals : ' 
                for orb  in self._description.get('orbitals'): 
			print '  atom:',int(orb[0]),' l:',int(orb[1]),' m:',int(orb[2]),'weight:',orb[3]

        def _SetDescription(self,spin,cutoff,multicenter,orbitals): 
		self._description = {}
                self._description.update({'spin':spin})
                self._description.update({'cutoff':cutoff})
                self._description.update({'number':multicenter})
                self._description.update({'orbitals':orbitals})

        def _SetSpin(self,spin): 
		self._numberofspin = spin 

        def GetSpin(self): 
		print 'Spin : ',self._numberofspin

        def _SetCutoff(self,cutoff): 
		self._cutoffradius = cutoff 

        def GetCutoff(self,cutoff): 
		print 'Cutoff : ',self._cutoffradius

