!-------------------------------------------------------------------------------
! Copyright (c) 2019 FrontISTR Commons
! This software is released under the MIT License, see LICENSE.txt
!-------------------------------------------------------------------------------
!> This module provides linear equation solver interface for monolish

module hecmw_solver_monolish
  use iso_c_binding
  use hecmw_util
  use hecmw_matrix_misc
  implicit none

  private
  public :: hecmw_solve_monolish

  interface
     integer(c_int) function monolish_sparse_solver &
          (n, nnz, row, col, val, X, Y, tol, maxiter) &
          BIND(C, name='monolish_sparse_solver')
       use, intrinsic :: iso_c_binding
       implicit none
       integer(c_int), value :: n
       integer(c_int), value :: nnz
       type(c_ptr), value :: row
       type(c_ptr), value :: col
       type(c_ptr), value :: val
       type(c_ptr), value :: X
       type(c_ptr), value :: Y
       real(c_double), value :: tol
       integer(c_int), value :: maxiter
     end function monolish_sparse_solver
  end interface

contains

  !> usage: !SOLVER, TYPE={CG, BICGSTAB, DIRECT}, ARCH={CPU, GPU}, LIB={HECMW, MONOLISH, CUSOLVER, MUMPS, MKL}
  !> if LIB={CUSOLVER, MUMPS, MKL} is set, these solver are called by monolish.

  subroutine hecmw_solve_monolish(hecMESH, hecMAT)
    implicit none
    type (hecmwST_local_mesh), intent(in) :: hecMESH
    type (hecmwST_matrix    ), intent(inout) :: hecMAT
    real(c_double), allocatable, target :: X(:)
    real(c_double), allocatable, target :: B(:)
    real(c_double), allocatable, target :: val(:)
    integer(c_int), allocatable, target :: col(:)
    integer(c_int), allocatable, target :: row(:)
    integer(c_int) :: ierr, nnz, n, maxiter, method, precond, lib, arch
    integer(kint) :: method_hecmw, precond_hecmw, lib_hecmw, arch_hecmw
    real(c_double) :: tol

#ifdef HECMW_WITH_MONOLISH
    call hecmw_convert_csr_to_coo(hecMAT, X, B, val, col, row, nnz)

    B(:) = hecMAT%B(:)
    n = size(X)
    tol = hecmw_mat_get_resid(hecMAT)
    maxiter = hecmw_mat_get_iter(hecMAT)

    !> get parameter defined by FrontISTR
    method_hecmw = hecmw_mat_get_method(hecMAT)
    precond_hecmw = hecmw_mat_get_precond(hecMAT)
    arch_hecmw = hecmw_mat_get_flag_architecture(hecMAT)
    lib_hecmw = hecmw_mat_get_flag_libtype(hecMAT)

    !> convert parameter to monoolish
    method = monolish_convert_parm_method(method_hecmw)
    precond = monolish_convert_parm_precond(precond_hecmw)
    arch = monolish_convert_parm_arch(arch_hecmw)
    lib = monolish_convert_parm_lib(lib_hecmw)

    ierr = monolish_sparse_solver(n, nnz, c_loc(row), c_loc(col), c_loc(val), c_loc(X), c_loc(B), tol, maxiter)
    hecMAT%X(:) = X(:)
#else
    stop "error: monolish solver not found"
