# Copyright (C) 2003  CAMP
# Please see the accompanying LICENSE file for further information.

#BUGS:
# tom serie giver fejl!

"""Classes for IO with dacapo netcdf files.

  These classes can be used directly to communicate with a dacapo netcdf
  file, or used to inherit these methods to simple dacapo entities.

Usage:

  To create a simple entry in a netcdf say "PrintSpatial" attach the following
  attribute to your simulations (in this case 'mysim' ) object:

  '>>> mysim.ps=NetCDF.Entry(name="PrintSpatial")'

  note that the name 'ps' of the attribute is immaterial (see also Simulation).
  If you want to add attributes to the netcdf entry simply add these
  to the simulations attribute:

  '>>> mysim.ps.Quantity="TotalDensity"'

  If you want the entry to have a value you can add this using the SetValue
  method or simply do it at initialization time:

  '>>> mysim.kp=NetCDF.Entry(name="KpointSetup",value=[3,3,3])'

  If the entry needs to have speciel names for dimensions there is an optional
  dimension parameter:

  '>>> mysim.kp=NetCDF.Entry(name="DynamicAtomAttributes", value=["fix:123"],dimensionnames=[number_of_dynamic_atoms])'

  In the special case of the unlimited dimension the speciel NetCDF.Series
  class is offered. It works just like NetCDF.Entry except the first dimension
  is always the unlimited. Example:

  '>>> mysim.tot=NetCDF.Series("TotalEnergy")'

  The final class is NetCDF.File. This class actual represents all the
  elements in a given netcdf file:

  '>>> oldsim=NetCDF.File("cu.nc")'

  after the initialisation it contains attributes for all the entries in
  the file. Note that the actual values are not read before the
  ReadFromNetCDFFile method is called.

  When all Entries are attached to the simulation object, the latter is capable
  of making them write and read themselves to a netcdf file (see Simulation).

**Trouble shooting**

  If two or more entries tries to write to the same variable in the file
  an error is raised. This is to avoid unintended overwriting of values
  due to "forgotten" attributes. This is especially prune to happen if
  a 'NetCDF.File' is present as an attribute to a simulation. The solution
  is usually to 'del' the attributes (of the simulation or 'NetCDF.File' )
  which does not represent the wanted. Say the simulation has an attached
  'ListOfAtoms' in addition to a 'NetCDF.File' with an 'UnitCell' attribute.
  Both will try to write the unitcell resulting in an error. Deleting the
  UnitCell attribute:

  '>>> del mysim.oldsim.UnitCell'

  will cure the problem (of course their might be other conflicts).
  
**When writing dacapo entities**

  In order for this to work two important methods should always be available:

  'WriteToNetCDFFile(file =' *netcdffile* ') ' -- Method to write important data to a netcdffile.

  'ReadFromNetCDFFile( ' *file* = *netcdffile* ') ' -- Method to read important data from a netcdffile.

  These two are present in all the classes in this in this module. When
  writing new dacapo entities these should either be inherited (as in
  'PlanewaveCutoff' ) or be implemented directly (as in 'ListOfAtoms' ).
  
"""

__docformat__ = 'reStructuredText'

from Numeric import asarray,concatenate
import string
from UserList import UserList

class NameConflict(IOError):
	"""Error for use in WriteToNetCDF

  If two or more entries tries to write to the same variable in the file
  an error is raised. This is to avoid unintended overwriting of values
  due to "forgotten" attributes. This is especially prune to happen if
  a 'NetCDF.File' is present as an attribute to a simulation. The solution
  is usually to 'del' the attributes (of the simulation or 'NetCDF.File' )
  which does not represent the wanted. Say the simulation has an attached
  'ListOfAtoms' in addition to a 'NetCDF.File' with an 'UnitCell' attribute.
  Both will try to write the unitcell resulting in an error. Deleting the
  UnitCell attribute:

  '>>> del mysim.oldsim.UnitCell'

  will cure the problem (of course their might be other conflicts).
  """

class Dimension:
	"""Class to handle dimensions in a netcdf file.

        Especially useful for creating dimensions which have no variable
	associated with them, "dummy" dimensions."""
	def __init__(self,name,size=None):
		self.SetName(name)
		self.SetSize(size)

	def WriteToNetCDFFile(self,file):
		"""Method to create dimension in a netcdffile."""
		if not file.dimensions.has_key(self._name_):
		    file.createDimension(self._name_,self._size_)
		file.sync()	

	def ReadFromNetCDFFile(self,file,index=None):
		"""Method to read size of dimension from a netcdffile."""
		#should this be in a try statement?
		self._size_=file.dimensions[self._name_]
		return self

	def GetSize(self):
		return self._size_

	def SetSize(self,size):
		self._size_=size

	def GetName(self):
		return self._name_

	def SetName(self,name):
		self._name_=name


