# Copyright (C) 2003  CAMP
# Please see the accompanying LICENSE file for further information.


"""Class representing the atomic projected density of states

The atomic projected density of states corresponding to a single
calculation is read using the ReadAtoms method. 

Example:
-----------------------------------------------------------------------------
loa = Dacapo.ReadAtoms('outp1x1.nc')
calc = loa.GetCalculator()

# the sum of all angular channels for atom 1. 
atom1_all = calc.GetLDOS(atoms=[1],angularchannels=['s','p','d']) 
# the sum of all d channels for atom 1. 
atom1_d   = calc.GetLDOS(atoms=[1],angularchannels=['d'])

# make a combined plot
plot1 = atom1_all.GetPlot()
plot2 = atom1_d.GetPlot(parent=plot1) 
plot1.Update()
-----------------------------------------------------------------------------

Parameters for GetDOS method: 
* atoms           - list of atoms numbers (default is all atoms) 
* angularchannels - list of angular channel names. 
                    The full list of names  :
                     's', 'p_x', 'p_y', 'p_z','d_zz','dxx-yy', 'd_xy',
		     'd_xz', 'd_yz' 
                    'p' and 'd' can be used as shorthand for all p and
		    all d channels  
                    respectively. 
                    (default is all d channels) 
* spin            - list of spins (default all spins) 
* cutoffradius    - 'short' or 'infinite' 
                    For cutoffradius = 'short' the integrals are
		    truncated at 1 Angstrom.  
                    For cutoffradius = 'infinite' the integrals are
		    not truncated.  
                    (default 'infinite') 

The resulting DOS are added over the members of the three list
(atoms,angularchannels and spin).  

GetDOS returns a instance of the class AtomProjectedDOSTool. 
Methods for AtomProjectedDOSTool: 
* GetPlot()        - returns a GnuPlotAvartar for the energy resolved DOS
                     A parent can be given too combine plot. 
* GetIntegratedDOS - returns the integral up to the fermi energy 
* GetData          - returns the energy resolved DOS 

* GetBandMoment        - returns the first moment of the projected DOS
(center of energy) 
* GetBandMoment(1,2)   - returns the first and second moment of the
projected DOS (center and width) 
* SaveData(filename)   - saves the data to the filename
* GetGracePlot()       - makes a GracePlot (if GracePlot is installed)

"""

__docformat__ = 'reStructuredText'

from Dacapo import NetCDF

import Numeric,operator
from Scientific.Functions.Interpolation import InterpolatingFunction
from Scientific.IO import ArrayIO
try:
	from GracePlot.GracePlot import *
except:
	pass