#endif

    deallocate(X)
    deallocate(B)
    deallocate(val)
    deallocate(col)
    deallocate(row)
  end subroutine hecmw_solve_monolish

  function monolish_convert_parm_method(flag)
    implicit none
    integer(c_int) :: monolish_convert_parm_method
    integer(kint) :: flag

    if(flag == 1)  monolish_convert_parm_method =   1 !> CG method
    if(flag == 2)  monolish_convert_parm_method =   2 !> BiCGSTAB method
    if(flag > 100) monolish_convert_parm_method = 100 !> Direct method
  end function monolish_convert_parm_method

  function monolish_convert_parm_precond(flag)
    implicit none
    integer(c_int) :: monolish_convert_parm_precond
    integer(kint) :: flag

    if(flag ==  1) monolish_convert_parm_precond = 1 !> SOR preconditioner
    if(flag ==  2) monolish_convert_parm_precond = 1 !> SOR preconditioner
    if(flag ==  3) monolish_convert_parm_precond = 2 !> Diag preconditioner
    if(flag == 10) monolish_convert_parm_precond = 3 !> ILU preconditioner
    if(flag == 11) monolish_convert_parm_precond = 3 !> ILU preconditioner
    if(flag == 12) monolish_convert_parm_precond = 3 !> ILU preconditioner
    if(flag ==  5) monolish_convert_parm_precond = 4 !> Multi Grid preconditioner
  end function monolish_convert_parm_precond

  function monolish_convert_parm_arch(flag)
    implicit none
    integer(c_int) :: monolish_convert_parm_arch
    integer(kint) :: flag

    if(flag < 0 .or. 2 < flag) stop "monolish_convert_parm_arch"
    if(flag == 0) monolish_convert_parm_arch = 1 !> CPU as default
    if(flag == 1) monolish_convert_parm_arch = 1 !> CPU
    if(flag == 2) monolish_convert_parm_arch = 2 !> GPU
  end function monolish_convert_parm_arch

  function monolish_convert_parm_lib(flag)
    implicit none
    integer(c_int) :: monolish_convert_parm_lib
    integer(kint) :: flag

    if(flag == 1) stop "monolish_convert_parm_lib" !> HECMW
    if(flag == 2) monolish_convert_parm_lib = 2 !> MONOLISH
    if(flag == 3) monolish_convert_parm_lib = 3 !> CUSOLVER
    if(flag == 4) monolish_convert_parm_lib = 4 !> MUMPS
    if(flag == 5) monolish_convert_parm_lib = 5 !> MKL
  end function monolish_convert_parm_lib

  subroutine hecmw_convert_csr_to_coo(hecMAT, X, B, val, col, row, nnz)
    implicit none
    type (hecmwST_matrix    ), intent(inout) :: hecMAT
    integer(kint) :: N, NP, NZ, NDOF, NDOF2
    integer(kint) :: i, in, j, jn, k, jS, jE, idof, jdof, kn
    integer(kint) :: ierr
    real(c_double), intent(out), allocatable :: X(:)
    real(c_double), intent(out), allocatable :: B(:)
    real(c_double), intent(out), allocatable :: val(:)
    integer(c_int), intent(out), allocatable :: col(:)
    integer(c_int), intent(out), allocatable :: row(:)
    integer(c_int), intent(out) :: nnz

    N = hecMAT%N
    NP = hecMAT%NP
    NDOF = hecMAT%NDOF
    NDOF2 = NDOF*NDOF
    NZ = hecMAT%indexL(hecMAT%NP) + hecMAT%indexU(hecMAT%NP) + hecMAT%NP
    nnz = NZ*NDOF2
    allocate(X(NP*NDOF), source = 0.0d0)
    X(1) = 1.0d0
    allocate(B(NP*NDOF), source = 0.0d0)
    allocate(val(NZ*NDOF2), source = 0.0d0)
    allocate(col(NZ*NDOF2), source = 0)
    allocate(row(NZ*NDOF2), source = 0)

    in = 1
    do i = 1, hecMAT%NP
      do idof = 1, hecMAT%NDOF
        jS = hecMAT%indexL(i-1) + 1
        jE = hecMAT%indexL(i)
        do j = jS, jE
          jn = hecMAT%itemL(j)
          do jdof = 1, NDOF
            kn = NDOF2*(j-1) + NDOF*(idof-1) + jdof
            val(in) = hecMAT%AL(kn)
            col(in) = NDOF*(jn - 1) + jdof
            row(in) = NDOF*(i  - 1) + idof
            in = in + 1
          enddo
        enddo

        do jdof = 1, NDOF
          kn = NDOF2*(i-1) + NDOF*(idof-1) + jdof
          val(in) = hecMAT%D(kn)
          col(in) = NDOF*(i - 1) + jdof
          row(in) = NDOF*(i - 1) + idof
          in = in + 1
        enddo

        jS = hecMAT%indexU(i-1) + 1
        jE = hecMAT%indexU(i)
        do j = jS, jE
          jn = hecMAT%itemU(j)
          do jdof = 1, NDOF
            kn = NDOF2*(j-1) + NDOF*(idof-1) + jdof
            val(in) = hecMAT%AU(kn)
            col(in) = NDOF*(jn - 1) + jdof
            row(in) = NDOF*(i  - 1) + idof
            in = in + 1
          enddo
        enddo
      enddo
    enddo
  end subroutine hecmw_convert_csr_to_coo

  subroutine hecmw_convert_csr_to_coo_test()
    implicit none
    type (hecmwST_matrix    ) :: hecMAT
    integer(kint) :: i, N, NP, NZ, NDOF, NDOF2
    real(c_double), allocatable :: X(:)
    real(c_double), allocatable :: Y(:)
    real(c_double), allocatable :: val(:)
    integer(c_int), allocatable :: col(:)
    integer(c_int), allocatable :: row(:)
    integer(c_int) :: nnz

    write(*,*)"** hecmw_convert_csr_to_coo_test"

    hecMAT%N = 3
    hecMAT%NP = 3
    hecMAT%NDOF = 3
    N = hecMAT%N
    NP = hecMAT%NP
    NDOF = hecMAT%NDOF
    NDOF2 = NDOF*NDOF
    NZ = 3
    allocate(hecMAT%indexL(0:NP), source = 0)
    allocate(hecMAT%indexU(0:NP), source = 0)
    allocate(hecMAT%itemU(0:NDOF2*NZ), source = 0)
    allocate(hecMAT%itemL(0:NDOF2*NZ), source = 0)
    allocate(hecMAT%D(NDOF2*NP), source = 0.0d0)
    allocate(hecMAT%AU(NDOF2*NZ), source = 0.0d0)
    allocate(hecMAT%AL(NDOF2*NZ), source = 0.0d0)

    do i = 1, NDOF2*NP
      hecMAT%D(i) = dble(i)
    enddo

    do i = 1, NDOF2*NZ
      hecMAT%AU(i) = dble(i) + 100.0d0
    enddo

    do i = 1, NDOF2*NZ
      hecMAT%AL(i) = dble(i) + 200.0d0
    enddo

    hecMAT%indexL(1) = 0
    hecMAT%indexL(2) = 1
    hecMAT%indexL(3) = 3
    hecMAT%itemL(1) = 1
    hecMAT%itemL(2) = 1
    hecMAT%itemL(3) = 2

    hecMAT%indexU(1) = 2
    hecMAT%indexU(2) = 3
    hecMAT%indexU(3) = 3
    hecMAT%itemU(1) = 2
    hecMAT%itemU(2) = 3
    hecMAT%itemU(3) = 3

    call hecmw_convert_csr_to_coo(hecMAT, X, Y, val, col, row, nnz)

    write(*,*)"** coo"
    do i = 1, 81
      write(*,"(2i3,f5.0)") row(i), col(i), val(i)
    enddo
  end subroutine hecmw_convert_csr_to_coo_test
end module hecmw_solver_monolish