class Entry:
	"""Class for simple entries in netcdf files.

  To create a simple entry in a netcdf say "PrintSpatial" attach the following
  attribute to your simulations (in this case 'mysim' ) object:

  '>>> mysim.ps=NetCDF.Entry(name="PrintSpatial")'

  note that the name 'ps' of the attribute is immaterial (see also Simulation).
  If you want to add attributes to the netcdf entry simply add these
  to the simulations attribute:

  '>>> mysim.ps.Quantity="TotalDensity"'

  If you want the entry to have a value you can add this using the SetValue
  method or simply do it at initialization time:

  '>>> mysim.kp=NetCDF.Entry(name="KpointSetup",value=[3,3,3])'

  If the entry needs to have speciel names for dimensions there is an optional
  dimensionnames parameter:

  '>>> mysim.kp=NetCDF.Entry(name="DynamicAtomAttributes", value=["fix:123"],dimensionnames=[number_of_dynamic_atoms])'
        """
        def __setattr__(self,name,value):
		self.__dict__[name]=value
                        
	def __delattr__(self,name):
		del self.__dict__[name]
	    
	def __init__(self,name=None,value='',dimensionnames=None):
		if not name:
			self.SetName(self.__class__.__name__)
		else:
			self.SetName(name)
		self.SetValue(value)
		if dimensionnames!=None:
			self.SetDimensionNames(dimensionnames)


	def __repr__(self):
		s = 'NetCDF.Entry(name='+repr(self.GetName())+', value='+repr(self.GetValue())
		if self.GetDimensionNames() is not None:
			s = s+', dimensionnames='+repr(self.GetDimensionNames())
		s=s+')'
		return s


	def reprWithName(self,name):
		s=name+"="+repr(self)+"\n"
		attr=self.GetAttributes()
		for key in attr.keys():
			s=s+name+"."+key+"="+repr(attr[key])+"\n"
		return s
	
	def __str__(self):
		text=self.GetName()
		if asarray(self.GetValue()).shape != (0,) and self.GetValue() !='\000' :
			text=text+'= '+str(self.GetValue())
		if self.GetAttributes() != {}:
			text=text+"\n"+str(self.GetAttributes())
		return text

	def WriteToNetCDFFile(self,file,indices=None):
		"""Method to write important data to a netcdffile."""
		if self.GetValue() is None:
			data=asarray('')
		else:
			data=asarray(self.GetValue())
  		name=self.GetName()
		attr=self.GetAttributes()
		dimlist=self.GetDimensions()
		for i in range(len(dimlist[0])):
        		Dimension(dimlist[0][i],dimlist[1][i]).WriteToNetCDFFile(file)
    		if file.variables.has_key(name):
		    v=file.variables[name]
		    if indices==None:# raise error (unless writing to specific position)
			raise NameConflict,"Trying to write twice under the same variablename: "+str(name)
		else:
		    v=file.createVariable(name,data.typecode(),tuple(dimlist[0]))
		#assignValue dont accept dimunlimited ?!
		if indices==None:#have indices been specified?
			if data.shape==():#no, is this an array?
				indices=[]#no, write as a single item
			else:
				indices=range(data.shape[0])#n-dim array is written as a 1-dim list of n-1 dim arrays..
		if indices!=[]:#have indices been specified?
		    if type(indices)==type(1):#is a single index?
			v[indices]=data[-1] #write last entry into list a specified point
		    else: #or a list?
			for i in indices: #write all items specified
			    v[i]=data[i]
		else:
			v.assignValue(data)# write as a single item
    		file.sync()
    		for key in attr.keys():
			setattr(v,key,attr[key])
			file.sync()


	def ReadFromNetCDFFile(self,file,index=None):
		"""Method to read important data from a netcdffile."""
		v=file.variables[self.GetName()]
		if index!=None:#have indices been specified?
			self.SetValue([v[index]])#yes, read only the indicated item but keep dimensionality
		else:
			self.SetValue(v.getValue())#default: read all items
    		for key in v.__dict__.keys():
			setattr(self,key,v.__dict__[key])
		dimlist=[(),()]
		for dim in v.dimensions:#read dimensions
			        dimlist[0]=dimlist[0]+(dim,)
        			dimlist[1]=dimlist[1]+(file.dimensions[dim],)
		setattr(self,"_dimensions_",dimlist)
		#remember names even if size needs to be recalculated
		self.SetDimensionNames(list(dimlist[0]))
		return self

	def GetAttributes(self):
		"""Returns the attributes to be written in the netcdf file.

		By default this is just the attributes of the instance, except
		those that starts or ends with an underscore ('_').
		"""
		attr={}
		for key in self.__dict__.keys():
			if key[:1]!= "_" and key[-1:]!= "_" and key !='dimensions' :
				 attr[key]=self.__dict__[key]
		return attr

	def SetAttributes(self,attr):
		"""Sets the attributes to be written in the netcdf file.

		Removes all old attributes except
		those that starts or ends with an underscore ('_').
		"""
		for key in self.__dict__.keys():
			if key[:1]!= "_" and key[-1:]!= "_" and key !='dimensions' :
				 delattr(self,key)
		for key in attr.keys():
			if key[:1]!= "_" and key[-1:]!= "_" and key !='dimensions' :
				 setattr(self,key,attr[key])

	def GetDimensions(self):
		"""Returns the dimensions to be used for writting in the netcdf file.

		By default the names are constructed from the shape of the
		value simply as  'dim' *nn* where *nn* is
		the length. This is overridden by any names in the
		'_dimensionnames_' attribute.
		"""
		#is dimensions already calculated?
		if self.__dict__.has_key('_dimensions_'):
			return self._dimensions_
		data=asarray(self.GetValue())
	    	dimlist=[[],[]]
		if data.shape != (0,) :
    			for d in data.shape:
        			dimlist[0].append('dim'+`d`)
        			dimlist[1].append(d)
		if self.__dict__.has_key('_dimensionnames_'):
			for i in range(len(self._dimensionnames_)):
				if len(self._dimensionnames_[i])!=0 and i<len(dimlist[0]):
					dimlist[0][i]=self._dimensionnames_[i]
					if dimlist[0][i]=='number_ionic_steps':
						dimlist[1][i]=None
		self._dimensions_=dimlist
		return dimlist

	def GetDimensionNames(self):
		"""Returns the names of the dimensions or None if not defined
		"""
		try:	dname = self._dimensionnames_
		except: dname = None
		return dname
	
	def SetDimensionNames(self,names):
		"""Sets special names for dimensions
		"""
		#check for any 'dim' nn names
		for i  in range(len(names)):
			if names[i][:3]=='dim':
				try:
					string.atoi(names[i][3:])
					names[i]=''
				except ValueError:
					pass
		self._dimensionnames_=names

	def GetName(self):
		"""Returns the name to be used for writting in the netcdf file.
		"""
		return self._name_

     	def SetName(self,name):
		"""Sets the name to be used for writting in the netcdf file.
		"""
		self._name_=name

     	def GetValue(self):
		"""Returns the valued to be written in the netcdf file
		"""
		return self._value_

     	def SetValue(self,value):
		"""Sets the valued to be written in the netcdf file
		"""
		# recalculate dimensions!
		if self.__dict__.has_key('_dimensions_'):
			del self._dimensions_
		self._value_=value

	def GetwxManipulator(self,parent=None):
		from Manipulators.wxPython.NetCDFEntryManipulator import NetCDFEntryManipulator
		return NetCDFEntryManipulator(parent,self)

