"""Implementations of function spaces"""

from ASE.Utilities import ArrayTools
import LinearAlgebra
import Numeric
import copy
import FFT

def ValueMatrixFromIndexMatrix(indexmatrix,indices,values):
	# indexmatrix->valuematrix by replacing indices by values
	# Could probably be made faster
	# Converting values to a NumPy array, if necessary
	values=Numeric.asarray(values)
	# Giving value matrix correct type
	valuematrix=Numeric.zeros(indexmatrix.shape,values.typecode())
	for i in range(len(indices)):
		newvaluematrix=Numeric.where(Numeric.equal(indexmatrix,indices[i]),values[i],0.0)
		Numeric.add(valuematrix,newvaluematrix,valuematrix)
	return valuematrix


def BlockFormFromVector(vector,dimensions,vectorshape=None):
	# Transforming vector(a,dim)->blockform(vectorshape,dim1,...,dimi)
	# If vectorshape is not specified leave it unaltered
	if vectorshape is None:
		ndim=len(dimensions)
		vectorshape=vector.shape[:-1]
	return Numeric.reshape(vector,tuple(vectorshape)+dimensions)

def VectorFromBlockForm(blockform,dimensions,vectorshape=None):
	# Transforming blockform(a,dim1,...,dimi)->vector(vectorshape,dim)

	mul_dim=(Numeric.multiply.reduce(dimensions),)
	# If vectorshape is not specified leave it unaltered
	if vectorshape is None:
		ndim=len(dimensions)
		vectorshape=blockform.shape[:-ndim]
	return Numeric.reshape(blockform,vectorshape+mul_dim)
	
def MatrixFromBlockForm(representation,dimensions):
	matrixdim=Numeric.multiply.reduce(dimensions)
	return Numeric.reshape(representation,(matrixdim,matrixdim))

def BlockFormFromMatrix(matrix,dimensions):
	return Numeric.reshape(matrix,dimensions+dimensions)

def ReverseCyclicPermutation(matrix,contiguous=1):
	# matrix(a,b,...,c,d) -> matrix(b,...,c,d,a)
	axes=range(len(matrix.shape))
	# Move first index to last
	axes.append(0)
	del axes[0]
	# Transpose makes the matrix non-contiguous
	if contiguous==1:
		return copy.copy(Numeric.transpose(matrix,axes))
	else:
		return Numeric.transpose(matrix,axes)

def CyclicPermutation(matrix,contiguous=1):
	# matrix(a,b,..,c,d) -> matrix(d,a,b,...,c)
	axes=range(len(matrix.shape))
	axes[0:0]=[axes[-1]]
	axes=axes[:-1]
	# Transpose makes the matrix non-contiguous
	if contiguous==1:
		return copy.copy(Numeric.transpose(matrix,axes))
	else:
		return Numeric.transpose(matrix,axes)

def Diag(values):
	return Numeric.identity(len(values))*values

def Round(array):
	shape=array.shape
	roundarray=map(lambda x:round(x),array.flat)
	return Numeric.reshape(roundarray,shape)

def Dagger(array):
	return Numeric.transpose(Numeric.conjugate(array))


class FunctionSpace1D:

	def __init__(self,length):
		self.SetLength(length)

	def SetLength(self,length):
		self._length=length

	def GetLength(self):
		return self._length

	def GetDimensions(self):
		# Method which should be available for all functionspaces
		return (self.GetLength(),)

	def GetLaplaceRepresentation(self,laplacerepresentation,metric=None):
		# Defining unit metric if None is specified
		if metric is None:
			metric=Numeric.array([[1.0]])
		inv_metric=LinearAlgebra.inverse(Numeric.asarray(metric))
		return inv_metric[0,0]*laplacerepresentation

	def InnerProduct(self,state1,operatorrepresentation,state2):
		matrixmultiply=ArrayTools.MatrixMultiplication
		# Assuming state1 and state2 to be given on the form:
		# statei = (a,dim) where dim is the dimension of the space. 
		# state1(a,dim) -> state1^*(a,dim)
		state1=Numeric.conjugate(state1)
		# state2(b,dim) -> state2(dim,b)
		#state2=CyclicPermutation(state2)
		# Finally doing the inner product as:
		# state1^*(a,dim)xoperator(dim,dim)xstate2(dim,b)=
		# state1^*(a,dim)xoperator(dim,dim)x(state2(b,dim))^{T}
		return matrixmultiply(state1,matrixmultiply(operatorrepresentation,state2,btype='Transpose'))