class AtomProjectedDOS:

	def __init__(self): 
                pass

	def ReadFromNetCDFFile(self,file,index=None):
		"""Method to read data from a netcdf file"""
		self.ReadEFermi(file)
                self.ReadAtomProjectedDOS(file)

	def GetEFermi(self):
		"""Returns the Fermi level"""
		return self._efermi_

	def SetEFermi(self,efermi):
		"""Sets the Fermi level"""
		if efermi is not None:	
			self._efermi_=efermi

	def ReadEFermi(self,file):
		"""Internal method to read the Fermi level from a
		netcdf file""" 
		efermi=NetCDF.Series(name="FermiLevel").ReadFromNetCDFFile(file).GetValue()[-1:][0]
		self.SetEFermi(efermi)

        def ReadAtomProjectedDOS(self,file): 
                """ Internal method to read atom projected density of
		states from a netcdf file  """ 

                self._atomdosprojections = 0
                # Does the atom projected density exists in the necdf file 
                self._atomprojecteddos_integrated=NetCDF.Entry(name="AtomProjectedDOS_IntegratedDOS").ReadFromNetCDFFile(file).GetValue()

                # get number of projections/per spin 
                self._atomdosprojections  = self._atomprojecteddos_integrated.shape[0] 

                # get energy resolved dos 
                self._atomprojecteddos_energyresolved=NetCDF.Entry(name="AtomProjectedDOS_EnergyResolvedDOS").ReadFromNetCDFFile(file).GetValue()


                ordinalmap = NetCDF.Entry(name="AtomProjectedDOS_OrdinalMap").ReadFromNetCDFFile(file)

                # get ordinal map 
                self._atomprojecteddos_ordinalmap=ordinalmap.GetValue()

                # get angular channels names s,px,py..etc
                self._atomprojecteddos_ordinalmap_angularchannels_names = ordinalmap.GetAttributes()['AngularChannels']

               
                # get ordinal map angularchannels names 
                self._atomprojecteddos_ordinalmap=NetCDF.Entry(name="AtomProjectedDOS_OrdinalMap").ReadFromNetCDFFile(file).GetValue()

                # get energy grid 
                self._atomprojecteddos_energygrid=NetCDF.Entry(name="AtomProjectedDOS_EnergyGrid").ReadFromNetCDFFile(file).GetValue()


        def GetDOS(self,**keywords):
                """Returns an instance of the AtomProjectedDOSTool
		class
		
                The Atomic projection are read from the dacapo NetCDF
		variables
		
                The accepted keywords are:

                * 'atom' : list of atom numbers for which the
		projection are added. 

                * 'orbitals' : list of orbital names for which
		projections are made. Each projection are added.
		
                * 'cutoffradius' : 'short','long'  (default long) 
                """

                # list of allowed keywords 
                allkeywords = ['atoms','angularchannels','cutoffradius','spin']

                # check keywords 
                for key in keywords.keys(): 
                        try: 
                            allkeywords.index(key) 
                        except: 
                            print 'Allowed keywords : ',allkeywords
                            return
                        
                # defining defaults 
                atoms,angularchannels,cutoffradius,spin = None,None,None,None

                if keywords.has_key("atoms"):
                        atoms=keywords["atoms"]
                if keywords.has_key("angularchannels"):
                        angularchannels= keywords["angularchannels"]
                if keywords.has_key("cutoffradius"):
                        cutoffradius=keywords["cutoffradius"]
                if keywords.has_key("spin"):
                        spin=keywords["spin"]

                atomprojdos = AtomProjectedDOSTool(atomprojecteddos=self,atoms=atoms,angularchannels=angularchannels,cutoffradius=cutoffradius,spin=spin)
		
                return atomprojdos

