created sp3mm module and implementation files as well as an interface, still needs C files and C binds

sp3mm-interface
wlther 2 years ago
parent c093a6e11d
commit 1a4ae1e973

@ -0,0 +1,44 @@
module sp3mm_mod
use iso_c_binding
use psb_const_mod
use psb_error_mod
interface spmm_row_by_row
subroutine dspmm_row_by_row_ub(a,b,c,info)
use psb_d_mat_mod, only : psb_dspmat_type
import :: psb_ipk_
implicit none
type(psb_dspmat_type), intent(in) :: a,b
type(psb_dspmat_type), intent(out) :: c
integer(psb_ipk_), intent(out) :: info
end subroutine dspmm_row_by_row_ub
subroutine dspmm_row_by_row_symb_num(a,b,c,info)
use psb_d_mat_mod, only : psb_dspmat_type
import :: psb_ipk_
implicit none
type(psb_dspmat_type), intent(in) :: a,b
type(psb_dspmat_type), intent(out) :: c
integer(psb_ipk_), intent(out) :: info
end subroutine dspmm_row_by_row_symb_num
subroutine dspmm_row_by_row_1d_blocks_symb_num(a,b,c,info)
use psb_d_mat_mod, only : psb_dspmat_type
import :: psb_ipk_
implicit none
type(psb_dspmat_type), intent(in) :: a,b
type(psb_dspmat_type), intent(out) :: c
integer(psb_ipk_), intent(out) :: info
end subroutine dspmm_row_by_row_1d_blocks_symb_num
subroutine dspmm_row_by_row_2d_blocks_symb_num(a,b,c,info)
use psb_d_mat_mod, only : psb_dspmat_type
import :: psb_ipk_
implicit none
type(psb_dspmat_type), intent(in) :: a,b
type(psb_dspmat_type), intent(out) :: c
integer(psb_ipk_), intent(out) :: info
end subroutine dspmm_row_by_row_2d_blocks_symb_num
end interface spmm_row_by_row
end module sp3mm_mod

