#!/usr/bin/env python


# Localized Wannier orbitals for ferromagnetic iron
# Generating the bandstructure 

from Dacapo import Dacapo
from ASE import Atom,ListOfAtoms
from ASE.Trajectories import NetCDFTrajectory
import Numeric as num
import os

# check if we have already calculated the nc file
if not os.path.isfile('fe-bcc.nc'):
    
    # Pt wire
    lat = 2.87
    atoms = ListOfAtoms([
                         Atom('Fe', ([0,0,0]),magmom=2.32)])
                         

    cell = num.array([[-1/2.,1/2.,1/2.],[1/2.,-1/2.,1/2.],[1/2.,1/2.,-1/2.]])*2.87
    atoms.SetUnitCell(cell)

    # Dacapo calculator:
    calc = Dacapo(planewavecutoff=400,nbands=16,kpts=(5,5,5),xc='PBE',out='fe-bcc.nc', 
                  spinpol=True)
    atoms.SetCalculator(calc)

    # displace kpoints sligthly, so that the symmetry program in dacapo does
    # not use inversion symmetry to remove kpoints.
    kpoints = calc.GetBZKPoints()
    kpoints[:,0] += 2e-5
    calc.SetBZKPoints(kpoints)
    
    tot = atoms.GetPotentialEnergy()

atoms = Dacapo.ReadAtoms('fe-bcc.nc')

# Begin the Wannier part
from Dacapo import Dacapo
from ASE.Utilities.Wannier.Wannier import Wannier
import Numeric as num

atoms = Dacapo.ReadAtoms('fe-bcc.nc')
calc = atoms.GetCalculator()

# Use 5 d-orbitals (l=2,m=-2,..,2) centered on atom 0 with a radius of 0.4 A as start guess.
# A random start guess will be used for the last (6th) WF.
initialwannier = [[[0],2,0.4]]

# Include all eigenstate below the Fermi level
occenergy=0.0
# Construct Wannier functions for spin down
spin=1

wannier = Wannier(numberofwannier=5+1,calculator=calc,
                  occupationenergy=occenergy, 
                  initialwannier=initialwannier, 
                  spin=spin) 

# Store the localization matrix. 
# wannier.ReadZIBlochMatrix('fe_bloch.pickle')
wannier.SaveZIBlochMatrix('fe_bloch.pickle')
# It can be read again in a later run by: wannier.ReadZIBlochMatrix('fe_bloch.pickle')

# Perform the localization
wannier.Localize(tolerance=1.0e-7)

# Store the WFs. They can be read again in a later run by:
wannier.SaveRotation('fe_wannier')
# They can be read again in a later run by: wannier.ReadRotation('fe_wannier')

# Translate all WFs to the unit cell (2,2,2)
wannier.TranslateAllWannierFunctionsToCell((2,2,2))

# Print the centers and radii of the WFs
centers = wannier.GetCenters()
for n in range(len(centers)):
    print n,centers[n]

# Get the centers as an atom plot
centers = wannier.GetCentersAsAtoms()
traj1 = NetCDFTrajectory('centers-febcc-'+str(spin)+'.traj',centers)
traj1.Update()
traj1.Close()

# Store a '.cube' file for each Wannier function
for n in range(6): 
    wannier.WriteCube(n,'fe_'+str(n)+'_spin_'+str(spin)+'.cube')


# band structure 
fermilevel = calc.GetFermilevel()
G = (0,0,0)
P = (0.25,0.25,0.25)
N = (0.0,0.0,0.5)
H = (-0.5,0.5,0.5)

from ASE.Utilities.Wannier import HoppingParameters
      
npoints = 80
cutoff = 12.0
      
hop=HoppingParameters.HoppingParameters(wannier,cutoff)
hop.WriteBandDiagramToNetCDFFile('GH'+str(cutoff)+'.nc',npoints,G,H,offset=fermilevel)
hop.WriteBandDiagramToNetCDFFile('HP'+str(cutoff)+'.nc',npoints,H,P,offset=fermilevel)
hop.WriteBandDiagramToNetCDFFile('PN'+str(cutoff)+'.nc',npoints,P,N,offset=fermilevel)
hop.WriteBandDiagramToNetCDFFile('NG'+str(cutoff)+'.nc',npoints,N,G,offset=fermilevel)
hop.WriteBandDiagramToNetCDFFile('GP'+str(cutoff)+'.nc',npoints,G,P,offset=fermilevel)
        

