#include "definitions.h"
module fftdimensions

contains 

!   entry   : setup_fft_grid
!   syntax  : call setup_fft_grid(recc,ecut_soft,ngxs,ngys,ngzs,
!                                 ecut_dense,ngx,ngy,ngz,lmastr)        
!   purpose : 
!         Return 
!             ngxs= netCDF variable softgrid_dim1
!             ngys= netCDF variable softgrid_dim2
!             ngzs= netCDF variable softgrid_dim3
!             ngx = netCDF variable hardgrid_dim1
!             ngy = netCDF variable hardgrid_dim2
!             ngz = netCDF variable hardgrid_dim3
!         softgrid_dim/hard_grid_dim is compared to the minimum fft grid required 
!         depending on the 
!              cutoffs             (netCDF variable PlaneWaveCutoff/Density_WaveCutoff)
!              reciprocal unitcell (recc) 
!              fft routine         (the module holds a list of available fft transform 
!                                   lengths) 
!         If the minimum grid is larger than the grid given in softgrid_dim/hardgrid_dim 
!         the program exits. 
!         If softgrid_dim/hardgrid_dim is not present the minimum grid is 
!         returned (cold start).  
! 
!         On return the netCDF dimensions softgrid_dim and hardgrid_dim are defined.

!         
!-----------------------------------------------------------------------------
subroutine setup_fft_grid(recc,ecut_soft,ngxs,ngys,ngzs,ecut_dense,ngx,ngy,ngz,lmastr)
use netcdfinterface
use run_context
#ifdef PARAL 
use par_functions_module
#endif
implicit none 

! recc  (in)           :  reciprocal unitcell
! ecut_soft (out)      :  planewave cutoff 
! ecut_dense (out)     :  dense grid cutoff 
! ngxs,ngys,ngzs (out) :  soft  fft grid returned
! ngx,ngy,ngz (out)    :  dense fft grid returned
! lmastr (in)          :  true is this is the master node (allways true for serial program)

real*8 , intent(in)  :: recc(3,3)
real*8 , intent(out) :: ecut_soft,ecut_dense  
integer, intent(out) :: ngx,ngy,ngz,ngxs,ngys,ngzs
logical*4,intent(in) :: lmastr
 
! locals 
integer :: grid(3),softgrid(3),mingrid_soft(3),mingrid_dense(3)
integer :: hardgrid(3),i,status(6),stat,ncid,nOK
logical :: found_all,found_one,error
#include "ms.h"

if (lmastr) then 

! open netCDF file
stat = nf_open(netCDF_input_filename,NF_NOWRITE, ncid )
if (stat /= nf_noerr) then 
   write(nconso,*) 'FFT: Could not open NetCDF file'
   call clexit(nconso) 
endif

! first read ecut from the netCDF variable PlaneWaveCutoff
stat = nfget(ncid,'PlaneWaveCutoff',ecut_soft)
if (stat/=nf_noerr) then 
  write(nconso,*) 'netCDF read error: PlaneWaveCutoff must exist'
  call clexit(nconso)
endif

write(nconso,'(1x,a,f12.4,a)') &
  'PAD: ecut (wave function cutoff) = ', ecut_soft, ' eV'    
 
! check that ecut is resonable
if ((ecut_soft.lt.1.0d0).or.(ecut_soft.gt.2000.0d0)) then 
  write(nconso,*) 'FFT:  netCDF PlaneWaveCutoff value out of range'
  write(nconso,*) 'FFT:  the program will stop' 
  call clexit(nconso)
endif

!  read density grid ecut from the netCDF variable Density_WaveCutoff
stat = nfget(ncid,'Density_WaveCutoff',ecut_dense)
if (stat/=nf_noerr) then
  write(nconso,*) 'FFT: Double grid not used'
  ecut_dense = ecut_soft 
else 
  write(nconso,*) 'FFT: Double grid used' 
endif
  
! check that ecut_dense is resonable
if ((ecut_dense.lt.1.0d0).or.(ecut_dense.gt.2000.0d0)) then
  write(nconso,*) 'FFT:  netCDF Density_WaveCutoff value out of range'
  write(nconso,*) 'FFT:  the program will stop'
  call clexit(nconso)
