From 347352fe1e5b3ba8f8ebfa085c34c010a8474df0 Mon Sep 17 00:00:00 2001 From: sfilippone Date: Fri, 2 Jun 2023 11:07:23 +0200 Subject: [PATCH] Make spins work in OpenMP from either par or serial --- base/tools/psb_cspins.F90 | 123 ++++++++++++++++++++++++++++++++---- base/tools/psb_dspins.F90 | 119 +++++++++++++++++++++++++++++++--- base/tools/psb_sspins.F90 | 123 ++++++++++++++++++++++++++++++++---- base/tools/psb_zspins.F90 | 123 ++++++++++++++++++++++++++++++++---- test/pargen/psb_d_pde3d.F90 | 1 - 5 files changed, 445 insertions(+), 44 deletions(-) diff --git a/base/tools/psb_cspins.F90 b/base/tools/psb_cspins.F90 index f523a529..e5f2731d 100644 --- a/base/tools/psb_cspins.F90 +++ b/base/tools/psb_cspins.F90 @@ -135,28 +135,132 @@ subroutine psb_cspins(nz,ia,ja,val,a,desc_a,info,rebuild,local) goto 9999 end if #if defined(OPENMP) - !$omp parallel private(ila,jla,nrow,ncol,nnl,k) -#endif + block + logical :: is_in_parallel + is_in_parallel = omp_in_parallel() + if (is_in_parallel) then + !$omp parallel private(ila,jla,nrow,ncol,nnl,k) + call desc_a%indxmap%g2l(ia(1:nz),ila(1:nz),info,owned=.true.) + !$omp critical(spins) + if (info == 0) call desc_a%indxmap%g2l_ins(ja(1:nz),jla(1:nz),info,& + & mask=(ila(1:nz)>0)) + !$omp end critical(spins) + !write(0,*) me,' after g2l_ins ',psb_errstatus_fatal(),info + if (info /= psb_success_) then + call psb_errpush(psb_err_from_subroutine_ai_,name,& + & a_err='psb_cdins',i_err=(/info/)) + goto 9998 + end if + nrow = desc_a%get_local_rows() + ncol = desc_a%get_local_cols() + !write(0,*) me,' Before csput',psb_errstatus_fatal() + if (a%is_bld()) then + call a%csput(nz,ila,jla,val,ione,nrow,ione,ncol,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='a%csput') + goto 9998 + end if + + if (a%is_remote_build()) then + nnl = count(ila(1:nz)<0) + if (nnl > 0) then + allocate(lila(nnl),ljla(nnl),lval(nnl)) + k = 0 + do i=1,nz + if (ila(i)<0) then + k=k+1 + lila(k) = ia(i) + ljla(k) = ja(i) + lval(k) = val(i) + end if + end do + if (k /= nnl) write(0,*) name,' Wrong conversion?',k,nnl + call a%rmta%csput(nnl,lila,ljla,lval,1_psb_lpk_,desc_a%get_global_rows(),& + & 1_psb_lpk_,desc_a%get_global_rows(),info) + end if + end if + + else + info = psb_err_invalid_a_and_cd_state_ + call psb_errpush(info,name) + end if +9998 continue + !write(0,*) me,' after csput',psb_errstatus_fatal() + !$omp end parallel + else + call desc_a%indxmap%g2l(ia(1:nz),ila(1:nz),info,owned=.true.) + !write(0,*) me,' Before g2l_ins ',psb_errstatus_fatal() + if (info == 0) call desc_a%indxmap%g2l_ins(ja(1:nz),jla(1:nz),info,& + & mask=(ila(1:nz)>0)) + !write(0,*) me,' after g2l_ins ',psb_errstatus_fatal(),info + if (info /= psb_success_) then + call psb_errpush(psb_err_from_subroutine_ai_,name,& + & a_err='psb_cdins',i_err=(/info/)) + goto 9999 + end if + nrow = desc_a%get_local_rows() + ncol = desc_a%get_local_cols() + !write(0,*) me,' Before csput',psb_errstatus_fatal() + if (a%is_bld()) then + call a%csput(nz,ila,jla,val,ione,nrow,ione,ncol,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='a%csput') + goto 9999 + end if + + if (a%is_remote_build()) then + nnl = count(ila(1:nz)<0) + if (nnl > 0) then + allocate(lila(nnl),ljla(nnl),lval(nnl)) + k = 0 + do i=1,nz + if (ila(i)<0) then + k=k+1 + lila(k) = ia(i) + ljla(k) = ja(i) + lval(k) = val(i) + end if + end do + if (k /= nnl) write(0,*) name,' Wrong conversion?',k,nnl + call a%rmta%csput(nnl,lila,ljla,lval,1_psb_lpk_,desc_a%get_global_rows(),& + & 1_psb_lpk_,desc_a%get_global_rows(),info) + end if + end if + + else + info = psb_err_invalid_a_and_cd_state_ + call psb_errpush(info,name) + goto 9999 + end if + end if + end block +#else + + !write(0,*) me,' Before g2l ',psb_errstatus_fatal() call desc_a%indxmap%g2l(ia(1:nz),ila(1:nz),info,owned=.true.) + if (info == 0) call desc_a%indxmap%g2l_ins(ja(1:nz),jla(1:nz),info,& & mask=(ila(1:nz)>0)) - + + !write(0,*) me,' after g2l_ins ',psb_errstatus_fatal(),info if (info /= psb_success_) then call psb_errpush(psb_err_from_subroutine_ai_,name,& & a_err='psb_cdins',i_err=(/info/)) - !goto 9999 + goto 9999 end if nrow = desc_a%get_local_rows() ncol = desc_a%get_local_cols() - + !write(0,*) me,' Before csput',psb_errstatus_fatal() if (a%is_bld()) then call a%csput(nz,ila,jla,val,ione,nrow,ione,ncol,info) if (info /= psb_success_) then info=psb_err_from_subroutine_ call psb_errpush(info,name,a_err='a%csput') - !goto 9999 + goto 9999 end if - + if (a%is_remote_build()) then nnl = count(ila(1:nz)<0) if (nnl > 0) then @@ -179,11 +283,8 @@ subroutine psb_cspins(nz,ia,ja,val,a,desc_a,info,rebuild,local) else info = psb_err_invalid_a_and_cd_state_ call psb_errpush(info,name) - !goto 9999 + goto 9999 end if - -#if defined(OPENMP) - !$omp end parallel #endif if (info /= 0) goto 9999 endif diff --git a/base/tools/psb_dspins.F90 b/base/tools/psb_dspins.F90 index 094d0a4b..cdeaa931 100644 --- a/base/tools/psb_dspins.F90 +++ b/base/tools/psb_dspins.F90 @@ -135,18 +135,120 @@ subroutine psb_dspins(nz,ia,ja,val,a,desc_a,info,rebuild,local) goto 9999 end if #if defined(OPENMP) - !$omp parallel private(ila,jla,nrow,ncol,nnl,k) -#endif + block + logical :: is_in_parallel + is_in_parallel = omp_in_parallel() + if (is_in_parallel) then + !$omp parallel private(ila,jla,nrow,ncol,nnl,k) + call desc_a%indxmap%g2l(ia(1:nz),ila(1:nz),info,owned=.true.) + !$omp critical(spins) + if (info == 0) call desc_a%indxmap%g2l_ins(ja(1:nz),jla(1:nz),info,& + & mask=(ila(1:nz)>0)) + !$omp end critical(spins) + !write(0,*) me,' after g2l_ins ',psb_errstatus_fatal(),info + if (info /= psb_success_) then + call psb_errpush(psb_err_from_subroutine_ai_,name,& + & a_err='psb_cdins',i_err=(/info/)) + goto 9998 + end if + nrow = desc_a%get_local_rows() + ncol = desc_a%get_local_cols() + !write(0,*) me,' Before csput',psb_errstatus_fatal() + if (a%is_bld()) then + call a%csput(nz,ila,jla,val,ione,nrow,ione,ncol,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='a%csput') + goto 9998 + end if + + if (a%is_remote_build()) then + nnl = count(ila(1:nz)<0) + if (nnl > 0) then + allocate(lila(nnl),ljla(nnl),lval(nnl)) + k = 0 + do i=1,nz + if (ila(i)<0) then + k=k+1 + lila(k) = ia(i) + ljla(k) = ja(i) + lval(k) = val(i) + end if + end do + if (k /= nnl) write(0,*) name,' Wrong conversion?',k,nnl + call a%rmta%csput(nnl,lila,ljla,lval,1_psb_lpk_,desc_a%get_global_rows(),& + & 1_psb_lpk_,desc_a%get_global_rows(),info) + end if + end if + + else + info = psb_err_invalid_a_and_cd_state_ + call psb_errpush(info,name) + end if +9998 continue + !write(0,*) me,' after csput',psb_errstatus_fatal() + !$omp end parallel + else + call desc_a%indxmap%g2l(ia(1:nz),ila(1:nz),info,owned=.true.) + !write(0,*) me,' Before g2l_ins ',psb_errstatus_fatal() + if (info == 0) call desc_a%indxmap%g2l_ins(ja(1:nz),jla(1:nz),info,& + & mask=(ila(1:nz)>0)) + !write(0,*) me,' after g2l_ins ',psb_errstatus_fatal(),info + if (info /= psb_success_) then + call psb_errpush(psb_err_from_subroutine_ai_,name,& + & a_err='psb_cdins',i_err=(/info/)) + goto 9999 + end if + nrow = desc_a%get_local_rows() + ncol = desc_a%get_local_cols() + !write(0,*) me,' Before csput',psb_errstatus_fatal() + if (a%is_bld()) then + call a%csput(nz,ila,jla,val,ione,nrow,ione,ncol,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='a%csput') + goto 9999 + end if + + if (a%is_remote_build()) then + nnl = count(ila(1:nz)<0) + if (nnl > 0) then + allocate(lila(nnl),ljla(nnl),lval(nnl)) + k = 0 + do i=1,nz + if (ila(i)<0) then + k=k+1 + lila(k) = ia(i) + ljla(k) = ja(i) + lval(k) = val(i) + end if + end do + if (k /= nnl) write(0,*) name,' Wrong conversion?',k,nnl + call a%rmta%csput(nnl,lila,ljla,lval,1_psb_lpk_,desc_a%get_global_rows(),& + & 1_psb_lpk_,desc_a%get_global_rows(),info) + end if + end if + + else + info = psb_err_invalid_a_and_cd_state_ + call psb_errpush(info,name) + goto 9999 + end if + end if + end block +#else + !write(0,*) me,' Before g2l ',psb_errstatus_fatal() call desc_a%indxmap%g2l(ia(1:nz),ila(1:nz),info,owned=.true.) - !write(0,*) me,' Before g2l_ins ',psb_errstatus_fatal() + if (info == 0) call desc_a%indxmap%g2l_ins(ja(1:nz),jla(1:nz),info,& & mask=(ila(1:nz)>0)) + !write(0,*) me,' after g2l_ins ',psb_errstatus_fatal(),info if (info /= psb_success_) then call psb_errpush(psb_err_from_subroutine_ai_,name,& & a_err='psb_cdins',i_err=(/info/)) - !goto 9999 + goto 9999 end if nrow = desc_a%get_local_rows() ncol = desc_a%get_local_cols() @@ -156,9 +258,9 @@ subroutine psb_dspins(nz,ia,ja,val,a,desc_a,info,rebuild,local) if (info /= psb_success_) then info=psb_err_from_subroutine_ call psb_errpush(info,name,a_err='a%csput') - !goto 9999 + goto 9999 end if - + if (a%is_remote_build()) then nnl = count(ila(1:nz)<0) if (nnl > 0) then @@ -181,11 +283,8 @@ subroutine psb_dspins(nz,ia,ja,val,a,desc_a,info,rebuild,local) else info = psb_err_invalid_a_and_cd_state_ call psb_errpush(info,name) - !goto 9999 + goto 9999 end if - !write(0,*) me,' after csput',psb_errstatus_fatal() -#if defined(OPENMP) - !$omp end parallel #endif if (info /= 0) goto 9999 endif diff --git a/base/tools/psb_sspins.F90 b/base/tools/psb_sspins.F90 index 9781eaae..39e4ad79 100644 --- a/base/tools/psb_sspins.F90 +++ b/base/tools/psb_sspins.F90 @@ -135,28 +135,132 @@ subroutine psb_sspins(nz,ia,ja,val,a,desc_a,info,rebuild,local) goto 9999 end if #if defined(OPENMP) - !$omp parallel private(ila,jla,nrow,ncol,nnl,k) -#endif + block + logical :: is_in_parallel + is_in_parallel = omp_in_parallel() + if (is_in_parallel) then + !$omp parallel private(ila,jla,nrow,ncol,nnl,k) + call desc_a%indxmap%g2l(ia(1:nz),ila(1:nz),info,owned=.true.) + !$omp critical(spins) + if (info == 0) call desc_a%indxmap%g2l_ins(ja(1:nz),jla(1:nz),info,& + & mask=(ila(1:nz)>0)) + !$omp end critical(spins) + !write(0,*) me,' after g2l_ins ',psb_errstatus_fatal(),info + if (info /= psb_success_) then + call psb_errpush(psb_err_from_subroutine_ai_,name,& + & a_err='psb_cdins',i_err=(/info/)) + goto 9998 + end if + nrow = desc_a%get_local_rows() + ncol = desc_a%get_local_cols() + !write(0,*) me,' Before csput',psb_errstatus_fatal() + if (a%is_bld()) then + call a%csput(nz,ila,jla,val,ione,nrow,ione,ncol,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='a%csput') + goto 9998 + end if + + if (a%is_remote_build()) then + nnl = count(ila(1:nz)<0) + if (nnl > 0) then + allocate(lila(nnl),ljla(nnl),lval(nnl)) + k = 0 + do i=1,nz + if (ila(i)<0) then + k=k+1 + lila(k) = ia(i) + ljla(k) = ja(i) + lval(k) = val(i) + end if + end do + if (k /= nnl) write(0,*) name,' Wrong conversion?',k,nnl + call a%rmta%csput(nnl,lila,ljla,lval,1_psb_lpk_,desc_a%get_global_rows(),& + & 1_psb_lpk_,desc_a%get_global_rows(),info) + end if + end if + + else + info = psb_err_invalid_a_and_cd_state_ + call psb_errpush(info,name) + end if +9998 continue + !write(0,*) me,' after csput',psb_errstatus_fatal() + !$omp end parallel + else + call desc_a%indxmap%g2l(ia(1:nz),ila(1:nz),info,owned=.true.) + !write(0,*) me,' Before g2l_ins ',psb_errstatus_fatal() + if (info == 0) call desc_a%indxmap%g2l_ins(ja(1:nz),jla(1:nz),info,& + & mask=(ila(1:nz)>0)) + !write(0,*) me,' after g2l_ins ',psb_errstatus_fatal(),info + if (info /= psb_success_) then + call psb_errpush(psb_err_from_subroutine_ai_,name,& + & a_err='psb_cdins',i_err=(/info/)) + goto 9999 + end if + nrow = desc_a%get_local_rows() + ncol = desc_a%get_local_cols() + !write(0,*) me,' Before csput',psb_errstatus_fatal() + if (a%is_bld()) then + call a%csput(nz,ila,jla,val,ione,nrow,ione,ncol,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='a%csput') + goto 9999 + end if + + if (a%is_remote_build()) then + nnl = count(ila(1:nz)<0) + if (nnl > 0) then + allocate(lila(nnl),ljla(nnl),lval(nnl)) + k = 0 + do i=1,nz + if (ila(i)<0) then + k=k+1 + lila(k) = ia(i) + ljla(k) = ja(i) + lval(k) = val(i) + end if + end do + if (k /= nnl) write(0,*) name,' Wrong conversion?',k,nnl + call a%rmta%csput(nnl,lila,ljla,lval,1_psb_lpk_,desc_a%get_global_rows(),& + & 1_psb_lpk_,desc_a%get_global_rows(),info) + end if + end if + + else + info = psb_err_invalid_a_and_cd_state_ + call psb_errpush(info,name) + goto 9999 + end if + end if + end block +#else + + !write(0,*) me,' Before g2l ',psb_errstatus_fatal() call desc_a%indxmap%g2l(ia(1:nz),ila(1:nz),info,owned=.true.) + if (info == 0) call desc_a%indxmap%g2l_ins(ja(1:nz),jla(1:nz),info,& & mask=(ila(1:nz)>0)) - + + !write(0,*) me,' after g2l_ins ',psb_errstatus_fatal(),info if (info /= psb_success_) then call psb_errpush(psb_err_from_subroutine_ai_,name,& & a_err='psb_cdins',i_err=(/info/)) - !goto 9999 + goto 9999 end if nrow = desc_a%get_local_rows() ncol = desc_a%get_local_cols() - + !write(0,*) me,' Before csput',psb_errstatus_fatal() if (a%is_bld()) then call a%csput(nz,ila,jla,val,ione,nrow,ione,ncol,info) if (info /= psb_success_) then info=psb_err_from_subroutine_ call psb_errpush(info,name,a_err='a%csput') - !goto 9999 + goto 9999 end if - + if (a%is_remote_build()) then nnl = count(ila(1:nz)<0) if (nnl > 0) then @@ -179,11 +283,8 @@ subroutine psb_sspins(nz,ia,ja,val,a,desc_a,info,rebuild,local) else info = psb_err_invalid_a_and_cd_state_ call psb_errpush(info,name) - !goto 9999 + goto 9999 end if - -#if defined(OPENMP) - !$omp end parallel #endif if (info /= 0) goto 9999 endif diff --git a/base/tools/psb_zspins.F90 b/base/tools/psb_zspins.F90 index 36b0b5a5..0c0ff91f 100644 --- a/base/tools/psb_zspins.F90 +++ b/base/tools/psb_zspins.F90 @@ -135,28 +135,132 @@ subroutine psb_zspins(nz,ia,ja,val,a,desc_a,info,rebuild,local) goto 9999 end if #if defined(OPENMP) - !$omp parallel private(ila,jla,nrow,ncol,nnl,k) -#endif + block + logical :: is_in_parallel + is_in_parallel = omp_in_parallel() + if (is_in_parallel) then + !$omp parallel private(ila,jla,nrow,ncol,nnl,k) + call desc_a%indxmap%g2l(ia(1:nz),ila(1:nz),info,owned=.true.) + !$omp critical(spins) + if (info == 0) call desc_a%indxmap%g2l_ins(ja(1:nz),jla(1:nz),info,& + & mask=(ila(1:nz)>0)) + !$omp end critical(spins) + !write(0,*) me,' after g2l_ins ',psb_errstatus_fatal(),info + if (info /= psb_success_) then + call psb_errpush(psb_err_from_subroutine_ai_,name,& + & a_err='psb_cdins',i_err=(/info/)) + goto 9998 + end if + nrow = desc_a%get_local_rows() + ncol = desc_a%get_local_cols() + !write(0,*) me,' Before csput',psb_errstatus_fatal() + if (a%is_bld()) then + call a%csput(nz,ila,jla,val,ione,nrow,ione,ncol,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='a%csput') + goto 9998 + end if + + if (a%is_remote_build()) then + nnl = count(ila(1:nz)<0) + if (nnl > 0) then + allocate(lila(nnl),ljla(nnl),lval(nnl)) + k = 0 + do i=1,nz + if (ila(i)<0) then + k=k+1 + lila(k) = ia(i) + ljla(k) = ja(i) + lval(k) = val(i) + end if + end do + if (k /= nnl) write(0,*) name,' Wrong conversion?',k,nnl + call a%rmta%csput(nnl,lila,ljla,lval,1_psb_lpk_,desc_a%get_global_rows(),& + & 1_psb_lpk_,desc_a%get_global_rows(),info) + end if + end if + + else + info = psb_err_invalid_a_and_cd_state_ + call psb_errpush(info,name) + end if +9998 continue + !write(0,*) me,' after csput',psb_errstatus_fatal() + !$omp end parallel + else + call desc_a%indxmap%g2l(ia(1:nz),ila(1:nz),info,owned=.true.) + !write(0,*) me,' Before g2l_ins ',psb_errstatus_fatal() + if (info == 0) call desc_a%indxmap%g2l_ins(ja(1:nz),jla(1:nz),info,& + & mask=(ila(1:nz)>0)) + !write(0,*) me,' after g2l_ins ',psb_errstatus_fatal(),info + if (info /= psb_success_) then + call psb_errpush(psb_err_from_subroutine_ai_,name,& + & a_err='psb_cdins',i_err=(/info/)) + goto 9999 + end if + nrow = desc_a%get_local_rows() + ncol = desc_a%get_local_cols() + !write(0,*) me,' Before csput',psb_errstatus_fatal() + if (a%is_bld()) then + call a%csput(nz,ila,jla,val,ione,nrow,ione,ncol,info) + if (info /= psb_success_) then + info=psb_err_from_subroutine_ + call psb_errpush(info,name,a_err='a%csput') + goto 9999 + end if + + if (a%is_remote_build()) then + nnl = count(ila(1:nz)<0) + if (nnl > 0) then + allocate(lila(nnl),ljla(nnl),lval(nnl)) + k = 0 + do i=1,nz + if (ila(i)<0) then + k=k+1 + lila(k) = ia(i) + ljla(k) = ja(i) + lval(k) = val(i) + end if + end do + if (k /= nnl) write(0,*) name,' Wrong conversion?',k,nnl + call a%rmta%csput(nnl,lila,ljla,lval,1_psb_lpk_,desc_a%get_global_rows(),& + & 1_psb_lpk_,desc_a%get_global_rows(),info) + end if + end if + + else + info = psb_err_invalid_a_and_cd_state_ + call psb_errpush(info,name) + goto 9999 + end if + end if + end block +#else + + !write(0,*) me,' Before g2l ',psb_errstatus_fatal() call desc_a%indxmap%g2l(ia(1:nz),ila(1:nz),info,owned=.true.) + if (info == 0) call desc_a%indxmap%g2l_ins(ja(1:nz),jla(1:nz),info,& & mask=(ila(1:nz)>0)) - + + !write(0,*) me,' after g2l_ins ',psb_errstatus_fatal(),info if (info /= psb_success_) then call psb_errpush(psb_err_from_subroutine_ai_,name,& & a_err='psb_cdins',i_err=(/info/)) - !goto 9999 + goto 9999 end if nrow = desc_a%get_local_rows() ncol = desc_a%get_local_cols() - + !write(0,*) me,' Before csput',psb_errstatus_fatal() if (a%is_bld()) then call a%csput(nz,ila,jla,val,ione,nrow,ione,ncol,info) if (info /= psb_success_) then info=psb_err_from_subroutine_ call psb_errpush(info,name,a_err='a%csput') - !goto 9999 + goto 9999 end if - + if (a%is_remote_build()) then nnl = count(ila(1:nz)<0) if (nnl > 0) then @@ -179,11 +283,8 @@ subroutine psb_zspins(nz,ia,ja,val,a,desc_a,info,rebuild,local) else info = psb_err_invalid_a_and_cd_state_ call psb_errpush(info,name) - !goto 9999 + goto 9999 end if - -#if defined(OPENMP) - !$omp end parallel #endif if (info /= 0) goto 9999 endif diff --git a/test/pargen/psb_d_pde3d.F90 b/test/pargen/psb_d_pde3d.F90 index eebc5ad8..6e895c00 100644 --- a/test/pargen/psb_d_pde3d.F90 +++ b/test/pargen/psb_d_pde3d.F90 @@ -737,7 +737,6 @@ program psb_d_pde3d ! ! allocate and fill in the coefficient matrix, rhs and initial guess ! - call psb_cd_set_large_threshold(100_psb_lpk_) call psb_barrier(ctxt) t1 = psb_wtime() call psb_gen_pde3d(ctxt,idim,a,bv,xxv,desc_a,afmt,info,partition=ipart)