""" The Dacapo.Tools.Assistants is a module containing 
    various functions and classes that can facilitate setting
    up Dacapo calculations and treating the results of the
    calculations.

    The PreTreatment sets up various filenames automatically,
    so that these do not have to be set in the Dacapo run-script.

    The PostTreatment class performs some often required
    tasks to be performed in relation to the analysis of a Dacapo
    calculation.

    Examples of use:

    pretreatment = PreTreatment("nametoken")
    .
    .
    .
    .......

    posttreatment = PostTreatment(calculator=calc,name="nametoken")
    posttreatment()

    runs a post treatment on the dacapo calculator calc. The resulting  
    files are named using the name "nametoken".
"""

import os
import Numeric as num
from ASE.Visualization.gnuplot import gnuplot
from ASE.Utilities.List import List

def DropInTextFile(textfile,text,mode="a"):
   f = open(textfile,mode)
   f.write(text+"\n")
   f.close()

class PreTreatment:
   def __init__(self,nametoken):
      self.inncfile = "in_"+nametoken+"_tmp.nc"
      self.outncfile = nametoken+"_tmp.nc"
      self.outtxtfile = nametoken+"_tmp.txt"
      self.trajectoryfile = nametoken+"_traj.nc"
      self.qnlogfile = nametoken+"_qn.log"
      self.hessianfile = nametoken+"_hessian.pickle"

class PostTreatment:
   def __init__(self,calculator=None,name="unnamed",ads=[],surf=[],sub=[],isspinpolarized=False):
      self.SetCalculator(calculator) # Dacapo calculator
      self.SetName(name)             # string
      self.adsorbate = ads           # list of indices of adsorbate atoms +1 (for some reason)
      self.surface = surf            # list of indices of surface atoms +1 (for some reason)
      self.substrate = sub           # list of indices of substrate atoms +1 (for some reason)
      self.isspinpolarized=isspinpolarized # True/False
 
   def SetCalculator(self,calculator):
      self._calculator = calculator

   def SetName(self,nametoken):
      self._name = nametoken

   def __call__(self):
      if self._calculator is None:
         raise "PostTreatment calculator not set."

      calc = self._calculator
      nametoken = self._name
      atoms = calc.GetListOfAtoms()
      # Set up the names:
      runtextfile = nametoken+"_out.txt"
      electrostatictitle = "Electrostatic Potential for "+nametoken
      electrostaticplotfile = nametoken+"_electrostaticpotential.ps"
      dosplotfile = nametoken+"_dos.ps"
      #=============================================================
      #=============== Post Treatment: General =====================
      #=============================================================
      # Pick up the pieces :
      eigs = calc.GetNetCDFEntry('EigenValues')[-1]
      occupationnumbers = calc.GetNetCDFEntry('OccupationNumbers')[-1]
      Nkpts,Nspins,Nbands = eigs.shape
      fermilevel = calc.GetFermiLevel()
      stress = calc.GetStress()
      potentialenergy = calc.GetPotentialEnergy()
      forces = calc.GetCartesianForces()
      positions = calc.GetListOfAtoms().GetCartesianPositions()
      # And drop them again :
      DropInTextFile(runtextfile,72*"="+"\n"+nametoken+" run, beware - Mighty Dacapo speaks :\n"+72*"="+"\n",mode="w")
      maxtopbandoccupation = max(max(occupationnumbers))[-1]
      if maxtopbandoccupation > 1e-6:
         DropInTextFile(runtextfile,"WARNING: Top band is occupied, occ.="+str(maxtopbandoccupation)+", increase Nbands\n")
      DropInTextFile(runtextfile,"Fermi level: "+str(fermilevel)+"\n")
      DropInTextFile(runtextfile,"  Band: Spin:  Kpt:   Eigenvalue:  Rel.Eigenvalue:  Occupation:")
      for band in range(Nbands):
         for spin in range(Nspins):
            for kpt in range(Nkpts):
               dropstring = "%5d %5d %5d " %(band,spin,kpt)
               dropstring += "%15.9f %15.10f %13.10f" %(eigs[kpt,spin,band],eigs[kpt,spin,band]-fermilevel,occupationnumbers[kpt,spin,band])
               DropInTextFile(runtextfile,dropstring)
      DropInTextFile(runtextfile,"\nForces:\n"+str(forces)+"\n")
      DropInTextFile(runtextfile,"Stress:\n"+str(stress)+"\n")
      DropInTextFile(runtextfile,"Positions:\n"+str(positions)+"\n")
      DropInTextFile(runtextfile,"Potential Energy: "+str(potentialenergy)+"\n")
      #=============================================================
      #=============== Post Treatment: Electrostatics ==============
      #=============================================================
      electro = calc.GetElectrostaticPotential()
      DropInTextFile(runtextfile,"Work Function:"+str("?")+"\n")
      nx,ny,nz = num.shape(electro)
      unitcell = atoms.GetUnitCell()
      xaxis = [float(z)/float(nz)*unitcell[2,2] for z in range(nz)]
      #xyaverage = [num.sum(num.sum(electro[:,:,z]))/(nx*ny) for z in range(nz)]
      xyaverage = num.sum(num.sum(electro))/(nx*ny)
      potential = List(zip(xaxis,xyaverage))
      potential.legend = '(x-y) Averaged Electrostatic potential'
      potential.ylabel = 'Electrostatic potential (eV)'
      potential.xlabel = 'Distance along z-axis (Angstrom)'
      potential.title = electrostatictitle
      plot = gnuplot(potential)
      plot.Update()
      plot.hardcopy(electrostaticplotfile,enhanced=1,color=1)
      plot.Update()
      try:
         os.system("pstopnm "+electrostaticplotfile)
      except:
         pass
      #vtkplot = VTKPlotArray(electro,unitcell)
      #=============================================================
      #=============== Post Treatment: DOS =========================
      #=============================================================
      # atoms : the LOA indices of included atoms.
      # angularchannels: ['s','p','d','p_x','p_y','p_z','d_zz','dxx-yy','d_xy','d_xz','d_yz']
      # spins: [0,1]. Spin must be 0 if calculation not spin polarized.
      # cutoffradius : ['short','long']. 'short' gives 1 A cutoff. Otherwise infinite
      DropInTextFile(runtextfile,"Atom projected DOS:\n")
      allatoms = self.adsorbate + self.surface + self.substrate
      if len(allatoms)==0:
         DropInTextFile(runtextfile,"No atoms specified for projected density of states.\n")
         return
      plot1 = DOSAnalysis(calc,"All",allatoms,['s','p','d'],runtextfile,None,self.isspinpolarized)
      plot1.xlabel('E-Efermi (eV)')
      plot1.ylabel('Density of states')
      if len(self.adsorbate) > 0:  
         plot2 = DOSAnalysis(calc,"Adsorbate",self.adsorbate,['s','p','d'],runtextfile,plot1,self.isspinpolarized)
      if len(self.surface) > 0:
         plot3 = DOSAnalysis(calc,"Surface, All states",self.surface,['s','p','d'],runtextfile,plot1,self.isspinpolarized)
         plot4 = DOSAnalysis(calc,"Surface, d-states",self.surface,['d'],runtextfile,plot1,self.isspinpolarized)
      if len(self.substrate) > 0:
         plot5 = DOSAnalysis(calc,"Substrate, All states",self.substrate,['s','p','d'],runtextfile,plot1,self.isspinpolarized)
         plot6 = DOSAnalysis(calc,"Substrate, d-states",self.substrate,['d'],runtextfile,plot1,self.isspinpolarized)
      plot1.Update()
      plot1.hardcopy(dosplotfile,enhanced=1,color=1)
      plot1.Update()