endif           

write(nconso,'(1x,a,f12.4,a)') &
  'PAD: ecut (wave function cutoff) = ', ecut_soft, ' eV'
write(nconso,'(1x,a,f12.4,a)') &
  'PAD: ecut (dense grid cutoff)    = ', ecut_dense, ' eV'

! get the minium required grid soft
call min_fft_grid(recc,ecut_soft,'WaweFct',grid,nconso)

! find the fft number >= min grid just found 
do i = 1,3
  mingrid_soft(i) = fft_number(grid(i)) 
enddo 

! get the minium required dense grid
call min_fft_grid(recc,ecut_dense,'Density',grid,nconso)

! find the fft number >= min grid just found
do i = 1,3
  mingrid_dense(i) = fft_number(grid(i))
enddo     

! read softgrid_dim and hardgrid_dim
status(1) = nfgetglobaldim(ncid,'softgrid_dim1',softgrid(1))
status(2) = nfgetglobaldim(ncid,'softgrid_dim2',softgrid(2))
status(3) = nfgetglobaldim(ncid,'softgrid_dim3',softgrid(3))
status(4) = nfgetglobaldim(ncid,'hardgrid_dim1',hardgrid(1))
status(5) = nfgetglobaldim(ncid,'hardgrid_dim2',hardgrid(2))
status(6) = nfgetglobaldim(ncid,'hardgrid_dim3',hardgrid(3))

! close netCDF file
stat = nf_close(ncid )
if (stat /= nf_noerr) then 
     write(nconso,*) 'FFT: Could not close NetCDF file'
     call clexit(nconso) 
endif

found_one = .false.
do i = 1,6
    if (status(i)==nf_noerr) found_one = .true. 
enddo

found_all = .true.
do i = 1,6 
  if (status(i)/=nf_noerr) found_all = .false.
enddo

call uflush(nconso)
if (found_all) then 
    write(nconso,*) 'FFT: netCDF softgrid_dim found'
    write(nconso,10) softgrid(1),softgrid(2),softgrid(3)
    write(nconso,*) 'FFT: netCDF hardgrid_dim found'
    write(nconso,11)  hardgrid(1),hardgrid(2),hardgrid(3)
elseif (found_one) then  
    write(nconso,*) 'FFT: All 6 softgrid_dims and hardgird_dims not set correctly (',(status(i),i=1,6),')'
!   add_warning istedet.
    write(nconso,*) 'FFT: The program will stop'
    call clexit(nconso)
endif
10 format(1x,'FFT: ngx = ',i3 ,' ngy = ',i3,' ngz = ',i3 )
11 format(1x,'FFT: ngxhard = ',i3 ,' ngyhard = ',i3,' ngzhard = ',i3 )

if (found_all) then 


  ! check grid 
  error = .false.
  do i = 1,3 
     if (softgrid(i)<mingrid_soft(i)) error = .true.
  enddo

  if (error) then 
    ! only write warning here
    write(nconso,*) 'FFT: Warning: Too small fft grid specified in softgrid_dim'
    write(nconso,100) (mingrid_soft(i),i=1,3)
100 format('FFT: Minimum required grid is (',i3,'x',i3,'x',i3,')')
    ! write(nconso,*) 'FFT: The program will stop'
    ! call clexit(nconso)
  endif

  ! now check that the given grid also is a valid fft number for the FFT routine used 
  error = .false.
  do i = 1,3
    grid(i) = fft_number(softgrid(i))
    if (grid(i)/=softgrid(i)) error = .true. 
  enddo 
  if (error) then 
!   only write warning here
    write(nconso,*) 'FFT: Warning : Grid specified in softgrid_dim is not optimal for the FFT routine used'
    write(nconso,105) (grid(i),i=1,3)
105 format('FFT: next optimal grid is (',i3,'x',i3,'x',i3,')')
  endif
  
  ngxs = softgrid(1)
  ngys = softgrid(2)
  ngzs = softgrid(3)
  ngx  = hardgrid(1)
  ngy  = hardgrid(2)
  ngz  = hardgrid(3)

else 