class EntryWithStaticName(Entry):
	"""An abstract variant of Entry which have a static name determined by the class name.
	
	Example:
	'class PlaneWaveCutoff(NetCDF.EntryWithStaticName):

	def __init__(self,value=340):
		NetCDF.EntryWithStaticName.__init__(self,value=value)
	'
        """

	def __repr__(self):
		s = self.GetName()+'(value='+repr(self.GetValue())
		if self.GetDimensionNames() is not None:
			s = s+', dimensionnames='+repr(self.GetDimensionNames())
		s=s+')'
		return s

	def GetName(self):
		"""Returns the name to be used for writting in the netcdf file.
		"""
		return self.__class__.__name__

	def GetwxManipulator(self,parent=None):
		from Manipulators.wxPython.NetCDFEntryManipulator import NetCDFEntryWithStaticNameManipulator,NetCDFEntryWithStaticNameAndNoValueManipulator
		if self.GetValue() is None:
			return NetCDFEntryWithStaticNameAndNoValueManipulator(parent,self)
		else:
			return NetCDFEntryWithStaticNameManipulator(parent,self)

class Series(Entry):
	"""A variant of Entry which have the first dimension unlimited
	
	this  works just like NetCDF.Entry (q.v.) except the first dimension
	is by default the unlimited. Example:

	'>>> mysim.tot=NetCDF.Series("TotalEnergy")'
        """

	def __init__(self,name=None,value=None,dimensionnames=None):
	    if value==None:
		value=[]
	    if dimensionnames==None:
		dimensionnames=[]
	    else:
		while "number_ionic_steps" in dimensionnames:#avoid double number_ionic_steps (should perhaps raise Error) 
		    dimensionnames.remove("number_ionic_steps")
	    Entry.__init__(self,name,value,["number_ionic_steps"]+dimensionnames)

