! Residual minimization method, direct inversion in the iterative 
! subspace 
!   (1) Subspace rotation 
!   (2) RMM-DIIS 
!   (3) Orthonomalization 
! 
#include "definitions.h" 
#define PARAL_ARGS "parallel_args.h"
#define PARAL_DECL "parallel_decl.h"  

       module module_rmm_diss

#ifdef PARAL
       use par_functions_module
#endif PARAL

       logical*4, save             :: rmm_init=.true.
       logical*4, save,allocatable :: init_becp(:)

       contains 

       subroutine rmm_diss(cptwfp, nplwkp,eigen,&
#include "apply_h_args.h"
                          , timer)         

      use netcdfinterface
      use run_context
      use us_hpsi_module 
      use van_us_data_module, only : nkbmax,becp
      implicit none    

#include "apply_h_decl.h"
      complex*WF_PRECISION :: cptwfp(nrplwv,nbands)
      integer nplwkp 
      real*8  eigen(nbands) 
      real*8 timer(*)

      ! locals 
      integer  :: block_size
      complex*WF_PRECISION,allocatable :: Rpsi(:,:),Rpsi1(:,:) 
      complex*WF_PRECISION,allocatable :: spsi(:,:),hpsi(:,:)
      complex*WF_PRECISION,allocatable :: psi1(:,:),spsi1(:,:)
      complex*WF_PRECISION,allocatable :: h_m_vnlpsi(:,:)

      complex*16 hmat(nbands,nbands),smat(nbands,nbands)
      complex*16 U(nbands,nbands)
      complex*16 r11,r22,r12,s11,s22,s12
      complex*16 lambda(2,2),e1,e12,e21,s1
      real*8     eig(2)
      complex*16 smat2(2,2),rmat2(2,2)
      integer    n,i,info,j,n1,nb,nn,nblock,actual_block_size
      integer    n_local
      complex*16 becp_psi(nkbmax,nbands) 
      real*8     deigen(nbands) 
      logical    exists
#ifdef PARAL
      integer offset(par_pw_np),nlocal(par_pw_np),irank
#endif                
      if (rmm_init) then 
        if (.not.allocated(init_becp)) allocate(init_becp(nkprun))
        init_becp(:) = .true. 
        rmm_init = .false.
      endif

#ifdef PARAL 
      ! find lenght of wavefunction array for this process
      call  par_defwfk (nrplwv,nrplwv_global,nplwkp,nlocal,&
                        offset,n_local,&
#include PARAL_ARGS
        , nconso)
      n_global= nplwkp   ! parallel_args
      ! get the rank of this processor
      call par_rank_pw(irank,&
#include PARAL_ARGS
        ,nconso   )
      block_size = par_pw_np
#else 
      n_local = nplwkp 
      block_size = min(nbands,12)