class InterpoletSpace:

        def __init__(self,depth,waveletparameters):
                self.SetDepth(depth)
                self.SetWaveletParameters(waveletparameters)

	def IsOrthogonal(self):
		return 0

        def SetDepth(self,depth):
                self._depth=depth

        def GetDepth(self):
                return self._depth

        def SetWaveletParameters(self,waveletparameters):
                self._parameters=waveletparameters

        def GetWaveletParameters(self):
                return self._parameters

        def GetScaleFunctionIndices(self,level):
                # The distance between indices
                step=pow(2,level)
                return Numeric.arange(0,self.GetLength(),step)

        def GetWaveletIndices(self,level):
                # The distance between indices
                step=pow(2,level)
                # When to start the series
                start=pow(2,level-1)
                return Numeric.arange(start,self.GetLength(),step)

	def _GetScaleCoefficientMatrix(self,level):
		pass

        def ForwardWaveletTransform(self,level1,level2,array):
                # Here level1>level2: 
                # level1: start level, level2 final level
                for level in range(level1-1,level2-1,-1):
                # Only indices at wavelets at level+1 are affected:
                # v(y)=v(y)+sum_z c((y-z)/2^(level))*v(z)
                #     =v(y)+sum_z c(y,z)*v(z)
                # y:waveletindices, z:scaleindices, c: scale coefficients
			scalecoefficients=self._GetScaleCoefficientMatrix(level=level+1)
                        waveletindices=self.GetWaveletIndices(level+1)
                        scaleindices=self.GetScaleFunctionIndices(level+1)
                # Now ready for v'=sum_z c(y,z)*v_scale(z) and
                # v_wavelet=v_wavelet+v'
                        v_scale=Numeric.take(array,scaleindices)
                        v_wavelet=Numeric.take(array,waveletindices)
                        Numeric.add(v_wavelet,Numeric.matrixmultiply(scalecoefficients,v_scale),v_wavelet)
                # Finally insert the result in array
                        Numeric.put(array,waveletindices,v_wavelet)
        

        def InverseWaveletTransform(self,level1,level2,array):
                # level1<level2
                for level in range(level1,level2):
               	# Only indices at wavelets at level+1 are affected:
                # v(y)=v(y)-sum_z c((y-z)/2^(level))*v(z)
                #     =v(y)-sum_z c(y,z)*v(z)
                # y:waveletindices, z:scaleindices, c: scale coefficients
			scalecoefficients=self._GetScaleCoefficientMatrix(level=level+1)

                        waveletindices=self.GetWaveletIndices(level+1)
                        scaleindices=self.GetScaleFunctionIndices(level+1)
	         # Now ready for v'=sum_z c(y,z)*v_scale(z) and
                 # v_wavelet=v_wavelet-v'
                        v_scale=Numeric.take(array,scaleindices)
                        v_wavelet=Numeric.take(array,waveletindices)
                        Numeric.subtract(v_wavelet,Numeric.matrixmultiply(scalecoefficients,v_scale),v_wavelet)
                # Finally insert the result in array
                        Numeric.put(array,waveletindices,v_wavelet)

	def GetInverseTransformationMatrix(self):
		# IF depth=0 : The transform is the identity and the loop is
                #              not entered
		# IF depth!=0: The transform is the identity but 1.0 is used
		#	       for performance reasons
		transform=1.0
		if self.GetDepth()==0:
			transform=Numeric.identity(self.GetLength()).astype(Numeric.Float)
		for level in range(self.GetDepth()):
			newtransform=Numeric.identity(self.GetLength()).astype(Numeric.Float)
			scalematrix=-1*self._GetScaleCoefficientMatrix(level=level+1)
			waveletindices=self.GetWaveletIndices(level+1)
			scaleindices=self.GetScaleFunctionIndices(level+1)
		# Finding indices for the scalematrix:
		# In a flattened array, b_{k}=a_{ij}.flat, 
		# k=i*max(j)+j
			indices=Numeric.add.outer(waveletindices*self.GetLength(),scaleindices)
		# Inserting the values at indices
			Numeric.put(newtransform,indices.flat,scalematrix.flat)
		# and finally multiplying with tranform
			transform=Numeric.matrixmultiply(newtransform,transform)
		return transform

	def GetForwardTransformationMatrix(self):
		# IF depth=0 : The transform is the identity and the loop is
                #              not entered
		# IF depth!=0: The transform is the identity but 1.0 is used
		#	       for performance reasons
		transform=1.0
		if self.GetDepth()==0:
			transform=Numeric.identity(self.GetLength()).astype(Numeric.Float)
		for level in range(self.GetDepth()-1,-1,-1):
			newtransform=Numeric.identity(self.GetLength()).astype(Numeric.Float)
			scalematrix=self._GetScaleCoefficientMatrix(level=level+1)
			waveletindices=self.GetWaveletIndices(level+1)
			scaleindices=self.GetScaleFunctionIndices(level+1)
		# Finding indices for the scalematrix:
		# In a flattened array, b_{k}=a_{ij}.flat, 
		# k=i*max(j)+j
			indices=Numeric.add.outer(waveletindices*self.GetLength(),scaleindices)
		# Inserting the values at indices
			Numeric.put(newtransform,indices.flat,scalematrix.flat)
		# and finally multiplying with tranform
			transform=Numeric.matrixmultiply(newtransform,transform)
		return transform
			
        def CoordinatesLevel0FromFunctionValues(self,functionvalues):
                # unit operator between 0 level wavelets and function values
                return functionvalues

        def FunctionValuesFromCoordinatesLevel0(self,coordinates):
                # unit operator between function values and 0 level wavelets
                return coordinates

        def CoordinatesFromFunctionValues(self,functionvalues):
                # array: Avoid original function values from being manipulated
                functionvalues=Numeric.array(functionvalues)
		# *IF* the array is type integer convert it to float
		if functionvalues.typecode() in ['l','b']:
			functionvalues=functionvalues.astype(Numeric.Float)
                coordinates=self.CoordinatesLevel0FromFunctionValues(functionvalues)
                self.InverseWaveletTransform(level1=0,level2=self.GetDepth(),array=coordinates)
                return coordinates

        def FunctionValuesFromCoordinates(self,coordinates):
                # array: Avoid original coordinates from being manipulated
                coordinates=Numeric.array(coordinates)
		# *IF* the array is type integer convert it to float
		if coordinates.typecode() in ['l','b']:
			coordinates=coordinates.astype(Numeric.Float)
                self.ForwardWaveletTransform(level1=self.GetDepth(),level2=0,array=coordinates)
                return self.FunctionValuesFromCoordinatesLevel0(coordinates)

	def _GetMatrixFromLevel0Matrix(self,matrix):
		forward=self.GetForwardTransformationMatrix()
		return Numeric.matrixmultiply(Numeric.transpose(forward),Numeric.matrixmultiply(matrix,forward))


	def Operate(self,state,operatorrepresentation):
		# Carries out A*|psi>, where A: operator and |psi> is a state
		# For a nonorthogonal space this is:
		# A*|psi> = O^{-1}x<i|A|j>xpsi(j) where O=<i|1|j> is overlap

		# The state is assumed to have the form: state(a,dim),
		# where a is vectorindex

		# Calculating: operator=O^{-1}xA
		operator=Numeric.matrixmultiply(self.GetInverseIdentityRepresentation(),operatorrepresentation)
		# state(a,dim)->state(dim,a)
		state=CyclicPermutation(state)
		# Then operate_state(dim,a)=operator(dim,dim)xstate(dim,a)
		operate_state=Numeric.matrixmultiply(operator,state)
		# Then return: operate_state(dim,a)->operate(a,dim)
		return CyclicPermutation(operate_state)
		