class SeriesOfEntries(UserList):
    """Represents a series of dacapo entries.

    This is used to create series out of more simple objects.
    The Read/WriteToNetCDFFile simply loops over the list of
    objects and read/write it with the index.

    The objects dont have to be simple entries as long as the objects have
    both a ReadFrom- and a WriteToNetCDFFile method.

    '>>> oldsim.configs=NetCDF.SeriesOfEntries([ListOfAtoms(),ListOfAtoms()]])'

    """

    def ReadFromNetCDFFile(self,file,index=None):
	"""Method to read entries from a netcdffile."""

	if index==None:
	    for i in range(len(self)):
		self[i].ReadFromNetCDFFile(file,i)
	else:
	    self[index].ReadFromNetCDFFile(file,index)
    
    def WriteToNetCDFFile(self,file,indices=None):
	"""Method to write entries to a netcdffile."""

	if indices==None:
	    indices=range(len(self))
	for i in indices:
	    self[i].WriteToNetCDFFile(file,i)


class Collection:
    """Represent a set of entries.
	
    The ReadFrom/WriteToNetCDFFile methods of this object ask
    all the attributes to read/write themselves to a given filen.
    """

    def __setattr__(self,name,value):
        self.__dict__[name]=value
                        
    def __delattr__(self,name):
        if self.__dict__.has_key(name): 
            del self.__dict__[name]
        else: 
            raise AttributeError

    def ReadFromNetCDFFile(self,file=None,index=None):
	    """Method to read all non private attributes from a netcdffile."""
	    for attr in self.__dict__.keys():
		    if attr[-1]!="_" and attr[0]!="_":#don't take private attr
			try:
                                getattr(self,attr).ReadFromNetCDFFile(file,index)
                        except AttributeError:
                                pass
	    return self

    def WriteToNetCDFFile(self,file):
	    """Method to write all non private attributes to a netcdffile."""
	    for attr in self.__dict__.keys():
		if attr[-1]!="_" and attr[0]!="_":#don't take private attr
		    try:
			getattr(self,attr).WriteToNetCDFFile(file)
		    except AttributeError:
			pass

    def DeleteNetCDFEntry(self,collection,delattr=""):
            for key in collection.__dict__.keys():
                    if key[-1]!="_" and key[0]!="_":#don't take private attr
                            try:
                                    name = collection.__dict__[key].GetName()
                                    if name == delattr:
                                            del collection.__dict__[key]
				    else:
					    try:
                                                    self.DeleteNetCDFEntry(getattr(collection,key),delattr)
                                            except AttributeError:
                                                    print "AttributeError",key
			    except AttributeError:
				    pass

class File(Collection):
	"""Represent the set of entries in a netcdf file
	
	This class represents all the elements in a given netcdf file:

	'>>> oldsim=NetCDF.File("cu.nc")'

	after the initialisation it contains attributes for all the entries in
	the file. Note that the actual values are not read before the
	'ReadFromNetCDFFile' method is called.
	"""
	def __init__(self,filename=None):
		if filename != None:
			from Scientific.IO.NetCDF import NetCDFFile 
			self._file_=NetCDFFile(filename,'r')
			for name in self._file_.variables.keys():
				self.__dict__[name]=Entry(name=name)
			for dim in self._file_.dimensions.keys():
				self.__dict__[dim]=Dimension(name=dim)

        def __str__(self):
                text=repr(self)+":\n"
                for attr in self.__dict__.keys():
                        text=text+"\n"+str(getattr(self,attr))+"\n"
                return text

	def __del__(self):
		self._file_.close()

	def ReadFromNetCDFFile(self,file=None,index=None):
		"""Method to read all non private attributes from a netcdffile.

		If no file is supplied the original file will be used"""
		
		if file==None:
			file=self._file_
		Collection.ReadFromNetCDFFile(self,file,index)