#endif
      allocate(Rpsi(nrplwv,block_size),Rpsi1(nrplwv,block_size))
      allocate(spsi(nrplwv,block_size),hpsi(nrplwv,block_size))
      allocate(psi1(nrplwv,block_size),spsi1(nrplwv,block_size))
      allocate(h_m_vnlpsi(nrplwv,nbands))
 
      h_m_vnlpsi(:,:) = 0.0d0

      ! subspace rotation 

      ! setup <b|H|b> (hmat) and <b|S|b> (smat)
      call setup_bhb_bsb(cptwfp,nplwkp,nbands,hmat,smat,&
#include "apply_h_args.h"
             ,timer,nconso,h_m_vnlpsi=h_m_vnlpsi)
      init_becp(nkp) = .false.

      !  diagonalize
      call time_start(timer,TORTHGN)
      call cdiaghg(nbands,hmat,smat,nbands,eigen,U )     
      call time_stop(timer,TORTHGN)

      ! vector rotate (b(:,:) = b(:,:)*u(:,:)) cptwfp
      call vec_rotate (cptwfp,nrplwv,n_local,nbands,nbands,U,nbands,&
                       timer)      

      ! vector rotate h_m_vnlpsi
      call vec_rotate(h_m_vnlpsi,nrplwv,n_local,nbands,nbands,U,&
                      nbands,timer)      

      ! vector rotate becp 
      call vec_rotate_dcmplx(becp(1,1,nkp),nkbmax,nbands,nbands,U,&
                     nbands,timer)                    

      ! set change of eigenvalues to zero for all bands
      deigen(:) = 0.0d0

      nblock = (nbands-1)/block_size + 1

      ! RMM-DIIS
      ! loop over all blocks with size block_size
      n = 1
      do nb = 1,nblock

        ! find actual blocksize
        actual_block_size = get_actual_block_size(nbands,block_size,nb) 

        ! loop over block to form preconditioned residual
        do nn = 1,actual_block_size

          n1 = n + nn -1

       
          ! get residual, first only (H-Vnl)|cptwfp>
          Rpsi(:,nn) = h_m_vnlpsi(:,n1)

          ! add (Vnl-eS)|psi> (reusing becp) 
          call add_vnl_eS_psi(cptwfp(1,n1),nplwkp,eigen(n1),n1,&
                            Rpsi(1,nn),&
#include                    "apply_h_args.h"
                            ,timer)   

          ! save becp beloging to initial psi
          becp_psi(:,n1) = becp(:,n1,nkp) 
       
          ! precondition residual K(Rpsi) equal to the trial vector psi1
          ! (spsi not used)
          call precondition(.true.,nplwkp,cptwfp(1,n1),nplwkp,spsi,&
            nplwkp,eigen(n1),psi1(1,nn),nplwkp,Rpsi(1,nn),nplwkp,&
#           include "apply_h_args.h"
           ,timer,nconso,info)

        enddo  ! nn = 1,actual_block_size

        call apply_H(psi1,hpsi,spsi1,&
                  nplwkp,actual_block_size,n,.false.,&
#                 include "apply_h_args.h"
                 ,timer)                         

        do nn = 1,actual_block_size

          n1 = n + nn -1

          ! get residual for psi1 in Rpsi1
          call get_residual(psi1(1,nn),spsi1(1,nn),hpsi(1,nn),n_local,&
                       eigen(n1),Rpsi1(1,nn),timer,nconso) 

          ! built 2x2 matrices <Ri|Rj> and <psi_i|S|psi_j>
          r11 = 0.0d0; r22 = 0.0d0 ; r12 = 0.0d0
          s11 = 0.0d0; s22 = 0.0d0 ; s12 = 0.0d0
          s11 = 1.0d0   ! allready S-normalized
#ifdef PARAL
            if (irank.eq.1) then
              s11 = 1.0d0
            else
              s11 = 0.0d0
            endif
#endif                                         
          do i = 1,n_local
            r11 = r11 +  conjg(Rpsi(i,nn))*Rpsi(i,nn)
            r22 = r22 +  conjg(Rpsi1(i,nn))*Rpsi1(i,nn)
            r12 = r12 +  conjg(Rpsi(i,nn))*Rpsi1(i,nn)
            s22 = s22 +  conjg(psi1(i,nn))*spsi1(i,nn)
            s12 = s12 +  cptwfp(i,n1)*conjg(spsi1(i,nn))
!           s12 = s12 +  conjg(psi1(i,nn))*spsi(i,nn)
          enddo 
          rmat2(1,1) = r11 ; rmat2(2,2) = r22
          rmat2(1,2) = r12 ; rmat2(2,1) = conjg(r12)   
          smat2(1,1) = s11 ; smat2(2,2) = s22
          smat2(1,2) = s12 ; smat2(2,1) = conjg(s12)   

          ! parallel sum rmat and smat
#ifdef PARAL 
          call par_sum_complex('A',rmat2,rmat2,2,2,2,&
#include PARAL_ARGS
                               ,timer ) 
          call par_sum_complex ('A',smat2,smat2,2,2,2,&
#include PARAL_ARGS
                               ,timer ) 
#endif