!  use minimum grid 
  write(nconso,110)  (mingrid_soft(i),i=1,3)
110 format(' FFT: Using default minimum soft grid (',i3,'x',i3,'x',i3,')')
  ngxs = mingrid_soft(1)
  ngys = mingrid_soft(2)
  ngzs = mingrid_soft(3)
  write(nconso,111)  (mingrid_dense(i),i=1,3)
111 format(' FFT: Using default minimum dense grid (',i3,'x',i3,'x',i3,')')
  ngx = mingrid_dense(1)
  ngy = mingrid_dense(2)
  ngz = mingrid_dense(3)

endif

endif ! lmastr

! write dimensions defined to the output netCDF file


if (lmastr) then

  ! open output netCDF file
  stat = nf_open(netCDF_output_filename,NF_WRITE, ncid )
  if (stat /= nf_noerr) stop "nf_open: error in fft_dim"
  
  ! define the softgrid and the hardgrid in the netCDF file
    status(1) = nfputglobaldim(ncid,'softgrid_dim1',ngxs)
    status(2) = nfputglobaldim(ncid,'softgrid_dim2',ngys)
    status(3) = nfputglobaldim(ncid,'softgrid_dim3',ngzs)
    status(4) = nfputglobaldim(ncid,'hardgrid_dim1',ngx)
    status(5) = nfputglobaldim(ncid,'hardgrid_dim2',ngy)
    status(6) = nfputglobaldim(ncid,'hardgrid_dim3',ngz)
  ! all errors are fatal
    do i = 1,6
     if ((status(i)/=nfif_OK ).and.(status(i)/=nfif_dimexist_butOKsize)) then
       write(nconso,*) 'FFT: netCDF interface error writing softgrid_dim/hardgrid_dim (errorcode:',status(i),')'
       write(nconso,*) 'FFT: the program will stop'
       call clexit(nconso)
     endif
    enddo
  
  ! close output netCDF file
    stat = nf_close(ncid )
    if (stat /= nf_noerr) then 
       write(nconso,*) 'FFT: Could not close NetCDF file'
       call clexit(nconso) 
    endif
 
endif  ! lmastr

#ifdef PARAL
! send ecut_soft,ngxs,ngys,ngzs,ecut_dense,ngx,ngy,ngz
call mspack_double_scalar(nconso,ANY,MSG_SETUP,&
        REAL8,ecut_dense,   1, nOK)
call mspack_double_scalar(nconso,ANY,MSG_SETUP,&
        REAL8,ecut_soft,   1, nOK)

call mspack_integer_scalar (nconso,ANY, MSG_SETUP,&
        INTEGER4, ngxs,  1, nOK)
call mspack_integer_scalar (nconso,ANY, MSG_SETUP,&
        INTEGER4, ngys,  1, nOK)
call mspack_integer_scalar (nconso,ANY, MSG_SETUP,&
        INTEGER4, ngzs,  1, nOK)

call mspack_integer_scalar (nconso,ANY, MSG_SETUP,&
        INTEGER4, ngx,  1, nOK)
call mspack_integer_scalar (nconso,ANY, MSG_SETUP,&
        INTEGER4, ngy,  1, nOK)
call mspack_integer_scalar (nconso,ANY, MSG_SETUP,&
        INTEGER4, ngz,  1, nOK)
#endif

end subroutine setup_fft_grid


!=====================================================================
      subroutine min_fft_grid(recc,ecut,grid_description,grid,nconso) 
!=====================================================================
!
!     Determine the limits of the G-grid to be searched.
!     The fft-grid is determined so that the number af G-vectors below 
!     4*enmax (ngdens) is converged, i.e it will not change by increasing 
!     the fft-grid determined. 
!     This is done for the k=0 only. 
!       ecut             : Kinetic energy cutoff for planewaves (eV)
!       recc             : reciprocal space unitcell
!       grid_description : grid id
!
!       grid   : returned min fft grid

!
!=======================================================================          
      implicit none 
      real*8  recc(3,3),ecut  
      character*(*)  grid_description
      integer grid(3) 
      integer nconso
