#!/usr/bin/env python

from math import sqrt

import Numeric as num

from Dacapo import Dacapo
from ASE.Filters.Subset import Subset
from ASE.Filters.NudgedElasticBand import NudgedElasticBand
from ASE.Dynamics.QuasiNewton import QuasiNewton
from ASE.Trajectories.NetCDFTrajectory import NetCDFTrajectory
from ASE import Atom,ListOfAtoms

# setup the static Al slab 
alslab =  ListOfAtoms([ Atom('Al',(0, 0, 0)),
                        Atom('Al',(0, 0.5, 0)),
                        Atom('Al',(0.5, 0, 0)),
                        Atom('Al',(0.5, 0.5, 0)),
                        Atom('Al',(0.3333333, 0.166666667, 0.125*2.)),
                        Atom('Al',(0.3333333, 0.666666667, 0.125*2.)),
                        Atom('Al',(0.8333333, 0.166666667, 0.125*2.)),
                        Atom('Al',(0.8333333, 0.666666667, 0.125*2.))])

initial = alslab.Copy()
initial.append(Atom('O',(0.166586,0.33341, -0.043448*2.),tag=1))

final = alslab.Copy()
final.append(Atom('O',(0.33331, 0.166677, -0.04417119*2. ),tag=1))

unitcell = [[5.72756492761103, 0, 0], 
            [-2.86378246380552, 4.96021672913593, 0], 
            [0,0,18.7061487217439/2.]]

initial.SetUnitCell(unitcell) 
final.SetUnitCell(unitcell)  

mask=[a.GetTag()==1 for a in initial]

calc = Dacapo(planewavecutoff = 340, 
              densitycutoff = 500,
              nbands = 22) 

calc.StayAliveOff()
calc.SetNetCDFFile('initial.nc') 
initial.SetCalculator(calc) 
initialenergy = initial.GetPotentialEnergy() 
# Create a qn object:
subset = Subset(initial,mask) 
relax = QuasiNewton(subset,fmax=0.05)
relax.Converge(maxsteps=1)
initialenergy1 = initial.GetPotentialEnergy() 

calc.SetNetCDFFile('final.nc') 
final.SetCalculator(calc) 
finalenergy = final.GetPotentialEnergy() 
# Create a qn object:
subset = Subset(final,mask) 
relax = QuasiNewton(subset,fmax=0.05)
relax.Converge(maxsteps=1)
finalenergy1 = final.GetPotentialEnergy() 

print 'initial state energy             : ',initialenergy
print 'initial state energy (minimized) : ',initialenergy1
print 'final state energy               : ',finalenergy
print 'final state energy (minimized)   : ',finalenergy1

atomslist = [Dacapo.ReadAtoms('initial.nc')]
configs   = [Subset(atomslist[0], mask=mask)]
for n in range(2): 
    atoms = Dacapo.ReadAtoms('initial.nc') 
    atomslist.append(atoms)
    configs.append(Subset(atoms, mask=mask))

atoms = Dacapo.ReadAtoms('final.nc') 
configs.append(Subset(atoms, mask=mask))
atomslist.append(atoms)

# setup each calculator
for n in range(len(atomslist)): 
    calc = atomslist[n].GetCalculator() 
    calc.StayAliveOff()
    calc.SetTxtFile('out.'+str(n)+'.txt') 
    calc.SetNetCDFFile('out.'+str(n)+'.nc') 

band = NudgedElasticBand(configs)
band.SetInterpolatedPositions()

# Create a qn object:
relax = QuasiNewton(band,fmax=0.05,forcemin=False)

# Create a trajectory for the each image
listoftraj = []
for n in range(len(atomslist)): 
    path = NetCDFTrajectory('image'+str(n)+'.nc', atomslist[n])
    print 'n= ',n,atomslist[n].GetPotentialEnergy()
    listoftraj.append(path) 
    path.Update()
    relax.Attach(path) 


relax.Converge()

# make a trajectory of the final configurations in neb-path.nc
atomslist = [listoftraj[n].GetListOfAtoms(-1) for n in range(len(listoftraj))]

atom0 = atomslist[0]
newtraj = NetCDFTrajectory('neb-path.nc',atom0) 
newtraj.Update()

for atom in atomslist[1:]: 
    atom0.SetCartesianPositions(atom.GetCartesianPositions())
    atom0.SetCalculator(atom.GetCalculator())
    newtraj.Update()