!           do i = 1,2
!             write(nconso,100) 'DIIS-RMAT ',i,(rmat2(i,j),j=1,2)
!           enddo
!           do i = 1,2
!             write(nconso,100) 'DIIS-SMAT ',i,(smat2(i,j),j=1,2),
!     &                          eigen(n1)
!           enddo
100       format(1x,a12,1x,i2,1x,5(f10.4,1x))

          ! solve generalized eigen-value problem
          call cdiaghg(2,rmat2,smat2,2,eig,lambda) 
!          write(nconso,110) 'DIIS-EIG 1',n1,eig(1),(lambda(i,1),i=1,2)
!         write(nconso,110) 'DIIS-EIG 2',n1,eig(2),(lambda(i,2),i=1,2)
110       format(1x,a12,1x,i2,1x,f10.4,1x,4(f10.4,1x))

          ! update wavefunction, using lowest eigenvector
          cptwfp(1:n_local,n1)=lambda(1,1)*cptwfp(1:n_local,n1)&
                             +lambda(2,1)*psi1(1:n_local,nn)
          becp(:,n1,nkp)     = lambda(1,1)*becp_psi(:,n1) + &
                              lambda(2,1)*becp(:,n1,nkp) 

          ! get change in eigen values
          ! <psi1|H|psi1> and <psi|H|psi1>
          e1 = 0.0 ; e12 = 0.0
          do i = 1,n_local
            e1 = e1 + conjg(psi1(i,nn))    *Rpsi1(i,nn) 
            e12=e12 + conjg(cptwfp(i,n1))*Rpsi1(i,nn)
          enddo
          e1 = e1   ! + smat2(2,2)*eigen(n1)
          e12 = e12 ! + smat2(2,1)*eigen(n1) 
          e21 = conjg(e12)
#ifdef PARAL 
          if (irank.eq.1) then 
            deigen(n1) = conjg(lambda(1,1))*lambda(1,1)*eigen(n1)&
                         -eigen(n1)                             +  &
            smat2(2,2)*eigen(n1)*conjg(lambda(2,1))*lambda(2,1) + &
            smat2(2,1)*eigen(n1)*conjg(lambda(1,1))*lambda(2,1) + &
            conjg(smat2(2,1)*eigen(n1))*conjg(lambda(2,1))*lambda(1,1)
          else
             deigen(n1) = 0.0d0
          endif
#else 
          deigen(n1) = conjg(lambda(1,1))*lambda(1,1)*eigen(n1)&
                       - eigen(n1)  + &
            smat2(2,2)*eigen(n1)*conjg(lambda(2,1))*lambda(2,1) + &
            smat2(2,1)*eigen(n1)*conjg(lambda(1,1))*lambda(2,1) + &
            conjg(smat2(2,1)*eigen(n1))*conjg(lambda(2,1))*lambda(1,1)
#endif
          deigen(n1) = deigen(n1)+conjg(lambda(2,1))*lambda(2,1)*e1+ &
                   conjg(lambda(1,1))*lambda(2,1)*e12  + &
                   conjg(lambda(2,1))*lambda(1,1)*e21 

        enddo
        n = n + actual_block_size

      enddo   ! end loop over all bands

      ! parallel sum the changes in the eigenvalues 
#ifdef PARAL 
      call par_sum_double ('A',deigen,deigen,nbands,nbands,1,&
#include PARAL_ARGS
          ,timer ) 
#endif

      ! update eigenvalues in eigen 
      eigen(:) = eigen(:) + deigen(:)

      ! S orthonomalize psi 
       call S_orthonomalize(cptwfp,nplwkp,&
#include "apply_h_args.h"
                           ,timer,nconso) 

      deallocate(Rpsi,Rpsi1,spsi,hpsi,psi1,spsi1,h_m_vnlpsi)

      end subroutine rmm_diss

      ! ------------------------------------------------------------

      subroutine setup_bhb_bsb(psi,nplwkp,bsize,hmat,smat,&
#include "apply_h_args.h"
                               ,timer,nconso,h_m_vnlpsi )

      ! setup <b|H|b> (hmat) and <b|S|b> (smat) for 
      ! block size bsize

      use us_hpsi_module
      use van_us_data_module, only : nkbmax,becp  
      implicit none