!     locals 
      integer max_search_points(3),trial_search_points
      integer dim1,dim2,ia,ib,ic,ngdens,i,dimension
      real*8  g_vec(3),g_length_sq,g_energy 
      logical*4  any_in_all_dim,any_in_this_dim
      real*8 hsqdtm
      data hsqdtm /3.810033d0/

!     Number of G-vectors for density 
!     0,0,0 is not included in the search below 
      ngdens = 1

!     start with extend 0 for the  fft grid 
      do i = 1,3
         max_search_points(i)=0
      enddo 

!     Add points in one dimension at a time until no more
!     G-vectors found in all three dimensions in a row
      any_in_all_dim = .true.
      do while (any_in_all_dim)
    
        any_in_all_dim = .false.

!       loop over all three dimensions 
        do dimension=1,3
      
          any_in_this_dim=.false.

!         add one in this dimensions 
          trial_search_points = max_search_points(dimension) + 1

!         Loop over the new G-points and see if any are with in the cut off 
!         get the other 2 dimensions 
          dim1 = 2
          dim2 = 3
          if (dimension.eq.2) then 
            dim1 = 1
            dim2 = 3
          endif 
          if (dimension.eq.3) then 
            dim1 = 1
            dim2 = 2
          endif 
       
          do ia=-trial_search_points, trial_search_points,trial_search_points*2

	    do ib=-max_search_points(dim1),max_search_points(dim1)

              do ic=-max_search_points(dim2),max_search_points(dim2)

                  call get_gvector(recc,ia,ib,ic,dimension,dim1,dim2,g_vec)
		  g_length_sq=g_vec(1)**2+g_vec(2)**2+g_vec(3)**2
	          g_energy= hsqdtm * g_length_sq

		  if (g_energy.lt.4*ecut) then 
		      ngdens = ngdens + 1
		      any_in_this_dim=.true.
                  endif

              enddo ! dim1
            enddo   ! dim2
          enddo     ! dim3 

          if (any_in_this_dim) then 
!           accept the new size 
            max_search_points(dimension)=trial_search_points
            any_in_all_dim = .true.
          endif
        enddo  ! dimension
      enddo    ! do while 
     
      max_search_points(1) = max_search_points(1) + 1
      max_search_points(2) = max_search_points(2) + 1
      max_search_points(3) = max_search_points(3) + 1

      write(nconso,10) grid_description,      &
                       2*max_search_points(1),& 
                       2*max_search_points(2),& 
                       2*max_search_points(3)
10    format(1x,"PAD: minimum grid enclosing the sphere G^2 < 4*E_cut(",a7,"): ", &
              i3, "x", i3, "x", i3 )

      grid(1) = 2*max_search_points(1)
      grid(2) = 2*max_search_points(2)
      grid(3) = 2*max_search_points(3)

      write(nconso,*) "PAD: Number of G-vectors (ngdens) below 4*ecut : ", ngdens
      call uflush(nconso)


      end subroutine min_fft_grid

!==========================================================================
      subroutine get_gvector(recc,ia,ib,ic,dim1,dim2,dim3,g_vec) 
!==========================================================================

!     return g_vec = ia*recc(dim1,:) + ib*recc(dim2,:) + ic*recc(dim3,:) 

      implicit none
      real*8 recc(3,3) 
      integer ia,ib,ic,dim1,dim2,dim3
      real*8 g_vec(3) 

!     locals 
      integer i

      do i = 1,3
        g_vec(i) = ia*recc(dim1,i)
      enddo 
      do i = 1,3
        g_vec(i) = g_vec(i) + ib*recc(dim2,i)
      enddo 
      do i = 1,3
        g_vec(i) = g_vec(i) + ic*recc(dim3,i)
      enddo 

      end subroutine get_gvector

!=========================================================================
function fft_number(min_fft)
implicit none
 
!     get_fft_number returns fft_number >= min_fft.
!     A list on efficient fft number (<300) is given for each
!     fft routine used in the program.
!     If the neither preprocessor symbol: FFTW, ESSL,DECALPHA is defined 
!     this function just returns fft_number = min_fft ...

      integer              ::  fft_number 
      integer, intent(in) ::  min_fft
 
!     locals
      integer number_in_list,i