class PlaneWaveSpace:

	def IsOrthogonal(self):
		return 1

	def CoordinatesFromFunctionValues(self,functionvalues):
		return ArrayTools.InverseFFT(Numeric.asarray(functionvalues))
		

	def FunctionValuesFromCoordinates(self,coordinates):
		return ArrayTools.FFT(Numeric.asarray(coordinates))


class PlaneWaveSpace1D(FunctionSpace1D,PlaneWaveSpace):

	def GetCoordinateBasis(self):
		length=self.GetLength()
		# Negative basis is first
		negbasis=Numeric.arange(0,-length/2,-1)
		# Then the positive basis
		posbasis=Numeric.arange(length/2,0,-1)
		# And finally construct the complete basis
		return Numeric.concatenate((negbasis,posbasis),-1)

	def GetIdentityRepresentation(self):
		return Numeric.identity(self.GetLength()).astype(Numeric.Float)

	def Identity_InnerProduct(self,state1,state2):
		multiply=ArrayTools.MatrixMultiplication
		# Fast implementation of inner product for the identity:
		# Input on the form: state1(a,dim) and state2(b,dim)
		# where a,b correspond to the number of states
		# Transforming state1(a,dim)->state1^*(a,dim)
		state1=Numeric.conjugate(state1)
		# Transforming state2(b,dim)->state2(dim,b)
		#state2=CyclicPermutation(state2)
		# Now the inner product = state1^*(a,dim)xstate2(dim,b) =
		#                         state1^*(a,dim)x(state2(b,dim))^{T}
		return multiply(state1,state2,btype="Transpose")

	def _GetGradientDiagonal(self):
		# Calculates the diagonal elements used in the gradient
		length=self.GetLength()
		# Calculating G assuming unit length between grid points
		prefactor=2*Numeric.pi*complex(0,1)/length
		return prefactor*self.GetCoordinateBasis()

	def GetGradientRepresentation(self):
		# Finding the diagonal terms:
		gradient_diag=self._GetGradientDiagonal()
		# and setting the diagonal in a matrix
		return Numeric.array([Diag(values=gradient_diag)])

	def Gradient_InnerProduct(self,state1,state2):
		# Fast implementation of inner product for the gradient:
		# Input on the form: state1(a,dim) and state2(b,dim)
		# where a,b correspond to the number of states.
		# Using: <state1|gradient|state2> = <state1|diff_state2>
		# where diff_state2=gradientxstate2=diagonal(gradient)*state2
		diff_state2=Numeric.array(Numeric.multiply(state2,self._GetGradientDiagonal()))
		return Numeric.array([self.Identity_InnerProduct(state1,diff_state2)])

	def _GetLaplaceDiagonal(self):
		# Calculates the laplace elements used in the diagonal:
		# diagonal(laplace)=pow(i*G,2)=-4*pi^2/L^2
		length=self.GetLength()
		prefactor=-4*pow(Numeric.pi,2)/pow(length,2)
		return prefactor*Numeric.power(self.GetCoordinateBasis(),2)

	def GetLaplaceRepresentation(self,metric=None):
		laplace=Diag(values=self._GetLaplaceDiagonal())
		# Finally scaling according to metric
		return FunctionSpace1D.GetLaplaceRepresentation(self,metric=metric,laplacerepresentation=laplace)

	def Laplace_InnerProduct(self,state1,state2,metric=None):
		# Fast implementation of inner product for laplacean:
		# Input on the form: state1(a,dim) and state2(b,dim)
		# where a,b correspond to the number of states
		laplace_diagonal=FunctionSpace1D.GetLaplaceRepresentation(self,metric=metric,laplacerepresentation=self._GetLaplaceDiagonal())
		# Using: <state1|laplace|state2>=<state1|laplace_state2>,where
		# laplace_state2 = laplacexstate2 = diagonal(laplace)*state2
		laplace_state2=Numeric.multiply(laplace_diagonal,state2)
		return self.Identity_InnerProduct(state1,laplace_state2)
	
	def GetLocalPotentialRepresentationOld(self,potential):
		length=self.GetLength()
		potential_G=self.CoordinatesFromFunctionValues(potential)

		# Building index matrix for V(G-G')
		Gbasis=self.GetCoordinateBasis()
		index=Numeric.subtract.outer(Gbasis,Gbasis)

		# Finally mapping the indices -> values
		# NOTE: The basis is 0,-G,-2G,...,2G,G
		# This is the reason for the negative index
		potentialmatrix=map(lambda i,values=potential_G:values[i],-index.flat)
		return Numeric.reshape(potentialmatrix,(length,length))

	def GetLocalPotentialRepresentation(self,potential):
		# Calculating the inner product <G'|V(r)|G>=1/sqrt(L)*V(G'-G)
		length=self.GetLength()
		potential_G=self.CoordinatesFromFunctionValues(potential)
		# Scaling with the length:
		Numeric.divide(potential_G,Numeric.sqrt(length),potential_G)
		# defining output matrix of size length,length
		potentialmatrix=Numeric.zeros((length,length),Numeric.Complex)
		basis=self.GetCoordinateBasis()
		for i in range(length):
			# Inserting the columns of V(G'-G) in the basis:
			# G = {0,-1,-2,...,-length/2+1,
			#      length/2,length/2-1,...,2,1}
			index=basis[i]
			# Slicing the potential as: [G:] + [:G]
			matrixpotential_G=Numeric.concatenate((potential_G[index:],potential_G[:index]))
			# This potential is inserted in the column of V(G'-G)
			potentialmatrix[:,i]=matrixpotential_G
		return potentialmatrix	
		
	def LocalPotential_InnerProduct(self,state1,state2,potential):
		# Using that: <state1|V|state2>=<state1|pot_state2>,
		# where pot_state2=InvF(V(r)*state2(r))

		# Calculating pot_state2
		state2_realspace=self.FunctionValuesFromCoordinates(state2)
		pot_state2_realspace=Numeric.multiply(potential,state2_realspace)
		pot_state2=self.CoordinatesFromFunctionValues(pot_state2_realspace)
		# Finally doing <state1|pot_state2>
		return self.Identity_InnerProduct(state1,pot_state2)

	def _GetTranslationDiagonal(self,coordinates):
		# Using that G = 2*pi*n/L : epx(-i*G*x_{1}) =
		#                pow(exp(-i*2*pi/N*x_{1}),n)
		prefactor=Numeric.exp(-complex(0,1)*2*Numeric.pi*coordinates[0]/self.GetLength())
		translation=map(lambda x,prefactor=prefactor:pow(prefactor,x),self.GetCoordinateBasis())
		return Numeric.array(translation)

	def Translation_Operate(self,state,coordinates):
		# Using that: T_{x_{1}}xpsi(x) = psi(x-x_{1}) =
		#             1/sqrt(L) sum_{G} a_G exp(i*G*(x-x_{1})) =
		#             1/sqrt(L) sum_{G} exp(-i*G*x_{1}) a_G exp(i*G*x)
		# Hence a_G -> exp(-i*G*x_{1}) * a_{G}
		return Numeric.multiply(self._GetTranslationDiagonal(coordinates),state)

	def CoordinatesFromFunctionValues(self,functionvalues):
		# In general: Using a basis set 1/sqrt(L)*exp(iGx)
		normalization=Numeric.sqrt(self.GetLength())
		# Functionvalues of the form (a,x_i), where
		# a is different functions and x_i the function values
		# Return array of the form (a,G_i)
		
		# Only the last axis is transformed
		functionvalues=Numeric.asarray(functionvalues)
		coordinates=FFT.inverse_fft(functionvalues,n=functionvalues.shape[-1],axis=-1)
		# Normalizing:
		Numeric.multiply(coordinates,normalization,coordinates)
		return coordinates

	def FunctionValuesFromCoordinates(self,coordinates):
		# In general: Using a basis set 1/sqrt(L)*exp(iGx)
		normalization=Numeric.sqrt(self.GetLength())
		# Coordinates of the form (a,G_i), where
		# a is different functions and G_i the coordinates

		# Only the last axis is transformed:
		coordinates=Numeric.asarray(coordinates)
		functionvalues=FFT.fft(coordinates,n=coordinates.shape[-1],axis=-1)
		Numeric.divide(functionvalues,normalization,functionvalues)
		return functionvalues

	def Operate(self,state,operatorrepresentation):
		# Carries out A*|psi>, where A: operator and |psi> is a state
		# For an orthogonal space this is:
		# A*|psi> = <i|A|j>xpsi(j) 

		# The state is assumed to have the form: state(a,dim),
		# where a is vectorindex

		# state(a,dim)->state(dim,a)
		state=CyclicPermutation(state)
		# Then operate_state(dim,a)=operator(dim,dim)xstate(dim,a)
		operate_state=Numeric.matrixmultiply(operatorrepresentation,state)
		# Then return: operate_state(dim,a)->operate(a,dim)
		return CyclicPermutation(operate_state)