#include "apply_h_decl.h"
      integer bsize
      complex*WF_PRECISION psi(nrplwv,bsize)
      integer nplwkp
      complex*16 hmat(bsize,bsize),smat(bsize,bsize)
      real*8 timer(*) 
      integer nconso
      complex*WF_PRECISION, optional :: h_m_vnlpsi(nrplwv,bsize)

      ! locals
      integer n,n1,m,i
      integer n_local,nindex
      complex*16 cdum,sdum
      complex*WF_PRECISION,allocatable :: spsi(:,:),hpsi(:,:),vpsi(:)
      complex*16 zone, zzero
      data       zone, zzero /(1.0d0,0.0d0), (0.0d0,0.0d0)/      
      logical*4  parallel_pw,exists,lhpsi,lspsi,lcalbec
      integer    block_size,nblock,actual_block_size,nb,nn,n2
#ifdef PARAL
      integer offset(par_pw_np)
      integer nlocal(par_pw_np)
#endif     

      call time_start(timer,TSPROD)
      hmat(:,:) = (0.0d0,0.0d0) ; smat(:,:) = (0.0d0,0.0d0)


#ifdef SERIAL
      n_local = nplwkp
#else
!     find lenght of wavefunction array for this process
      call  par_defwfk (nrplwv,nrplwv_global,nplwkp,nlocal,offset,&
                        n_local,&
#include PARAL_ARGS
        , nconso)
#endif                    

      if (present(h_m_vnlpsi)) then 

        allocate(spsi(nrplwv,1),vpsi(nrplwv)) 
   
!       calculate (H-Vnl)|psi> for all bands
        call apply_H(psi,h_m_vnlpsi,spsi,&
                   nplwkp,bsize,1,.true.,&
#                  include "apply_h_args.h"
                  ,timer)                      

        lhpsi = .true.
        lspsi = .true.
!       find out if becp overlapbs should be calculated
        if (rmm_init) then
          lcalbec = .true.
        else
          if (init_becp(nkp)) then
           lcalbec = .true.
          else
           lcalbec = .false.
          endif
        endif                                    
 
!       make <psi|H-Vnl|psi>
        do n = 1,bsize

!         get index into becp array, only relevant then lcalbec = false
          nindex = min(nbands,n)
 
!         add Vnl|psi>
          vpsi(:) = 0.0d0
          call usvnlpsi(spsi,nplwkp,lhpsi,lspsi,lcalbec,nindex,&
#include     "apply_h_args.h"
             ,timer,reci_psi=psi(1,n),reci_vpsi=vpsi(1:n_local))

          do n1 = n,bsize
            cdum = (0.0d0,0.0d0) ; sdum = (0.0d0,0.0d0)
            do m = 1,n_local
              cdum = cdum + conjg(psi(m,n1))*(h_m_vnlpsi(m,n)+vpsi(m))
              sdum = sdum + conjg(psi(m,n1))*(spsi(m,1))
            enddo
            hmat(n,n1)=conjg(cdum) ; hmat(n1,n)=conjg(hmat(n,n1))
            smat(n,n1) = conjg(sdum) ; smat(n1,n)=conjg(smat(n,n1)) 
           enddo
        enddo    ! bsize

      else ! h_m_vnlpsi not present 

#ifdef PARAL
         block_size = par_pw_np
#else
         if (bsize<nbands) then 
           block_size = min(12,bsize)
         else
           block_size = min(12,nbands)
         endif
