"""Module for KPoints

This module contains the base classes and functions for generating and
represennting k points. 
"""

from ASE.Utilities import Vector
from UserList import UserList
import Numeric
import copy

# Modules in use:
# Scientific.IO.NetCDF

def MonkhorstPackKPointGridGenerator(nkpoints,unitcell):
    """Function for generating Monkhort-Pack k points.

    This function can be used for generating a list of special k points in
    the Brillouin zone according to the Monkhorst-Pack scheme, see Phys. Rev. B
    13, 5188 (1976). The number of k point along each direction of the
    reciprocal cell must be specified along with the unit cell (in real-space).

    The output generated by this function is a list of instances of the class
    'KPoint'. The space of the k points will be given in terms of the basis
    of the reciprocal cell.

    **An example**

    To generate a list containing 12x12x12 Monkhorst-Pack special k points in
    a fcc lattice write:

    'kpointlist=MonkhortPackKPointGridGenerator(nkpoints=(12,12,12),unitcell=fcclattice)'

    where 'fcclattice' is an instance of 'BravaisLattice' representing the
    (real-space) fcc lattice vectors. 
    """
    # Number of kpoints along each direction:
    N1,N2,N3=nkpoints
    NKpoints=Numeric.multiply.reduce(nkpoints)
    # GENERATING THE MONKHORSTPACK GRID IN UNITS OF THE RECIPROCAL UNIT CELL:
    # Each point is calculated according to:
    # k_i=-1/2+(2*N_i)^{-1}+i/N_i,
    # where N_i are the number of k points along direction i.
    # Initializing output object:
    reclattice=unitcell.GetReciprocalBravaisLattice()
    listofkpoints=[]
    for n1 in range(N1):
        kp1=float(1+2*n1-N1)/(2*N1)
        for n2 in range(N2):
            kp2=float(1+2*n2-N2)/(2*N2)
            for n3 in range(N3):
                kp3=float(1+2*n3-N3)/(2*N3)
                listofkpoints.append(KPoint(coordinates=(kp1,kp2,kp3),space=reclattice))
    return listofkpoints

def ChadiCohenKPointGridGenerator(type,unitcell):
    """Function for generating Chadi-Cohen k points.

    This function can be used to generate a list of special k points in the
    Brillouin zone according to the Chadi-Cohen scheme, see Phys. Rev. B
    8, 5747 (1973). The type of Chadi-Cohen k point grid along with the unit
    cell (in real-space) must be specified.

    The following types of Chadi-Cohen k point grids are recognized:

    * type=(6,"1x1"): 6 special k points in 1x1 symmetry.

    * type=(12,"2x3"): 12 special k points in 2x3 symmetry.

    * type=(18,"sq3xsq3"): 18 special k points in sq(3)xsq(3) symmetry.

    * type=(18,"1x1"): 18 special k points in 1x1 symmetry.

    * type=(54,"sq3xsq3"): 54 special k points in sq(3)xsq(3) symmetry.

    * type=(54,"1x1"): 54 special k points in 1x1 symmetry.

    * type=(162,"1x1"): 162 special k points in 1x1 symmetry.

    Note that these currently defined k point grids are suitable for slab
    calculations only. 

    The output generated by this function is a list of instances of the class
    'KPoint'. The coordinates of these k points are specified in terms of the
    reciprocal cell. 

    **An example**

    To generate 54 special k points in 1x1 symmetry write:

    'kpointlist=ChadiCohenKPointGridGenerator(type=(54,"1x1"),unitcell=slabunitcell)'

    where 'slabunitcell' is an instance of 'BravaisLatticce' representing the
    unit cell for the slab calculation.
    """
    # Importing grids from database:
    from Structures import ChadiCohenKPointGrids
    # The type is specified according to the following convention:
    # type=(nkpoints,shape), where nkpoints is the number of CC k points and
    # shape (of type string) is the shape of the grid.
    # Mapping type to variable name in netcdf file:
    variablemap={(6,'1x1'):'CC6_1x1',(12,'2x3'):'CC12_2x3',(18,'sq3xsq3'):'CC18_sq3xsq3',(18,'1x1'):'CC18_1x1',(54,'sq3xsq3'):'CC54_sq3xsq3',(54,'1x1'):'CC54_1x1',(162,'1x1'):'CC162_1x1'}
    try: # Is the type of CC grid defined ? 
        variablename=variablemap[type]
    except KeyError:
        raise KeyError, "The specified type of ChadiCohen k point grid is not defined."
    # Retrieving k points from the ChadiCohenKPointGrids module
    kpointgrid=getattr(ChadiCohenKPointGrids,variablename)
    # GENERATING THE LIST OF CHADICOHEN K POINTS FROM THE COORDINATES
    reclattice=unitcell.GetReciprocalBravaisLattice()
    listofkpoints=[]
    for kpoint in kpointgrid:
        listofkpoints.append(KPoint(coordinates=kpoint,space=reclattice))
    return listofkpoints    