@ -3317,18 +3317,20 @@ subroutine psb_d_csr_clean_zeros(a, info)
call a%set_host() call a%set_host()
end subroutine psb_d_csr_clean_zeros end subroutine psb_d_csr_clean_zeros
subroutine psb_dcsrspspmm(a,b,c,info) subroutine psb_dcsrspspmm(a,b,c,info, spmm_impl_id)
use psb_d_mat_mod use psb_d_mat_mod
use psb_serial_mod, psb_protect_name => psb_dcsrspspmm use psb_serial_mod, psb_protect_name => psb_dcsrspspmm
implicit none implicit none
class(psb_d_csr_sparse_mat), intent(in) :: a,b class(psb_d_csr_sparse_mat), intent(in) :: a,b
type(psb_d_csr_sparse_mat), intent(out) :: c type(psb_d_csr_sparse_mat), intent(out) :: c
integer(psb_ipk_), intent(out) :: info integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_), intent(in), optional :: spmm_impl_id
integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb
character(len=20) :: name character(len=20) :: name
integer(psb_ipk_) :: err_act integer(psb_ipk_) :: err_act
integer(psb_ipk_) :: spmm_impl_id_
name='psb_csrspspmm' name='psb_csrspspmm'
call psb_erractionsave(err_act) call psb_erractionsave(err_act)
info = psb_success_ info = psb_success_
@ -3355,7 +3357,17 @@ subroutine psb_dcsrspspmm(a,b,c,info)
nzc = 2*(nza+nzb) nzc = 2*(nza+nzb)
call c%allocate(ma,nb,nzc) call c%allocate(ma,nb,nzc)
call csr_spspmm(a,b,c,info) ! Uses optional argument to choose c
! implementation of spmm or sets default
! choice if argument is missing
if (present(spmm_impl_id)) then
spmm_impl_id_ = spmm_impl_id
else
spmm_impl_id_ = 0
end if
! CSR matrix multiplication
call csr_spspmm(a,b,c,spmm_impl_id_,info)
call c%set_asb() call c%set_asb()
call c%set_host() call c%set_host()
@ -3369,11 +3381,13 @@ subroutine psb_dcsrspspmm(a,b,c,info)
contains contains
subroutine csr_spspmm(a,b,c,info) subroutine csr_spspmm(a,b,c,spmm_impl_id,info)
implicit none implicit none
type(psb_d_csr_sparse_mat), intent(in) :: a,b type(psb_d_csr_sparse_mat), intent(in) :: a,b
type(psb_d_csr_sparse_mat), intent(inout) :: c type(psb_d_csr_sparse_mat), intent(inout) :: c
integer(psb_ipk_), intent(out) :: info ! choice of spmm implementation from c code
integer(psb_ipk_), intent(in) :: spmm_impl_id
integer(psb_ipk_), intent(out) :: info
integer(psb_ipk_) :: ma,na,mb,nb integer(psb_ipk_) :: ma,na,mb,nb
integer(psb_ipk_), allocatable :: irow(:), idxs(:) integer(psb_ipk_), allocatable :: irow(:), idxs(:)
real(psb_dpk_), allocatable :: row(:) real(psb_dpk_), allocatable :: row(:)
@ -3388,6 +3402,37 @@ contains
mb = b%get_nrows() mb = b%get_nrows()
nb = b%get_ncols() nb = b%get_ncols()
!! TODO :
! * convert psb_d_csr_sparse_mat a and b to spmat_t
! * choice of implementation
! * code interfaces for sp3mm code
! * call wanted interface
! * convert result from spmat_t to psb_d_csr_sparse_mat c
! conversion
! available choices of implementation
enum, bind(C)
enumerator :: SPMM_ROW_BY_ROW_UB = 1
enumerator SPMM_ROW_BY_ROW_SYMB_NUM
enumerator SPMM_ROW_BY_ROW_1D_BLOCKS_SYMB_NUM
enumerator SPMM_ROW_BY_ROW_2D_BLOCKS_SYMB_NUM
end enum
select case (spmm_impl_id)
case (SPMM_ROW_BY_ROW_UB)
! call spmm_row_by_row_ub
case (SPMM_ROW_BY_ROW_SYMB_NUM)
! call spmm_row_by_row_symb_num
case (SPMM_ROW_BY_ROW_1D_BLOCKS_SYMB_NUM)
! call spmm_row_by_row_1d_blocks_symb_num
case (SPMM_ROW_BY_ROW_2D_BLOCKS_SYMB_NUM)
! call spmm_row_by_row_2d_blocks_symb_num
case default
! call default choice
end select
nze = min(size(c%val),size(c%ja)) nze = min(size(c%val),size(c%ja))
isz = max(ma,na,mb,nb) isz = max(ma,na,mb,nb)
call psb_realloc(isz,row,info) call psb_realloc(isz,row,info)

@ -0,0 +1,12 @@
subroutine dspmm_row_by_row_ub(a,b,c,info)
use psb_error_mod
use psb_base_mat_mod
use psb_d_mat_mod, only : psb_dspmat_type
use psb_objhandle_mod, only: spmat_t, config_t
implicit none
type(psb_dspmat_type), intent(in) :: a,b
type(psb_dspmat_type), intent(out) :: c
integer(psb_ipk_), intent(out) :: info
! TODO : implement the C interface
end subroutine dspmm_row_by_row_ub

@ -50,7 +50,7 @@ module psb_objhandle_mod
! number of non zeros and dimensions ! number of non zeros and dimensions
integer(c_size_t) :: nz, m, n integer(c_size_t) :: nz, m, n
! value array ! value array
real(c_float), allocatable :: as(:) real(c_double), allocatable :: as(:)
! columns array ! columns array
integer(c_size_t), allocatable :: ja(:) integer(c_size_t), allocatable :: ja(:)
! row index pointers array ! row index pointers array

Loading…
Cancel
Save