#endif
         nblock = (bsize-1)/block_size + 1
 
         allocate(hpsi(nrplwv,block_size),spsi(nrplwv,block_size))       

         n = 1
         do nb = 1,nblock
 
            ! find actual blocksize
            actual_block_size = &
                   get_actual_block_size(bsize,block_size,nb)
 
            call apply_H(psi(1,n),hpsi,spsi,&
                         nplwkp,actual_block_size,1,.false.,&
#                        include "apply_h_args.h"
                        ,timer)
 
            do nn = 1,actual_block_size
               n1 = n + nn -1
               do n2 = 1,bsize
                   cdum = (0.0d0,0.0d0) ; sdum = (0.0d0,0.0d0)
                   do m = 1,n_local
                     cdum = cdum + conjg(psi(m,n2))*hpsi(m,nn)
                     sdum = sdum + conjg(psi(m,n2))*spsi(m,nn)
                   enddo
                   hmat(n1,n2)=conjg(cdum) ;  smat(n1,n2)=conjg(sdum)
               enddo
            enddo
            n = n + actual_block_size
 
         enddo    ! bsize
      endif       ! present h_m_vnlpsi

#ifdef PARAL
       call par_sum_complex('N',hmat,hmat,bsize,bsize,bsize,&
#include PARAL_ARGS
             ,timer)
       call par_sum_complex('N',smat,smat,bsize,bsize,bsize,&
#include PARAL_ARGS
             ,timer)
#endif

#ifdef DEBUG
      if (bsize.lt.21) then 
      ! write out hmat and smat 
      write(nconso,*) '---- H ------'
      do n = 1,bsize 
         write(nconso,100) n, (dble(hmat(n,n1)),n1=1,bsize)
      enddo
      do n = 1,bsize 
         write(nconso,100) n, (dimag(hmat(n,n1)),n1=1,bsize)
      enddo
      write(nconso,*) '---- S ------'
      do n = 1,bsize 
         write(nconso,100) n, (dble(smat(n,n1)),n1=1,bsize)
      enddo
      do n = 1,bsize 
         write(nconso,100) n, (dimag(smat(n,n1)),n1=1,bsize)
      enddo
100   format(1x,i3,1x,100(f8.4,1x))
      endif
#endif

      if (allocated(hpsi)) deallocate(hpsi)
      if (allocated(spsi)) deallocate(spsi)
      if (allocated(vpsi)) deallocate(vpsi)

      call time_stop(timer,TSPROD)       

      end subroutine setup_bhb_bsb


      ! ---------------------------------------------------------
      subroutine S_orthonomalize(psi,nplwkp,&
#include "apply_h_args.h"
                                ,timer,nconso) 
      ! Gram-Schmidt S orthonomalization of psi is done using 
      ! Choleski decomposition : 
      !    (1) S = <psi|S|psi> 
      !    (2) S = LU     (zpotrf) 
      !    (3) S-1        (ztrtrs) 
      !    (4) vector rotate  psi = U-1 * psi 

      use us_hpsi_module
      use van_us_data_module, only : nkbmax,becp  
      use non_local_projectors
      implicit none
#include "apply_h_decl.h"
      complex*WF_PRECISION psi(nrplwv,nbands)
      integer nplwkp
      real*8 timer(*)
      integer nconso

!     locals 
      complex*16 smat(nbands,nbands),z(nbands,nbands)
      complex*WF_PRECISION          :: spsiwork(nrplwv)
      
      complex*16 cdum
      logical*4  lhpsi,lspsi,exists
      integer    n,n1,m,info,i
      integer n_local
      logical    parallel_pw
#ifdef PARAL
      integer offset(par_pw_np)
      integer nlocal(par_pw_np)
#endif         
      complex*16 zone, zzero
      data       zone, zzero /(1.0d0,0.0d0), (0.0d0,0.0d0)/               

      call time_start(timer,TORTH)

      lhpsi = .false. ; lspsi = .true. 

#ifdef PARAL
!     find lenght of wavefunction array for this process
      call  par_defwfk (nrplwv,nrplwv_global,nplwkp,nlocal,&
                        offset,n_local,&
#include PARAL_ARGS
        , nconso)
#else
      n_local = nplwkp