class FunctionSpace2D:

	def __init__(self,dimensions):
		self.SetDimensions(dimensions)

	def GetDimensions(self):
		return self.dimensions

	def SetDimensions(self,dimensions):
		self.dimensions=tuple(dimensions)

	def SetSubSpaces(self,subspaces):
		self.subspaces=subspaces

	def GetSubSpaces(self):
		return self.subspaces

	def GetIdentityRepresentation(self):
		# Formed by the tensorproduct of the subspaces:
		# I(n1,n1',n2,n2')=I(n1,n1')xI(n2,n2') -> I(n1,n2,n1',n2')
		space1,space2=self.GetSubSpaces()
		representation=Numeric.multiply.outer(space1.GetIdentityRepresentation(),space2.GetIdentityRepresentation())
		# Giving the representation the shape n1,n2,n1',n2'
		representation=Numeric.swapaxes(representation,1,2)
		return representation

	def GetGradientRepresentation(self):
		# Formed by the tensorproduct of the subspaces:
		# G1(n1,n1',n2,n2')=G(n1,n1')xI(n2,n2') -> G1(n1,n2,n1',n2')
		# G2(n1,n1',n2,n2')=I(n1,n1')xG(n2,n2') -> G2(n1,n2,n1',n2')
		space1,space2=self.GetSubSpaces()
		# Derivative first axis: Picking 0'th argument from gradient
		gradient1=Numeric.multiply.outer(space1.GetGradientRepresentation()[0],space2.GetIdentityRepresentation())
		# Giving the representation the shape n1,n2,n1',n2'
		gradient1=Numeric.swapaxes(gradient1,1,2)
		# Derivative second axis: Picking 0th argument from gradient
		gradient2=Numeric.multiply.outer(space1.GetIdentityRepresentation(),space2.GetGradientRepresentation()[0])
		# Giving the representation the shape n1,n2,n1',n2'
		gradient2=Numeric.swapaxes(gradient2,1,2)
		return Numeric.array((gradient1,gradient2))
				
	def GetLaplaceRepresentation(self,metric=None):
		# If metric is None: Define unit metric
		if metric is None:
			metric=Numeric.array([[1.0,0.0],[0.0,1.0]])
		inv_metric=LinearAlgebra.inverse(Numeric.asarray(metric))

		# Laplace calculated by:
		# Laplace*f=sum_{nm} ginv_{nm} d/dxn*d/dxm
		# This **only** works for scaled vector spaces
		space1,space2=self.GetSubSpaces()
		# First doing the diagonals: 
		# Formed by the tensorproduct of the subspaces:
		# L1(n1,n1',n2,n2')=L(n1,n1')xI(n2,n2') 
		# L2(n1,n1',n2,n2')=I(n1,n1')xL(n2,n2') 
		# Second derivative of first axis
		laplace=Numeric.multiply.outer(inv_metric[0,0]*space1.GetLaplaceRepresentation(),space2.GetIdentityRepresentation())
		# Adding the second derivative of the second axis
		Numeric.add(laplace,Numeric.multiply.outer(space1.GetIdentityRepresentation(),inv_metric[1,1]*space2.GetLaplaceRepresentation()),laplace)

		# Then doing the off-diagonal terms (if necessary)
		# Formed by the tensorproduct of the subspaces:
		# L_off(n1,n1',n2,n2')=G(n1,n1')xG(n2,n2'),
		# where G(ni,ni') is the gradient representation
		# Multiplying by two since the two off-diagonals give the same
		if inv_metric[0,1]!=0:
			laplace=Numeric.multiply.outer(2*inv_metric[0,1]*space1.GetGradientRepresentation()[0],space2.GetGradientRepresentation()[0])+laplace
		
		# Finally: L(n1,n1',n2,n2')->L(n1,n2,n1',n2')
		return Numeric.swapaxes(laplace,1,2)

	def InnerProduct(self,state1,operatorrepresentation,state2):
		# Using BLAS routine
		multiply=ArrayTools.MatrixMultiplication
		# Statei assumed to be of the form: statei(a,dim1,dim2)
		# operator(dim1,dim2,dim1,dim2)->operator(dim1xdim2,dim1xdim2)
		operatormatrix=MatrixFromBlockForm(operatorrepresentation,self.GetDimensions())
		dim12=Numeric.multiply.reduce(self.GetDimensions())
		# state1(a,dim1,dim2) -> state1^*(a,dim1xdim2)
		state1_shape=state1.shape
		state1=Numeric.conjugate(Numeric.reshape(state1,state1_shape[:-2]+(dim12,)))
		# state2(b,dim1,dim2) -> state2(b,dim1xdim2)
		state2_shape=state2.shape
		#state2=CyclicPermutation(Numeric.reshape(state2,state2_shape[:-2]+(dim12,)))
		state2=Numeric.reshape(state2,state2_shape[:-2]+(dim12,))
		# Finally doing the scalar product:
		# state1^*(a,dim1xdim2)xoperator(dim1xdim2,dim1xdim2)x
		# state2(b,dim1xdim2)^{T}
		return multiply(state1,multiply(operatormatrix,state2,btype='Transpose'))