def BandLineGenerator(kpointlist,numberofkpoints,unitcell):
    """Function for generating a bandline of k points

    This function can be used to generate a bandline of k points between two
    (or several) symmetry points in the Brillouin zone.

    The output generated by this function is a list of instances of the class
    'KPoint'. The coordinates of these k points will be given in terms of the
    reciprocal lattice vectors of the specified unit cell.

    **Examples**

    To generate 100 k points between the Gamma point and the X point in a fcc
    lattice write

    'kpointlist=BandLineGenerator(kpointlist=(Gamma,Xpoint),numberofkpoints=100,unitcell=fcclattice,)'

    where 'Gamma' is a vector for the Gamma point and 'Xpoint' is a vector for
    X (both instances of class 'Vector'). 'fcclattice' is an instance of
    'BravaisLattice' representing the lattice vectors of an fcc lattice.

    To generate 50 k points between the Gamma point and the X point and 30 k
    points between the X point and the W point in a fcc lattice write

    'kpointlist=BandLineGenerator(kpointlist=(Gamma,Xpoint,Wpoint),numberofkpoints=(50,30),unitcell=fcclattice)'

    where 'Gamma', 'Xpoint', and 'Wpoint' are again instances of 'Vector'
    representing the positions of the symmetry points. Note also that the
    number of k points in this case will be 79 since the X point only appears
    once.
    """
    # Is there more than a initial and final k point ?
    if len(kpointlist)>2: # Yes, PROPAGATE THE GENERATION
        newkpointlist=BandLineGenerator(kpointlist[0:2],numberofkpoints[0],unitcell)
        otherkpointlist=BandLineGenerator(kpointlist[1:],numberofkpoints[1:],unitcell)
        # Combine the results:
        # **NOTE** slicing the list to aviod double k points
        newkpointlist.extend(otherkpointlist[1:])
        return newkpointlist
    else: # No, DO THE GENERATION
        # Casting numberofkpoints to correct type (int)
        if type(numberofkpoints) is not type(1):
            numberofkpoints=numberofkpoints[0]
        # Finding relevant variables
        reclattice=unitcell.GetReciprocalBravaisLattice()
        kpbegin=kpointlist[0].GetCartesianCoordinates()
        kpfinal=kpointlist[1].GetCartesianCoordinates()
        step=(kpfinal-kpbegin)/(numberofkpoints-1)
        kpointlist=[]
        # STARTING K POINT GENERATION
        for nkp in range(numberofkpoints):
            newkpoint=KPoint(space=reclattice)
            newkpoint.SetCartesianCoordinates(kpbegin+step*nkp)
            kpointlist.append(newkpoint)
        return kpointlist    

