diff --git a/base/modules/serial/sp3mm_mod.F90 b/base/modules/serial/sp3mm_mod.F90 new file mode 100644 index 00000000..45b0502f --- /dev/null +++ b/base/modules/serial/sp3mm_mod.F90 @@ -0,0 +1,44 @@ +module sp3mm_mod + use iso_c_binding + use psb_const_mod + use psb_error_mod + + interface spmm_row_by_row + subroutine dspmm_row_by_row_ub(a,b,c,info) + use psb_d_mat_mod, only : psb_dspmat_type + import :: psb_ipk_ + implicit none + type(psb_dspmat_type), intent(in) :: a,b + type(psb_dspmat_type), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + end subroutine dspmm_row_by_row_ub + + subroutine dspmm_row_by_row_symb_num(a,b,c,info) + use psb_d_mat_mod, only : psb_dspmat_type + import :: psb_ipk_ + implicit none + type(psb_dspmat_type), intent(in) :: a,b + type(psb_dspmat_type), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + end subroutine dspmm_row_by_row_symb_num + + subroutine dspmm_row_by_row_1d_blocks_symb_num(a,b,c,info) + use psb_d_mat_mod, only : psb_dspmat_type + import :: psb_ipk_ + implicit none + type(psb_dspmat_type), intent(in) :: a,b + type(psb_dspmat_type), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + end subroutine dspmm_row_by_row_1d_blocks_symb_num + + subroutine dspmm_row_by_row_2d_blocks_symb_num(a,b,c,info) + use psb_d_mat_mod, only : psb_dspmat_type + import :: psb_ipk_ + implicit none + type(psb_dspmat_type), intent(in) :: a,b + type(psb_dspmat_type), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + end subroutine dspmm_row_by_row_2d_blocks_symb_num + end interface spmm_row_by_row + +end module sp3mm_mod \ No newline at end of file diff --git a/base/serial/impl/psb_d_csr_impl.F90 b/base/serial/impl/psb_d_csr_impl.F90 index 10518a2d..1c637401 100644 --- a/base/serial/impl/psb_d_csr_impl.F90 +++ b/base/serial/impl/psb_d_csr_impl.F90 @@ -3317,18 +3317,20 @@ subroutine psb_d_csr_clean_zeros(a, info) call a%set_host() end subroutine psb_d_csr_clean_zeros -subroutine psb_dcsrspspmm(a,b,c,info) +subroutine psb_dcsrspspmm(a,b,c,info, spmm_impl_id) use psb_d_mat_mod use psb_serial_mod, psb_protect_name => psb_dcsrspspmm implicit none class(psb_d_csr_sparse_mat), intent(in) :: a,b - type(psb_d_csr_sparse_mat), intent(out) :: c - integer(psb_ipk_), intent(out) :: info + type(psb_d_csr_sparse_mat), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_), intent(in), optional :: spmm_impl_id integer(psb_ipk_) :: ma,na,mb,nb, nzc, nza, nzb character(len=20) :: name integer(psb_ipk_) :: err_act + integer(psb_ipk_) :: spmm_impl_id_ name='psb_csrspspmm' call psb_erractionsave(err_act) info = psb_success_ @@ -3355,7 +3357,17 @@ subroutine psb_dcsrspspmm(a,b,c,info) nzc = 2*(nza+nzb) call c%allocate(ma,nb,nzc) - call csr_spspmm(a,b,c,info) + ! Uses optional argument to choose c + ! implementation of spmm or sets default + ! choice if argument is missing + if (present(spmm_impl_id)) then + spmm_impl_id_ = spmm_impl_id + else + spmm_impl_id_ = 0 + end if + + ! CSR matrix multiplication + call csr_spspmm(a,b,c,spmm_impl_id_,info) call c%set_asb() call c%set_host() @@ -3369,11 +3381,13 @@ subroutine psb_dcsrspspmm(a,b,c,info) contains - subroutine csr_spspmm(a,b,c,info) + subroutine csr_spspmm(a,b,c,spmm_impl_id,info) implicit none - type(psb_d_csr_sparse_mat), intent(in) :: a,b + type(psb_d_csr_sparse_mat), intent(in) :: a,b type(psb_d_csr_sparse_mat), intent(inout) :: c - integer(psb_ipk_), intent(out) :: info + ! choice of spmm implementation from c code + integer(psb_ipk_), intent(in) :: spmm_impl_id + integer(psb_ipk_), intent(out) :: info integer(psb_ipk_) :: ma,na,mb,nb integer(psb_ipk_), allocatable :: irow(:), idxs(:) real(psb_dpk_), allocatable :: row(:) @@ -3388,6 +3402,37 @@ contains mb = b%get_nrows() nb = b%get_ncols() + !! TODO : + ! * convert psb_d_csr_sparse_mat a and b to spmat_t + ! * choice of implementation + ! * code interfaces for sp3mm code + ! * call wanted interface + ! * convert result from spmat_t to psb_d_csr_sparse_mat c + + ! conversion + + ! available choices of implementation + enum, bind(C) + enumerator :: SPMM_ROW_BY_ROW_UB = 1 + enumerator SPMM_ROW_BY_ROW_SYMB_NUM + enumerator SPMM_ROW_BY_ROW_1D_BLOCKS_SYMB_NUM + enumerator SPMM_ROW_BY_ROW_2D_BLOCKS_SYMB_NUM + end enum + + select case (spmm_impl_id) + case (SPMM_ROW_BY_ROW_UB) + ! call spmm_row_by_row_ub + case (SPMM_ROW_BY_ROW_SYMB_NUM) + ! call spmm_row_by_row_symb_num + case (SPMM_ROW_BY_ROW_1D_BLOCKS_SYMB_NUM) + ! call spmm_row_by_row_1d_blocks_symb_num + case (SPMM_ROW_BY_ROW_2D_BLOCKS_SYMB_NUM) + ! call spmm_row_by_row_2d_blocks_symb_num + case default + ! call default choice + end select + + nze = min(size(c%val),size(c%ja)) isz = max(ma,na,mb,nb) call psb_realloc(isz,row,info) diff --git a/base/serial/impl/sp3mm_impl.f90 b/base/serial/impl/sp3mm_impl.f90 new file mode 100644 index 00000000..d79a05ac --- /dev/null +++ b/base/serial/impl/sp3mm_impl.f90 @@ -0,0 +1,12 @@ +subroutine dspmm_row_by_row_ub(a,b,c,info) + use psb_error_mod + use psb_base_mat_mod + use psb_d_mat_mod, only : psb_dspmat_type + use psb_objhandle_mod, only: spmat_t, config_t + implicit none + type(psb_dspmat_type), intent(in) :: a,b + type(psb_dspmat_type), intent(out) :: c + integer(psb_ipk_), intent(out) :: info + + ! TODO : implement the C interface +end subroutine dspmm_row_by_row_ub \ No newline at end of file diff --git a/cbind/base/psb_objhandle_mod.F90 b/cbind/base/psb_objhandle_mod.F90 index b77b17ab..d6200294 100644 --- a/cbind/base/psb_objhandle_mod.F90 +++ b/cbind/base/psb_objhandle_mod.F90 @@ -50,7 +50,7 @@ module psb_objhandle_mod ! number of non zeros and dimensions integer(c_size_t) :: nz, m, n ! value array - real(c_float), allocatable :: as(:) + real(c_double), allocatable :: as(:) ! columns array integer(c_size_t), allocatable :: ja(:) ! row index pointers array