class PlaneWaveSpace2D(PlaneWaveSpace,FunctionSpace2D):

	def __init__(self,dimensions):
		self.SetDimensions(dimensions)
		self.SetSubSpaces((PlaneWaveSpace1D(dimensions[0]),PlaneWaveSpace1D(dimensions[1])))

	def GetCoordinateBasis(self):
		space1,space2=self.GetSubSpaces()
		dimension1,dimension2=self.GetDimensions()
		# Creating basis for G1
		basis1=Numeric.add.outer(space1.GetCoordinateBasis(),Numeric.zeros((dimension2,)))
		# Creating basis for G2
		basis2=Numeric.add.outer(Numeric.zeros((dimension1,)),space2.GetCoordinateBasis())
		# Combining
		return Numeric.array((basis1,basis2))

	def Identity_InnerProduct(self,state1,state2):
		# Using BLAS routine:
		multiply=ArrayTools.MatrixMultiplication
		# Fast implementation of the identity:
		# Input: state1(a,dim1,dim2) and state2(b,dim1,dim2)
		dim12=Numeric.multiply.reduce(self.GetDimensions())
		# Transforming state1(a,dim1,dim2)->state1^*(a,dim1*dim2)
		state1_shape=state1.shape
		state1=Numeric.conjugate(Numeric.reshape(state1,state1_shape[:-2]+(dim12,)))
		# Transforming state2(b,dim1,dim2)->state2(b,dim1*dim2)
		state2_shape=state2.shape
		state2=Numeric.reshape(state2,state2_shape[:-2]+(dim12,))
		#state2=CyclicPermutation(Numeric.reshape(state2,state2_shape[:-2]+(dim12,)))
		# Inner product = state1^*(a,dim1*dim2)xstate2(b,dim1*dim2)^T
		return multiply(state1,state2,btype="Transpose")

	def Gradient_InnerProduct(self,state1,state2):
		# Fast implementation of the gradient:
		dimensions=self.GetDimensions()
		basis=self.GetCoordinateBasis()
		# Calculating the reciprocal vectors: G_i=2pi/L_i
		prefactor1=2*Numeric.pi*complex(0,1)/dimensions[0]
		prefactor2=2*Numeric.pi*complex(0,1)/dimensions[1]
		# Using <state1|gradient_i|state2>=<state1|state2_diff_i>
		# where state2_diff_i = gradient_ixstate2
		#                    = diagonal(gradient_i)*state2
		state2_diff1=Numeric.multiply(prefactor1*basis[0],state2)
		state2_diff2=Numeric.multiply(prefactor2*basis[1],state2)
		innerproduct1=self.Identity_InnerProduct(state1,state2_diff1)
		innerproduct2=self.Identity_InnerProduct(state1,state2_diff2)
		return Numeric.array([innerproduct1,innerproduct2])

	def Laplace_InnerProduct(self,state1,state2,metric=None):
		# Fast implementation of the inner product for the laplacean
		# Using that g^mn = inverse metric:
		# <state1|laplace|state2> =
		# sum_mn <state1|g^mn d^2/(dx_m * dx_n)|state2> =
		# sum_mn g^mn <state1|diff_mn_state2> where
		# diff_mn_state2 = d^2/(dx_m * dx_n)xstate2
		#                = diagonal(d^2/(dx_m*dx_n))*state2
		# Now diagonal(d^2/(dx_m*dx_n)) = -G_mG_n

		# If metric is None: Define the metric:
		if metric is None:
			metric=Numeric.identity(2).astype(Numeric.Float)

		# Defining necessary quantities:	
		inv_metric=LinearAlgebra.inverse(Numeric.asarray(metric))
		dimensions=self.GetDimensions()
		prefactor1=2*Numeric.pi/dimensions[0]
		prefactor2=2*Numeric.pi/dimensions[1]
		basis1=prefactor1*self.GetCoordinateBasis()[0]
		basis2=prefactor2*self.GetCoordinateBasis()[1]
		
		# First doing the diagonal elements d^2/dxi^2, i=1,2
		laplace_diagonal=-inv_metric[0,0]*Numeric.power(basis1,2)
		Numeric.subtract(laplace_diagonal,inv_metric[1,1]*Numeric.power(basis2,2),laplace_diagonal)

		# Then do the off-diagonal elements (if necessary):
		# **NOTE** : Metric must be symmetric
		if inv_metric[0,1]!=0:
			Numeric.subtract(laplace_diagonal,2*inv_metric[0,1]*Numeric.multiply(basis1,basis2),laplace_diagonal)

		# Finally find diff_state2 and do the inner product
		diff_state2=Numeric.multiply(laplace_diagonal,state2)
		return self.Identity_InnerProduct(state1,diff_state2)

									
	def GetLocalPotentialRepresentation(self,potential):
		# Using: <G'|V(r)|G>=1/sqrt(A)*V(G'-G)
		normalization=Numeric.sqrt(Numeric.multiply.reduce(self.GetDimensions()))
		
		length1,length2=self.GetDimensions()
		space1,space2=self.GetSubSpaces()
		# The Fourier transformed potential
		potential_G=self.CoordinatesFromFunctionValues(potential)
		# Scaling:
		Numeric.divide(potential_G,normalization,potential_G)
		
		# Defining the output potential of shape (n1,n2,n1',n2')
		potentialmatrix=Numeric.zeros((self.GetDimensions()+self.GetDimensions()),Numeric.Complex)
		basis1=space1.GetCoordinateBasis()
		basis2=space2.GetCoordinateBasis()
		# Inserting the columns of V(G'1,G'2,G1,G2) in the basis:
		# Gi = {0,-1,-2,...-length/2+1,length/2,length/2-1,...2,1}
		for i in range(length1):			
			index1=basis1[i]
			# Slicing the potential as: [G1:] + [:G1]
			potential_G1=Numeric.concatenate((potential_G[index1:],potential_G[:index1]),axis=0)
			for j in range(length2):
				index2=basis2[j]
				# Slicing in the second coordinate:
				# [G2:] + [:G2]
				matrixpotential=Numeric.concatenate((potential_G1[:,index2:],potential_G1[:,:index2]),axis=1)
				# This potential is inserted in the
				# "column" of V(G'-G)
				potentialmatrix[:,:,i,j]=matrixpotential
		return potentialmatrix

	def LocalPotential_InnerProduct(self,state1,state2,potential):
		# Using that: <state1|V|state2>=<state1|pot_state2>,
		# where pot_state2=InvF(V(r)*state2(r))

		# Calculating pot_state2
		state2_realspace=self.FunctionValuesFromCoordinates(state2)
		pot_state2_realspace=Numeric.multiply(potential,state2_realspace)
		pot_state2=self.CoordinatesFromFunctionValues(pot_state2_realspace)
		# Finally doing <state1|pot_state2>
		return self.Identity_InnerProduct(state1,pot_state2)

	def Translation_Operate(self,state,coordinates):
		space1,space2=self.GetSubSpaces()
		translation=Numeric.multiply.outer(space1._GetTranslationDiagonal([coordinates[0]]),space2._GetTranslationDiagonal([coordinates[1]]))
		return Numeric.multiply(translation,state)

	def CoordinatesFromFunctionValues(self,functionvalues):
		# In general: Using basis functions of the form:
		# 1/sqrt(N1*N2)*exp(iG_dot_r)
		normalization=Numeric.sqrt(Numeric.multiply.reduce(self.GetDimensions()))
		# Functionvalues of the form (a,x_i,x_j), where
		# a is different functions and x_i,x_j the function values
		# Return array of the form (a,G_i,G_j)
		
		# The two last axes are transformed
		coordinates=FFT.inverse_fft(FFT.inverse_fft(functionvalues,n=functionvalues.shape[-1],axis=-1),n=functionvalues.shape[-2],axis=-2)
		# Scaling:
		Numeric.multiply(coordinates,normalization,coordinates)
		return coordinates

	def FunctionValuesFromCoordinates(self,coordinates):
		# In general: Using basis functions of the form:
		# 1/sqrt(N1*N2)*exp(iG_dot_r)
		normalization=Numeric.sqrt(Numeric.multiply.reduce(self.GetDimensions()))

		# Coordinates of the form (a,G_i,G_j), where
		# a is different functions and G_i,G_j the coordinates

		# Only the last axis is transformed:
		functionvalues=FFT.fft2d(coordinates,coordinates.shape[-2:],axes=(-2,-1))
		# Scaling:
		Numeric.divide(functionvalues,normalization,functionvalues)
		return functionvalues

	def Operate(self,state,operatorrepresentation):
		# Carries out A*|psi>, where A: operator and |psi> is a state
		# For an orthogonal space this is:
		# A*|psi> = <i|A|j>xpsi(j) 

		# operatorrepresentation(dim1,dim2,dim1',dim2')->
		# operatorrepresentation(dim1*dim2,dim1'*dim2')
		operatorrepresentation=MatrixFromBlockForm(operatorrepresentation,self.GetDimensions())

		# Finding the arrangement of the states:
		stateshape=state.shape[:-2]
		# The state is assumed to have the form: state(a1,..,ai,dim),

		# state(a1,...,ai,dim1,dim2)->state(a1*...*ai,dim1*dim2)
		state=VectorFromBlockForm(state,self.GetDimensions(),vectorshape=Numeric.multiply.reduce(stateshape))
		# state(a1*...*ai,dim1*dim2)->state(dim1*dim2,a1*...*ai)
		state=ReverseCyclicPermutation(state)
		# Then operate_state(dim1*dim2,a1*...*ai)=
		#      operator(dim,dim)xstate(dim,a1*...*ai)
		operate_state=Numeric.matrixmultiply(operatorrepresentation,state)
		# Then return: operate_state(dim,a1*...*ai)->
		#              operate_state(a1*...*ai,dim)
		operate_state=CyclicPermutation(operate_state)
		# ...and finally: operate_state(a1*...*ai,dim1*dim2)->
		#                 operate_state(a1,...,ai,dim1,dim2)
		return BlockFormFromVector(operate_state,self.GetDimensions(),vectorshape=stateshape)