#      try:
      os.system("pstopnm "+dosplotfile)
#      except:
#         pass
      #self.ldosdatafile = nametoken+"_ldos.dat" # should be moved up if used
      #ldosall.GetData()
      #ldosall.SaveData(ldosdatafile)
      #GetGracePlot()

def DOSAnalysis(calc,text,atomindices,channels,runtextfile,parent,isspinpolarized):
   if isspinpolarized:
      ldos = calc.GetLDOS(atoms=atomindices,angularchannels=channels,cutoffradius='short',spin=[0,1])
      plot = ldos.GetPlot(parent=parent,legend=text,relative="NotNone")
      moments = ldos.GetBandMoment(0,1,2,3,4,5,6)
      idos = ldos.GetIntegratedDOS()
      DropInTextFile(runtextfile,text)
      DropInTextFile(runtextfile,"Moments: %6.3f %7.3f %8.3f %9.3f %9.3f %10.3f %11.3f" %tuple(moments))
      DropInTextFile(runtextfile,"Integrated DOS: "+str(idos)+"\n")
   else:
      ldos = calc.GetLDOS(atoms=atomindices,angularchannels=channels,cutoffradius='short',spin=[0])
      plot = ldos.GetPlot(parent=parent,legend=text,relative="NotNone")
      moments = ldos.GetBandMoment(0,1,2,3,4,5,6)
      idos = ldos.GetIntegratedDOS()
      DropInTextFile(runtextfile,text)
      DropInTextFile(runtextfile,"Moments: %6.3f %7.3f %8.3f %9.3f %9.3f %10.3f %11.3f" %tuple(moments))
      DropInTextFile(runtextfile,"Integrated DOS: "+str(idos)+"\n")
   return plot


def DropInTextFile(textfile,text,mode="a"):
   f = open(textfile,mode)
   f.write(text+"\n")
   f.close()

