#!/usr/bin/env python

""" restart a NEB calculations 

Script to restart a NEB calculation. 
Assume that the configurations are in out.<config-no>.nc, 
config-no = 0,M-1
""" 

from math import sqrt

import Numeric as num

from Dacapo import Dacapo
from ASE.Filters.Subset import Subset
from ASE.Filters.NudgedElasticBand import NudgedElasticBand
from ASE.Dynamics.QuasiNewton import QuasiNewton
from ASE.Trajectories.NetCDFTrajectory import NetCDFTrajectory
from ASE import Atom,ListOfAtoms

# number of frames in total including end points
M = 4

# create file names 
filenames = ['initial.nc']
filenames.extend(['out.'+str(n)+'.nc' for n in range(1,M-1)])
filenames.append('final.nc')

# atoms list 
atomslist = [Dacapo.ReadAtoms(file,save_memory=True) for file in filenames]

# setup each calculator
# dacapo does not remember StayAliveOff and filenames
for n in range(len(atomslist)): 
    calc = atomslist[n].GetCalculator() 
    # calc.StayAliveOff()
    calc.SetTxtFile('out-restart.'+str(n)+'.txt') 
    calc.SetNetCDFFile('out-restart.'+str(n)+'.nc') 


mask=[a.GetTag()==1 for a in atomslist[0]]
band = NudgedElasticBand([Subset(atoms, mask=mask) for atoms in atomslist])

# Create a quickmin object:
relax = QuasiNewton(band,fmax=0.02,logfilename='restart.log') 
# relax.ReadHessian('hessian.pickle')

# Create a trajectory for the each image
listoftraj = []
for n in range(len(atomslist)): 
    path = NetCDFTrajectory('image-restart'+str(n)+'.nc', atomslist[n])
    listoftraj.append(path) 
    path.Update()
    relax.Attach(path) 


relax.Converge()

# make a trajectory of the final configurations in neb-path.nc
atomslist = [listoftraj[n].GetListOfAtoms(-1) for n in range(len(listoftraj))]

atom0 = atomslist[0]
newtraj = NetCDFTrajectory('neb-path-restart.nc',atom0) 
newtraj.Update()

for atom in atomslist[1:]: 
    atom0.SetCartesianPositions(atom.GetCartesianPositions())
    atom0.SetCalculator(atom.GetCalculator())
    newtraj.Update()