class FunctionSpace3D:

        def __init__(self,dimensions):
                self.SetDimensions(dimensions)

        def SetDimensions(self,dimensions):
                self.dimensions=dimensions


        def GetDimensions(self):
                return self.dimensions

        def SetSubSpaces(self,subspaces):
                self.subspaces=subspaces

        def GetSubSpaces(self):
                return self.subspaces
	
class PlaneWaveSpace3D(FunctionSpace3D,PlaneWaveSpace):

	def __init__(self,dimensions):
		# Initializing subspaces
		space1=PlaneWaveSpace1D(length=dimensions[0])
		space2=PlaneWaveSpace1D(length=dimensions[1])
		space3=PlaneWaveSpace1D(length=dimensions[2])
		self.SetSubSpaces((space1,space2,space3))
		FunctionSpace3D.__init__(self,dimensions=dimensions)

	def GetCoordinateBasis(self):
		space1,space2,space3=self.GetSubSpaces()
		dimension1,dimension2,dimension3=self.GetDimensions()
		# Creating basis for G1
		basis1=Numeric.add.outer(space1.GetCoordinateBasis(),Numeric.zeros((dimension2,dimension3)))
		# Creating basis for G2
		basis2=Numeric.add.outer(Numeric.zeros((dimension1,)),space2.GetCoordinateBasis())
		basis2=Numeric.add.outer(basis2,Numeric.zeros((dimension3,)))
		# Creating basis for G3
		basis3=Numeric.add.outer(Numeric.zeros((dimension1,dimension2)),space3.GetCoordinateBasis())
		# Combining
		return Numeric.array((basis1,basis2,basis3))

	def Translation_Operate(self,state,coordinates):
		space1,space2,space3=self.GetSubSpaces()
		translation=Numeric.multiply.outer(space1._GetTranslationDiagonal([coordinates[0]]),space2._GetTranslationDiagonal([coordinates[1]]))
		translation=Numeric.multiply.outer(translation,space3._GetTranslationDiagonal([coordinates[2]]))
		return Numeric.multiply(translation,state)

	def CoordinatesFromFunctionValues(self,functionvalues):
		# In general: Using basis functions of the form:
		# 1/sqrt(N1*N2*N3)*exp(iG_dot_r)
		normalization=Numeric.sqrt(Numeric.multiply.reduce(self.GetDimensions()))
		# Functionvalues of the form (a,x_i,x_j,x_k), where
		# a is different functions and x_i,x_j,x_k the function values
		# Return array of the form (a,G_i,G_j,G_k)
		
		# The three last axes are transformed
		coordinates=FFT.inverse_fft(FFT.inverse_fft(FFT.inverse_fft(functionvalues,n=functionvalues.shape[-1],axis=-1),n=functionvalues.shape[-2],axis=-2),n=functionvalues.shape[-3],axis=-3)
		# Scaling:
		Numeric.multiply(coordinates,normalization,coordinates)
		return coordinates

	def FunctionValuesFromCoordinates(self,coordinates):
		# In general: Using basis functions of the form:
		# 1/sqrt(N1*N2*N3)*exp(iG_dot_r)
		normalization=Numeric.sqrt(Numeric.multiply.reduce(self.GetDimensions()))

		# Coordinates of the form (a,G_i,G_j,G_k), where
		# a is different functions and G_i,G_j,G_k the coordinates

		# Only the last axis is transformed:
		functionvalues=FFT.fft(FFT.fft2d(coordinates,coordinates.shape[-2:],axes=(-2,-1)),n=coordinates.shape[-3],axis=-3)
		# Scaling:
		Numeric.divide(functionvalues,normalization,functionvalues)
		return functionvalues