#endif

      ! setup <psi|S|psi> 
      do n = 1,nbands 

         if (luse_rs_proj) then 
            call usvnlpsi(spsiwork,nplwkp,lhpsi,lspsi,.false.,n,&
#include          "apply_h_args.h"
                  ,timer,reci_psi=psi(1,n))
         else
            call usvnlpsi(spsiwork,nplwkp,lhpsi,lspsi,.false.,n,&
#include          "apply_h_args.h"
                  ,timer,reci_psi=psi(1,n)) 
         endif

         do n1 = n,nbands 
              cdum = (0.0d0,0.0d0) 
              do m = 1,n_local
                cdum = cdum + conjg(psi(m,n1))*spsiwork(m)
              enddo
              smat(n,n1) = cdum 
              smat(n1,n) = conjg(cdum)
         enddo 

      enddo ! nbands

#ifdef PARAL
         call par_sum_complex('N',smat,smat,nbands,nbands,nbands,&
#include PARAL_ARGS
             ,timer)
#endif

      ! Choleski decomposition (conjg(z)*smat*z) = I 
      ! get the inverse of z from Choleski decomposition of smat
      call zpotrf('U',nbands,smat,nbands,info) 
      if (info/=0) then 
        write(nconso,*) 'S_orthonomalize : error in zpotrf ',info 
        call clexit(nconso)
      endif 

      ! find U-1 : the inverse of smat
      z(:,:) = (0.0d0,0.0d0) 
      do i = 1,nbands 
        z(i,i) = 1.0d0
      enddo
      call ztrtrs('U','N','N',nbands,nbands,smat,nbands,z,nbands,info) 
      if (info/=0) then 
        write(nconso,*) 'S_orthonomalize : error in ztrtrs ',info 
        call clexit(nconso)
      endif 

      ! vector rotate psi using z
      call vec_rotate(psi,nrplwv,n_local,nbands,nbands,z,nbands,timer)

      ! vector rotate becp using z
      call vec_rotate_dcmplx(becp(1,1,nkp),nkbmax,nbands,nbands,z,&
                             nbands,timer)

      call time_stop(timer,TORTH)

      end subroutine S_orthonomalize

! --------------------------------------------------------------------
      subroutine  get_residual(psi,spsi,hpsi,n,eig,Rpsi,&
                               timer,nconso)

      ! get residual 
      !   R|psi> = H|psi> - eigen*S|psi>
! --------------------------------------------------------------------
      implicit none
      complex*WF_PRECISION psi(*)
      integer n
      complex*WF_PRECISION spsi(*),hpsi(*),Rpsi(*) 
      real*8 eig 
      real*8 timer(*)
      integer nconso

!     locals 
      integer i 

      call time_start(timer,TRESI)
      do i = 1,n
        Rpsi(i) = hpsi(i) - eig*spsi(i)
      enddo

      call time_stop(timer,TRESI)
      end subroutine  get_residual

!-------------------------------------------------------------------
 
      subroutine disp_rmm_diis_timers(timer,nout,nitend)
 
      implicit none
      integer nout,nitend
      real*8 timer(*)
 
      call print_t (nout, '-Form_BHB ', TSPROD, timer,nitend)
      call print_t (nout, '-vec_rotat', TVROT,  timer,nitend)
      call print_t (nout, '-S-ortho  ', TORTH,  timer,nitend)
      call print_t (nout, '-residual ', TRESI,  timer,nitend)
 
      end subroutine disp_rmm_diis_timers
!-------------------------------------------------------------------     

! ------------------------------------------------------------------
      integer function get_actual_block_size(nbands,block_size,nb) 

!     get actual block size given total number of bands, 
!     block_size and the current block_number
! --------------------------------------------------------------
      implicit none 
      integer nbands,block_size,nb
      get_actual_block_size = block_size 
      if ((nbands - block_size*nb)<0) then 
        get_actual_block_size = mod(nbands,block_size) 
      endif
      end function get_actual_block_size

      end module module_rmm_diss
                                                                   
