From 9b6542d350b0c1ed1979b326aaf9b3774ef46fbc Mon Sep 17 00:00:00 2001 From: wlthr Date: Wed, 5 Jul 2023 16:39:35 +0200 Subject: [PATCH] added using new spmm implementation --- base/serial/impl/psb_d_csr_impl.F90 | 126 +++++++++++++--------------- base/serial/impl/sp3mm_impl.f90 | 10 +-- 2 files changed, 58 insertions(+), 78 deletions(-) diff --git a/base/serial/impl/psb_d_csr_impl.F90 b/base/serial/impl/psb_d_csr_impl.F90 index bab49e23..e45cfc30 100644 --- a/base/serial/impl/psb_d_csr_impl.F90 +++ b/base/serial/impl/psb_d_csr_impl.F90 @@ -3367,7 +3367,7 @@ subroutine psb_dcsrspspmm(a,b,c,info, spmm_impl_id) end if ! CSR matrix multiplication - call csr_spspmm(a,b,c,spmm_impl_id_,info) + call csr_spspmm(a,b,c,info,spmm_impl_id_) call c%set_asb() call c%set_host() @@ -3381,13 +3381,13 @@ subroutine psb_dcsrspspmm(a,b,c,info, spmm_impl_id) contains - subroutine csr_spspmm(a,b,c,spmm_impl_id,info) + subroutine csr_spspmm(a,b,c,info,spmm_impl_id) implicit none type(psb_d_csr_sparse_mat), intent(in) :: a,b type(psb_d_csr_sparse_mat), intent(inout) :: c ! choice of spmm implementation from c code - integer(psb_ipk_), intent(in) :: spmm_impl_id integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_), intent(in) :: spmm_impl_id integer(psb_ipk_) :: ma,na,mb,nb integer(psb_ipk_), allocatable :: irow(:), idxs(:) real(psb_dpk_), allocatable :: row(:) @@ -3401,76 +3401,62 @@ contains mb = b%get_nrows() nb = b%get_ncols() - !! TODO : - ! * convert psb_d_csr_sparse_mat a and b to spmat_t - ! * choice of implementation - ! * code interfaces for sp3mm code - ! * call wanted interface - ! * convert result from spmat_t to psb_d_csr_sparse_mat c - - ! conversion - - ! select case (spmm_impl_id) - ! case (SPMM_ROW_BY_ROW_UB) - ! ! call spmm_row_by_row_ub - ! case (SPMM_ROW_BY_ROW_SYMB_NUM) - ! ! call spmm_row_by_row_symb_num - ! case (SPMM_ROW_BY_ROW_1D_BLOCKS_SYMB_NUM) - ! ! call spmm_row_by_row_1d_blocks_symb_num - ! case (SPMM_ROW_BY_ROW_2D_BLOCKS_SYMB_NUM) - ! ! call spmm_row_by_row_2d_blocks_symb_num - ! case default - ! ! call default choice - ! end select - - - nze = min(size(c%val),size(c%ja)) - isz = max(ma,na,mb,nb) - call psb_realloc(isz,row,info) - if (info == 0) call psb_realloc(isz,idxs,info) - if (info == 0) call psb_realloc(isz,irow,info) - if (info /= 0) return - row = dzero - irow = 0 - nzc = 1 - do j = 1,ma - c%irp(j) = nzc - nrc = 0 - do k = a%irp(j), a%irp(j+1)-1 - irw = a%ja(k) - cfb = a%val(k) - irwsz = b%irp(irw+1)-b%irp(irw) - do i = b%irp(irw),b%irp(irw+1)-1 - icl = b%ja(i) - if (irow(icl) 0 ) then - if ((nzc+nrc)>nze) then - nze = max(ma*((nzc+j-1)/j),nzc+2*nrc) - call psb_realloc(nze,c%val,info) - if (info == 0) call psb_realloc(nze,c%ja,info) - if (info /= 0) return - end if - - call psb_qsort(idxs(1:nrc)) - do i=1, nrc - irw = idxs(i) - c%ja(nzc) = irw - c%val(nzc) = row(irw) - row(irw) = dzero - nzc = nzc + 1 + if (.false.) then + nze = min(size(c%val),size(c%ja)) + isz = max(ma,na,mb,nb) + call psb_realloc(isz,row,info) + if (info == 0) call psb_realloc(isz,idxs,info) + if (info == 0) call psb_realloc(isz,irow,info) + if (info /= 0) return + row = dzero + irow = 0 + nzc = 1 + do j = 1,ma + c%irp(j) = nzc + nrc = 0 + do k = a%irp(j), a%irp(j+1)-1 + irw = a%ja(k) + cfb = a%val(k) + irwsz = b%irp(irw+1)-b%irp(irw) + do i = b%irp(irw),b%irp(irw+1)-1 + icl = b%ja(i) + if (irow(icl) 0 ) then + if ((nzc+nrc)>nze) then + nze = max(ma*((nzc+j-1)/j),nzc+2*nrc) + call psb_realloc(nze,c%val,info) + if (info == 0) call psb_realloc(nze,c%ja,info) + if (info /= 0) return + end if - c%irp(ma+1) = nzc + call psb_qsort(idxs(1:nrc)) + do i=1, nrc + irw = idxs(i) + c%ja(nzc) = irw + c%val(nzc) = row(irw) + row(irw) = dzero + nzc = nzc + 1 + end do + end if + end do + c%irp(ma+1) = nzc + else + !! TODO : + ! * convert psb_d_csr_sparse_mat a and b to spmat_t + ! * choice of implementation + ! * code interfaces for sp3mm code + ! * call wanted interface + ! * convert result from spmat_t to psb_d_csr_sparse_mat c + call dspmm(a,b,c,info,spmm_impl_id_) + end if end subroutine csr_spspmm diff --git a/base/serial/impl/sp3mm_impl.f90 b/base/serial/impl/sp3mm_impl.f90 index a8f5533c..de4de8c8 100644 --- a/base/serial/impl/sp3mm_impl.f90 +++ b/base/serial/impl/sp3mm_impl.f90 @@ -11,8 +11,8 @@ subroutine dspmm(a,b,c,info, impl_choice) implicit none type(psb_d_csr_sparse_mat), intent(in) :: a,b type(psb_d_csr_sparse_mat), intent(inout):: c - integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_), intent(in), optional :: impl_choice + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_), intent(in) :: impl_choice ! Internal variables integer(c_size_t):: a_m,a_n,a_nz @@ -97,12 +97,6 @@ subroutine dspmm(a,b,c,info, impl_choice) b_irp = b%irp b_irp_ptr = c_loc(b_irp) - if (present(impl_choice)) then - impl_choice_ = impl_choice - else - impl_choice_ = 0 - end if - ! call calculateSize call psb_f_spmm_build_spacc(a_m,a_n,a_nz,a_as_ptr,& a_ja_ptr,a_irp_ptr,a_max_row_nz,&