#ifdef FFTW
      parameter(number_in_list= 66)
!     ESSL transform
!            h  i  j  k   m
!     n =   2  3  5  7  11   , h=1,2,..,25; i=0,1,2  ;  j,k,m=0,1
      integer list_of_fft(number_in_list)
      data list_of_fft /2,  4,  6,  8, 10, 12, 14, 16, 18, 20, &
                        22,24, 28, 30,32, 36, 40, 42, 44, 48, &
                        56,60, 64, 66, 70, 72, 80, 84, 88, 90, &
                        96,108,110,112,120,126,128,132,140,144,154, &
                        160,168,176,180,192,198,200, &
                        216,240,264,270,280,288,324,352,360,378,384,400,432, &
                        450,480,540,576,640/
#elif ESSL
      parameter(number_in_list= 46)
!     ESSL transform
!            h  i  j  k   m
!     n =   2  3  5  7  11   , h=1,2,..,25; i=0,1,2  ;  j,k,m=0,1
!     32,64 and 128 should maybe be avoided.
      integer list_of_fft(number_in_list)
      data list_of_fft /2,  4,  6,  8, 10, 12, 14, 16, 18, 20, &
                        22,24, 28, 30,32, 36, 40, 42, 44, 48, &
                        56,60, 64, 66, 70, 72, 80, 84, 88, 90, &
                        96,110,112,120,126,128,132,140,144,154, &
                        160,168,176,180,192,198/

#elif DECALPHA
      parameter(number_in_list= 33)
!     dxml transform
!            h  i  j  k   m
!     n =   2  3  5  7  11   , h=1,2,..,25; i=0,1,2,3,3,3  ;  j,k,m=0,1
      integer list_of_fft(number_in_list)
      data list_of_fft /8,  12,  16, 24, 32, 36, 48, &
                       64, 72, 96,108,128,144,160,176,180,192,200, &
                       216,240,270,280,288,324,360,378,384,400,432, &
                       450,480,540,576/

#elif ULTRA_ZFFT
      parameter(number_in_list= 72)
!     zfft transform
!            h  i  j  k 
!     n =   2  3  5  7      , h=1,2,..,25; i,j=0,1,2,3,4  ;  k=0,1
      integer list_of_fft(number_in_list)
      data list_of_fft /4,  5,  6,  7,  8,  9, 10, 12, 14, 15, &
                        16, 18, 20, 21, 24, 25, 27, 28, 30, 32, &
                        35, 36, 40, 42, 45, 48, 50, 54, 56, 60, &
                        63, 64, 70, 72, 75, 80, 81, 84, 90, 96, &
                        100,105,108,112,120,125,126,128,135,140, & 
                        144,150,160,162,168,175,180,189,192,200, &
                        210,216,224,225,240,250,252,256,270,280, &
                        288,300/

#elif SGI
!     (SGI interface by Sergei Dudiy)     
      parameter(number_in_list= 72)

!     zfft transform
!            h  i  j  k
!     n =   2  3  5  7      , h=1,2,..,25; i,j=0,1,2,3,4  ;  k=0,1
      integer list_of_fft(number_in_list)
      data list_of_fft /4,  5,  6,  7,  8,  9, 10, 12, 14, 15, &
                        16, 18, 20, 21, 24, 25, 27, 28, 30, 32, &
                        35, 36, 40, 42, 45, 48, 50, 54, 56, 60, &
                        63, 64, 70, 72, 75, 80, 81, 84, 90, 96, &
                        100,105,108,112,120,125,126,128,135,140, &
                        144,150,160,162,168,175,180,189,192,200, &
                        210,216,224,225,240,250,252,256,270,280, &
                        288,300/



#else
      parameter(number_in_list= 1)
!
!     Other FFT implementation should make a similar entry here,
!     returning allowed grid sizes
! 
      integer list_of_fft(number_in_list)
      data list_of_fft /0/
      fft_number = min_fft
      return
#endif

         i = 1
         do while ((list_of_fft(i).lt.min_fft).and.(i<number_in_list))
           i = i + 1
         enddo
 
         fft_number = list_of_fft(i)

         end function fft_number

         end module fftdimensions
           