class KPoints(UserList):
    """Base class for representing a collection of k points

    This class can be used to represent a collection of k points. It behaves
    as a list and thus supports slicing operations and the addition of
    different instances of 'KPoints'. This class also supports visualization
    with VTK.

    To create an instance of this class write:

    'kpoints=KPoints(kpointlist=kpointlist,unitcell=mycell)'

    where 'kpointlist' is a sequence of k points. These k points are in turn
    expected to be instances of 'KPoint'. 'mycell' must be an instance of
    'BravaisLattice' (defined in 'Structures.VectorSpaces').
    """
    
    def __init__(self,kpointlist=None,unitcell=None):
        UserList.__init__(self,kpointlist)
        if unitcell is not None:
            self.SetUnitCell(unitcell)

    def __getslice__(self, i, j):
        """Reimplemented from UserList"""
        # Method overloaded from UserList:
        # Ensures that all attributes are copied to the new object
        i = max(i, 0); j = max(j, 0)
        slice=copy.copy(self)
        slice.data=self.data[i:j]
        return slice

    def __add__(self,other):
        """Reimplemented from UserList"""
        # Method overloaded from UserList:
        # Ensures that all attributes are copied to the new object +
        # To add the unit cells must be identical and the objects must be
        # the same class:
        if self.__class__==other.__class__:
            if self.GetUnitCell()==other.GetUnitCell():
                newkpoints=copy.copy(self)
                newkpoints.data.extend(other.data)
                return newkpoints
            else:
                raise ValueError, "The two k point collections must be defined with the same unit cell."
        else:
            raise ValueError, "It is only possible to add KPoints of the same class."

    def SetKPointList(self,kpointlist):
        """Sets a list of k points

        Set the list of k points. This method is equivalent to the statement

        'mykpoints.data=list(kpointlist)'

        where 'mykpoints' is an instance of 'KPoints'. 'kpointlist' must be a
        sequence containing instances 'KPoint'.
        """
        self.data=list(kpointlist)

    def GetKPointList(self):
        """Returns the k point list"""
        return self.data

    def GetNumberOfKPoints(self):
        """Returns the number of k points"""
        return len(self.kpointlist)

    def SetUnitCell(self,unitcell):
        """Sets the unit cell

        The unit cell is expected to be an instance of the class
        'BravaisLattice' defined in 'Structures.VectorSpaces'.
        """
        self.unitcell=unitcell

    def GetUnitCell(self):
        """Returns the unit cell"""
        return self.unitcell

    def GetReciprocalUnitCell(self):
        """Returns the reciprocal cell

        This method will return the reciprocal cell of the specified unitcell.
        It will be an instance of 'BravaisLattice'.
        """
        return self.unitcell.GetReciprocalBravaisLattice()

    def GetSpace(self):
        """Returns the space for the collection of k points.

        This method is identical to the method 'GetReciprocalUnitCell'.
        """
        return self.GetReciprocalUnitCell()

    def GetCartesianCoordinates(self):
        """Returns the cartesian coordinates of the k points"""
        coordinates=[]
        for kpoint in self:
            coordinates.append(kpoint.GetCartesianCoordinates())
        return coordinates    

    def GetScaledCoordinates(self):
        """Returns the scaled coordinates of k points

        These coordinates are given in units of the reciprocal lattice. 
        """
        space=self.GetSpace()
        cartesiancoor=self.GetCartesianCoordinates()
        scaledcoordinates=[]
        for coor in cartesiancoor:
            scaledcoordinates.append(space.CoordinatesFromCartesianCoordinates(coor))
        return scaledcoordinates    

    def GetKPointWeights(self):
        """Returns a list with the k point weights.

        Note that the k point weights may not always be defined. 
        """
        return map(lambda kp:kp.GetKPointWeight(),self)
            
    def GetVTKAvatar(self,parent=None,**keywords):
        """Returns a vtkavatar visualizing the k points.

        This avatar will be an instance of 'vtkKPoints' (defined in
        'Visualization.Avatars,vtkKPoints') where the individual k points are
        visualized as spheres in the reciprocal space. For more information,
        see the documentation for this class.
        """
        from Visualization.Avatars.vtkKPoints import vtkKPoints
        return apply(vtkKPoints,[self,parent],keywords)
            
            
class KPoint(Vector.Vector):
    """Implements a k point

    This class represents the coordinates of a k point. It implements the
    functionality of the class 'Vector' through inheritance, but may in
    addition also be specified with a k point weight. Note, that this quantity
    may not always be present (or even be well-defined). 
    """

    def __init__(self,coordinates=None,space=None,kpointweight=None):
        Vector.Vector.__init__(self,coordinates,space)
        if kpointweight is not None:
            self.SetKPointWeight(kpointweight)

    def SetKPointWeight(self,kpointweight):
        """Sets the k point weight

        This method can be used to associate a weight to a k point.
        This weight can be used when the Brillouin zone is integrated over a
        discrete set of k points. Note, that this quantity may not be
        well-defined for all concrete implementations of this class.
        """
        self.weight=kpointweight

    def GetKPointWeight(self):
        """Returns the k point weight"""
        return self.weight