class AtomProjectedDOSTool: 
        """ The Atomic projection are read from the dacapo NetCDF variables
            
            * GetPlot() 
              returns a gnuplot of the energy resolved dos for this
	      set of atoms and  
              orbitals
            * GetIntegratedDOS
              returns the integrated dos up to the fermi level for
	      this set of atoms and 
              orbitals 
            * GetData
              returns the energy resolved data

            * GetMoments
	      returns first,second,third and fourth moment of the
	      atomprojecteddos 
	      

        """

        def __init__(self,atomprojecteddos,atoms=None,angularchannels=None,cutoffradius=None,spin=None):

                from Numeric import zeros,Float
                import re
                import string 

                self.ados = atomprojecteddos

                # test if atom projections are present
                if self.ados._atomdosprojections==0: 
		    print 'Atom projected DOS is not present in the NetCDF file'
		    return

                fermilevel = self.ados.GetEFermi()

                # number of atoms 
                nions = self.ados._atomprojecteddos_ordinalmap.shape[0]

                # number of spins 
                numberofspin = self.ados._atomprojecteddos_integrated.shape[1]
                self._numberofspin = numberofspin

                # get list of angularchannels names 
                namelist = re.split(',',string.strip(self.ados._atomprojecteddos_ordinalmap_angularchannels_names))
                # removed white spaces
                self._angularchannelsnames = map(string.strip,namelist) 

                if atoms==None: 
                      atoms = []
                      for atom in range(nions):
                          atoms.append(atom+1)

                
                if angularchannels==None:
                      angularchannels = ['d_zz', 'dxx-yy', 'd_xy', 'd_xz', 'd_yz'] # all d default 

                # check if the shorthand notation 'p' or 'd' is
		# present and expand if neccesary 
                try:  
                      angularchannels.index('p')
                      angularchannels.remove('p') 
                      angularchannels.append('p_x')                          
                      angularchannels.append('p_y')                          
                      angularchannels.append('p_z')                          
                except ValueError: 
                      pass

                try:  
                      angularchannels.index('d')
                      angularchannels.remove('d') 
                      angularchannels.append('d_zz') 
                      angularchannels.append('dxx-yy')
                      angularchannels.append('d_xy')
                      angularchannels.append('d_xz')
                      angularchannels.append('d_yz')
                except ValueError: 
                      pass

                # convert angularchannels names into a map index (0-8) array
                angularchannels_mapindex = [] 
                for name in angularchannels: 
                    try: 
                          angularchannels_mapindex.append(self._angularchannelsnames.index(name))
                    except ValueError: 
                          print 'Angular Channel could not be found: ',name 
                          print 'Accepted values : ',self._angularchannelsnames 
                          print 'Plus d (all d angular channels) and p (all p angular channels) '
                          return
                      
                if spin==None: 
                      if numberofspin>1: 
                          spin = [0,1]
                      else: 
                          spin = [0]

                self._atoms = atoms 
                self._angularchannels = angularchannels
                self._spin = spin

                # set cutoffradius 
                self._cutoffradius = 'infinite'
                icut = 0

                if cutoffradius=='short' :
                      self._cutoffradius = 'short'
                      icut = 1 

                # check atoms 
                if max(atoms)>nions or min(atoms)<1: 
                      print 'Atom numbers should be between 1 and ',nions
                      return

                # check spin 
                if min(spin)<0 or max(spin)>numberofspin-1: 
                      if (numberofspin==2): 
                          print 'spin should be 0 or 1' 
                      else: 
                          print 'Calculation not spin-polarized: spin should be 0 '
                          spin = [0]

                #print 'Atoms: ',atoms, ' angular channels: ',angularchannels,'  Spin: ',spin,'  Cutoffradius: ',self._cutoffradius

                # get energy grid size 
                energygridsize = self.ados._atomprojecteddos_energygrid.shape[0]

                energyresolveddata = zeros((energygridsize),Float) 

                integrateddos = 0.0
                # loop over atoms and angular channels
                for atom in atoms:
                   for angularchannel in angularchannels_mapindex: 
                      # look up in the ordinalmap 
                      index = int(self.ados._atomprojecteddos_ordinalmap[atom-1,icut,angularchannel])
                      # loop over spin
                      for spinnumber in spin: 
                            integrate = self.ados._atomprojecteddos_integrated[index,spinnumber]
                            integrateddos = integrateddos + integrate
 
                            # add energy resolved data 
                            energyresolveddata = energyresolveddata + self.ados._atomprojecteddos_energyresolved[index,spinnumber]

                self.SetIntegratedDOS(integrateddos)

                # make a list  [(energy,dos) ]
                energylist = []
                factor = 1.0
                if numberofspin==2: 
                       factor = 2.0
                for i in range(len(self.ados._atomprojecteddos_energygrid)): 
                       energylist.append((self.ados._atomprojecteddos_energygrid[i]+fermilevel,factor*energyresolveddata[i])) 

                self.SetData(energylist)

        def  GetEFermi(self): 
               return self.ados.GetEFermi()

        def  GetIntegratedDOS(self): 
               """ returns integrated dos """
               return  self._integrateddos

        def  SetIntegratedDOS(self,integrateddos): 
               """ define integrated dos """
               self._integrateddos = integrateddos

        def GetData(self): 
               """ returns data array for energy resolved dos """
               return self._energyresolveddata

        def SetData(self,energyresolveddata): 
               """ set data array for energy resolved dos """
               self._energyresolveddata = energyresolveddata

        def GetPlot(self,parent=None): 
               """ returns Gnuplot of energy resolved dos, using basic GnuplotAvatar """ 
               from Visualization.Avatars.GnuplotAvatars import GnuplotAvatar
               import Gnuplot
               title = 'Atoms: ' + str(self._atoms) + '  Angular channels: ' + str(self._angularchannels)
               if self._numberofspin > 0:  
                     title = title + '   Spin: ' + str(self._spin) 

               if self._cutoffradius=='short': 
                     title = title + '   Cutoffradius: ' + str(self._cutoffradius) 

               plotitem = Gnuplot.Data(self.GetData(),with='lines',title=title)
               plot = GnuplotAvatar(parent=parent,plotitem=plotitem)
               return plot

	"""
	This starts where I have modified this file (except that I
	imported Numeric and some Scientific functions and try to
	import Graceplot at the beginning). 
	
	GetBandMoment() is used to calculate band centers and mean 
	square width, etc... and is the main 
	new function.
	
	SaveData() is for convenience of getting the data into other
	programs

	GetGracePlot() convenience of making a GracePlot
	
	John Kitchin <jkitchin@andrew.cmu.edu>
	"""
	
	def _Integrate(self,xdata,ydata):
		"""
		Internal function for integrating in GetBandMoment() 
		It uses linear interpolation between data points,
		making this effectively an implementation of the
		trapezoid method of integration 
		
		John Kitchin <jkitchin@andrew.cmu.edu>
		"""
		x_array=Numeric.array(xdata)
		y_array=Numeric.array(ydata)

		# Create the interpolating function
		f = InterpolatingFunction((x_array,), y_array)
		
		f_integral=f.definiteIntegral()

		return f_integral

	def _Moment(self,xdata,ydata,nth=1):
		"""
		Internal function to calculate the nth moment  of the
		xdata,ydata distribution, by default the first moment
		(center of energy) 

		John Kitchin <jkitchin@andrew.cmu.edu>
		"""
		xdata=Numeric.array(xdata)
		ydata=Numeric.array(ydata)
		
		data=(xdata**(nth))*ydata
		normalization=self._Integrate(xdata,ydata)

		if nth == 0:
			return normalization
		else:
			return self._Integrate(xdata,data)/normalization
	
	def _GetData(self,relative=None,scaled=None):
		"""
		Internal function to get the data in the form used for
		other convenient commands. 
		It may be useful for future statistical analysis of
		the DOS, such as variance, cross-correlation,
		etc.... I hope the underscore makes it clearly
		different from the other GetData() method.

		If relative is not None, it makes the data relative to
		the Fermi Level. if scaled is not None, it scales the
		density to return the probability distribution, and
		the integral of the probability distribution is equal
		to 1 
		
		John Kitchin <jkitchin@andrew.cmu.edu>
		"""
		dos = self.GetData()
		#This complicated looking statement just separates the
                #data into arrays used later. 
		energy = Numeric.transpose(Numeric.array(dos))[0]
		density= Numeric.transpose(Numeric.array(dos))[1]

		if relative is not None:
			energy=energy-self.GetEFermi()
		if scaled is not None:
			normalization=self._Integrate(energy,density)
			density=density/normalization
			
		return energy,density

	def GetBandMoment(self,*order):
		"""
		This function returns the nth order moment of the band
		with respect to the Fermi level. 

		GetBandMoment() #returns first moment
		GetBandMoment(1,2) # returns first and second moment

		John Kitchin <jkitchin@andrew.cmu.edu>
		"""
		# make it relative to fermi level
		e1,d1=self._GetData(relative=1)
		
		if len(order)==0:
			return self._Moment(e1,d1,1)

		if len(order)==1:
			return self._Moment(e1,d1,order[0])

		else:
			return map(lambda x:self._Moment(e1,d1,x), order)

	def SaveData(self,filename):
		"""
		This is a convenience function.
		Saves the data in a space delimited file in two columns,
		the first one is energy
		the second one is density

		John Kitchin <jkitchin@andrew.cmu.edu>
		"""
		dos = Numeric.array(self.GetData())
		ArrayIO.writeArray(dos,filename)

	def GetGracePlot(self,relative=0,title=None):
		"""
		This function tries to plot the band DOS with
		GracePlot. 
		
		If you don't declare relative, it sets it
		automatically to 0, which means the dos is plotted as
		is. Anything else, and the dos are plotted relative to
		the Fermi Level (that is, the Fermi Level is set to
		0).
		
		The function returns a GracePlot object, which can
		then be further manipulated either interactively or in
		a script, for example to change the plot features. 
		
		John Kitchin <jkitchin@andrew.cmu.edu>		
		"""
		try:
			p=GracePlot()
		except:
			print "GracePlot does not seem to be installed."
			return
		
		if relative != 0:
			E_Fermi=self.GetEFermi()
		else:
			E_Fermi=0
                #Get the data
		energy1,weight1=self._GetData()
		# make Fermi level correction
		energy1=energy1-E_Fermi
		
		p.plot(Data(energy1,weight1))
		if title is not None:
			p.title(title)

		return p

