From fa5e7ff9455730ec6e5a07ad5fa6a9018eeac6a3 Mon Sep 17 00:00:00 2001 From: sfilippone Date: Tue, 20 Aug 2024 19:38:34 +0200 Subject: [PATCH] Fixes for vector methods and sync() --- openacc/impl/psb_c_oacc_mlt_v.f90 | 29 ++- openacc/impl/psb_c_oacc_mlt_v_2.f90 | 98 +++++----- openacc/impl/psb_d_oacc_mlt_v.f90 | 29 ++- openacc/impl/psb_d_oacc_mlt_v_2.f90 | 98 +++++----- openacc/impl/psb_s_oacc_mlt_v.f90 | 29 ++- openacc/impl/psb_s_oacc_mlt_v_2.f90 | 98 +++++----- openacc/impl/psb_z_oacc_mlt_v.f90 | 29 ++- openacc/impl/psb_z_oacc_mlt_v_2.f90 | 98 +++++----- openacc/psb_c_oacc_vect_mod.F90 | 279 ++++++++++++++++++---------- openacc/psb_d_oacc_vect_mod.F90 | 279 ++++++++++++++++++---------- openacc/psb_i_oacc_vect_mod.F90 | 54 +++--- openacc/psb_l_oacc_vect_mod.F90 | 54 +++--- openacc/psb_s_oacc_vect_mod.F90 | 279 ++++++++++++++++++---------- openacc/psb_z_oacc_vect_mod.F90 | 279 ++++++++++++++++++---------- 14 files changed, 1082 insertions(+), 650 deletions(-) diff --git a/openacc/impl/psb_c_oacc_mlt_v.f90 b/openacc/impl/psb_c_oacc_mlt_v.f90 index 66c4e865..a366543a 100644 --- a/openacc/impl/psb_c_oacc_mlt_v.f90 +++ b/openacc/impl/psb_c_oacc_mlt_v.f90 @@ -1,6 +1,6 @@ -subroutine c_oacc_mlt_v(x, y, info) - use psb_c_oacc_vect_mod, psb_protect_name => c_oacc_mlt_v +subroutine psb_c_oacc_mlt_v(x, y, info) + use psb_c_oacc_vect_mod, psb_protect_name => psb_c_oacc_mlt_v implicit none class(psb_c_base_vect_type), intent(inout) :: x @@ -9,16 +9,19 @@ subroutine c_oacc_mlt_v(x, y, info) integer(psb_ipk_) :: i, n + info = 0 + n = min(x%get_nrows(), y%get_nrows()) info = 0 n = min(x%get_nrows(), y%get_nrows()) select type(xx => x) class is (psb_c_vect_oacc) if (y%is_host()) call y%sync() if (xx%is_host()) call xx%sync() - !$acc parallel loop - do i = 1, n - y%v(i) = y%v(i) * xx%v(i) - end do + call c_inner_oacc_mlt_v(n,xx%v, y%v) +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ y%v(i) = y%v(i) * xx%v(i) +!!$ end do call y%set_dev() class default if (xx%is_dev()) call xx%sync() @@ -28,4 +31,16 @@ subroutine c_oacc_mlt_v(x, y, info) end do call y%set_host() end select -end subroutine c_oacc_mlt_v +contains + subroutine c_inner_oacc_mlt_v(n,x, y) + implicit none + integer(psb_ipk_), intent(in) :: n + complex(psb_spk_), intent(inout) :: x(:), y(:) + + integer(psb_ipk_) :: i + !$acc parallel loop present(x,y) + do i = 1, n + y(i) = (x(i)) * (y(i)) + end do + end subroutine c_inner_oacc_mlt_v +end subroutine psb_c_oacc_mlt_v diff --git a/openacc/impl/psb_c_oacc_mlt_v_2.f90 b/openacc/impl/psb_c_oacc_mlt_v_2.f90 index a6bb6cc5..f7bceae7 100644 --- a/openacc/impl/psb_c_oacc_mlt_v_2.f90 +++ b/openacc/impl/psb_c_oacc_mlt_v_2.f90 @@ -1,5 +1,5 @@ -subroutine c_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) - use psb_c_oacc_vect_mod, psb_protect_name => c_oacc_mlt_v_2 +subroutine psb_c_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) + use psb_c_oacc_vect_mod, psb_protect_name => psb_c_oacc_mlt_v_2 use psb_string_mod implicit none complex(psb_spk_), intent(in) :: alpha, beta @@ -25,33 +25,13 @@ subroutine c_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) if (xx%is_host()) call xx%sync() if (yy%is_host()) call yy%sync() if ((beta /= czero) .and. (z%is_host())) call z%sync() - if (conjgx_.and.conjgy_) then - !$acc parallel loop - do i = 1, n - z%v(i) = alpha * conjg(xx%v(i)) * conjg(yy%v(i)) + beta * z%v(i) - end do - else if (conjgx_.and.(.not.conjgy_)) then - !$acc parallel loop - do i = 1, n - z%v(i) = alpha * conjg(xx%v(i)) * (yy%v(i)) + beta * z%v(i) - end do - else if ((.not.conjgx_).and.(conjgy_)) then - !$acc parallel loop - do i = 1, n - z%v(i) = alpha * (xx%v(i)) * conjg(yy%v(i)) + beta * z%v(i) - end do - else - !$acc parallel loop - do i = 1, n - z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) - end do - - end if + call c_inner_oacc_mlt_v_2(n,alpha, xx%v, yy%v, beta, z%v, info, conjgx_, conjgy_) call z%set_dev() class default if (xx%is_dev()) call xx%sync() if (yy%is_dev()) call yy%sync() if ((beta /= czero) .and. (z%is_dev())) call z%sync() + !call c_inner_oacc_mlt_v_2(n,alpha, xx%v, yy%v, beta, z%v, info, conjgx_, conjgy_) if (conjgx_.and.conjgy_) then do i = 1, n z%v(i) = alpha * conjg(xx%v(i)) * conjg(yy%v(i)) + beta * z%v(i) @@ -67,7 +47,7 @@ subroutine c_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) else do i = 1, n z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) - end do + end do end if call z%set_host() end select @@ -75,24 +55,56 @@ subroutine c_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) if (x%is_dev()) call x%sync() if (y%is_dev()) call y%sync() if ((beta /= czero) .and. (z%is_dev())) call z%sync() - if (conjgx_.and.conjgy_) then - do i = 1, n - z%v(i) = alpha * conjg(x%v(i)) * conjg(y%v(i)) + beta * z%v(i) - end do - else if (conjgx_.and.(.not.conjgy_)) then - do i = 1, n - z%v(i) = alpha * conjg(x%v(i)) * (y%v(i)) + beta * z%v(i) - end do - else if ((.not.conjgx_).and.(conjgy_)) then - do i = 1, n - z%v(i) = alpha * (x%v(i)) * conjg(y%v(i)) + beta * z%v(i) - end do - else - do i = 1, n - z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) - end do - end if + if (conjgx_.and.conjgy_) then + do i = 1, n + z%v(i) = alpha * conjg(x%v(i)) * conjg(y%v(i)) + beta * z%v(i) + end do + else if (conjgx_.and.(.not.conjgy_)) then + do i = 1, n + z%v(i) = alpha * conjg(x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + else if ((.not.conjgx_).and.(conjgy_)) then + do i = 1, n + z%v(i) = alpha * (x%v(i)) * conjg(y%v(i)) + beta * z%v(i) + end do + else + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + end if call z%set_host() end select -end subroutine c_oacc_mlt_v_2 +contains + subroutine c_inner_oacc_mlt_v_2(n,alpha, x, y, beta, z, info, conjgx, conjgy) + implicit none + integer(psb_ipk_), intent(in) :: n +complex(psb_spk_), intent(in) :: alpha, beta +complex(psb_spk_), intent(inout) :: x(:), y(:), z(:) + integer(psb_ipk_), intent(out) :: info + logical, intent(in) :: conjgx, conjgy + + integer(psb_ipk_) :: i + if (conjgx.and.conjgy) then + !$acc parallel loop present(x,y,z) + do i = 1, n + z(i) = alpha * conjg(x(i)) * conjg(y(i)) + beta * z(i) + end do + else if (conjgx.and.(.not.conjgy)) then + !$acc parallel loop present(x,y,z) + do i = 1, n + z(i) = alpha * conjg(x(i)) * (y(i)) + beta * z(i) + end do + else if ((.not.conjgx).and.(conjgy)) then + !$acc parallel loop present(x,y,z) + do i = 1, n + z(i) = alpha * (x(i)) * conjg(y(i)) + beta * z(i) + end do + else + !$acc parallel loop present(x,y,z) + do i = 1, n + z(i) = alpha * (x(i)) * (y(i)) + beta * z(i) + end do + end if + end subroutine c_inner_oacc_mlt_v_2 +end subroutine psb_c_oacc_mlt_v_2 diff --git a/openacc/impl/psb_d_oacc_mlt_v.f90 b/openacc/impl/psb_d_oacc_mlt_v.f90 index bedd0247..dac62a65 100644 --- a/openacc/impl/psb_d_oacc_mlt_v.f90 +++ b/openacc/impl/psb_d_oacc_mlt_v.f90 @@ -1,6 +1,6 @@ -subroutine d_oacc_mlt_v(x, y, info) - use psb_d_oacc_vect_mod, psb_protect_name => d_oacc_mlt_v +subroutine psb_d_oacc_mlt_v(x, y, info) + use psb_d_oacc_vect_mod, psb_protect_name => psb_d_oacc_mlt_v implicit none class(psb_d_base_vect_type), intent(inout) :: x @@ -9,16 +9,19 @@ subroutine d_oacc_mlt_v(x, y, info) integer(psb_ipk_) :: i, n + info = 0 + n = min(x%get_nrows(), y%get_nrows()) info = 0 n = min(x%get_nrows(), y%get_nrows()) select type(xx => x) class is (psb_d_vect_oacc) if (y%is_host()) call y%sync() if (xx%is_host()) call xx%sync() - !$acc parallel loop - do i = 1, n - y%v(i) = y%v(i) * xx%v(i) - end do + call d_inner_oacc_mlt_v(n,xx%v, y%v) +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ y%v(i) = y%v(i) * xx%v(i) +!!$ end do call y%set_dev() class default if (xx%is_dev()) call xx%sync() @@ -28,4 +31,16 @@ subroutine d_oacc_mlt_v(x, y, info) end do call y%set_host() end select -end subroutine d_oacc_mlt_v +contains + subroutine d_inner_oacc_mlt_v(n,x, y) + implicit none + integer(psb_ipk_), intent(in) :: n + real(psb_dpk_), intent(inout) :: x(:), y(:) + + integer(psb_ipk_) :: i + !$acc parallel loop present(x,y) + do i = 1, n + y(i) = (x(i)) * (y(i)) + end do + end subroutine d_inner_oacc_mlt_v +end subroutine psb_d_oacc_mlt_v diff --git a/openacc/impl/psb_d_oacc_mlt_v_2.f90 b/openacc/impl/psb_d_oacc_mlt_v_2.f90 index e7dd604f..3f3a457d 100644 --- a/openacc/impl/psb_d_oacc_mlt_v_2.f90 +++ b/openacc/impl/psb_d_oacc_mlt_v_2.f90 @@ -1,5 +1,5 @@ -subroutine d_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) - use psb_d_oacc_vect_mod, psb_protect_name => d_oacc_mlt_v_2 +subroutine psb_d_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) + use psb_d_oacc_vect_mod, psb_protect_name => psb_d_oacc_mlt_v_2 use psb_string_mod implicit none real(psb_dpk_), intent(in) :: alpha, beta @@ -25,33 +25,13 @@ subroutine d_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) if (xx%is_host()) call xx%sync() if (yy%is_host()) call yy%sync() if ((beta /= dzero) .and. (z%is_host())) call z%sync() - if (conjgx_.and.conjgy_) then - !$acc parallel loop - do i = 1, n - z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) - end do - else if (conjgx_.and.(.not.conjgy_)) then - !$acc parallel loop - do i = 1, n - z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) - end do - else if ((.not.conjgx_).and.(conjgy_)) then - !$acc parallel loop - do i = 1, n - z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) - end do - else - !$acc parallel loop - do i = 1, n - z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) - end do - - end if + call d_inner_oacc_mlt_v_2(n,alpha, xx%v, yy%v, beta, z%v, info, conjgx_, conjgy_) call z%set_dev() class default if (xx%is_dev()) call xx%sync() if (yy%is_dev()) call yy%sync() if ((beta /= dzero) .and. (z%is_dev())) call z%sync() + !call d_inner_oacc_mlt_v_2(n,alpha, xx%v, yy%v, beta, z%v, info, conjgx_, conjgy_) if (conjgx_.and.conjgy_) then do i = 1, n z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) @@ -67,7 +47,7 @@ subroutine d_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) else do i = 1, n z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) - end do + end do end if call z%set_host() end select @@ -75,24 +55,56 @@ subroutine d_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) if (x%is_dev()) call x%sync() if (y%is_dev()) call y%sync() if ((beta /= dzero) .and. (z%is_dev())) call z%sync() - if (conjgx_.and.conjgy_) then - do i = 1, n - z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) - end do - else if (conjgx_.and.(.not.conjgy_)) then - do i = 1, n - z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) - end do - else if ((.not.conjgx_).and.(conjgy_)) then - do i = 1, n - z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) - end do - else - do i = 1, n - z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) - end do - end if + if (conjgx_.and.conjgy_) then + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + else if (conjgx_.and.(.not.conjgy_)) then + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + else if ((.not.conjgx_).and.(conjgy_)) then + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + else + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + end if call z%set_host() end select -end subroutine d_oacc_mlt_v_2 +contains + subroutine d_inner_oacc_mlt_v_2(n,alpha, x, y, beta, z, info, conjgx, conjgy) + implicit none + integer(psb_ipk_), intent(in) :: n +real(psb_dpk_), intent(in) :: alpha, beta +real(psb_dpk_), intent(inout) :: x(:), y(:), z(:) + integer(psb_ipk_), intent(out) :: info + logical, intent(in) :: conjgx, conjgy + + integer(psb_ipk_) :: i + if (conjgx.and.conjgy) then + !$acc parallel loop present(x,y,z) + do i = 1, n + z(i) = alpha * (x(i)) * (y(i)) + beta * z(i) + end do + else if (conjgx.and.(.not.conjgy)) then + !$acc parallel loop present(x,y,z) + do i = 1, n + z(i) = alpha * (x(i)) * (y(i)) + beta * z(i) + end do + else if ((.not.conjgx).and.(conjgy)) then + !$acc parallel loop present(x,y,z) + do i = 1, n + z(i) = alpha * (x(i)) * (y(i)) + beta * z(i) + end do + else + !$acc parallel loop present(x,y,z) + do i = 1, n + z(i) = alpha * (x(i)) * (y(i)) + beta * z(i) + end do + end if + end subroutine d_inner_oacc_mlt_v_2 +end subroutine psb_d_oacc_mlt_v_2 diff --git a/openacc/impl/psb_s_oacc_mlt_v.f90 b/openacc/impl/psb_s_oacc_mlt_v.f90 index fb043cf2..61a1d152 100644 --- a/openacc/impl/psb_s_oacc_mlt_v.f90 +++ b/openacc/impl/psb_s_oacc_mlt_v.f90 @@ -1,6 +1,6 @@ -subroutine s_oacc_mlt_v(x, y, info) - use psb_s_oacc_vect_mod, psb_protect_name => s_oacc_mlt_v +subroutine psb_s_oacc_mlt_v(x, y, info) + use psb_s_oacc_vect_mod, psb_protect_name => psb_s_oacc_mlt_v implicit none class(psb_s_base_vect_type), intent(inout) :: x @@ -9,16 +9,19 @@ subroutine s_oacc_mlt_v(x, y, info) integer(psb_ipk_) :: i, n + info = 0 + n = min(x%get_nrows(), y%get_nrows()) info = 0 n = min(x%get_nrows(), y%get_nrows()) select type(xx => x) class is (psb_s_vect_oacc) if (y%is_host()) call y%sync() if (xx%is_host()) call xx%sync() - !$acc parallel loop - do i = 1, n - y%v(i) = y%v(i) * xx%v(i) - end do + call s_inner_oacc_mlt_v(n,xx%v, y%v) +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ y%v(i) = y%v(i) * xx%v(i) +!!$ end do call y%set_dev() class default if (xx%is_dev()) call xx%sync() @@ -28,4 +31,16 @@ subroutine s_oacc_mlt_v(x, y, info) end do call y%set_host() end select -end subroutine s_oacc_mlt_v +contains + subroutine s_inner_oacc_mlt_v(n,x, y) + implicit none + integer(psb_ipk_), intent(in) :: n + real(psb_spk_), intent(inout) :: x(:), y(:) + + integer(psb_ipk_) :: i + !$acc parallel loop present(x,y) + do i = 1, n + y(i) = (x(i)) * (y(i)) + end do + end subroutine s_inner_oacc_mlt_v +end subroutine psb_s_oacc_mlt_v diff --git a/openacc/impl/psb_s_oacc_mlt_v_2.f90 b/openacc/impl/psb_s_oacc_mlt_v_2.f90 index 04ee8e09..bcaebfbe 100644 --- a/openacc/impl/psb_s_oacc_mlt_v_2.f90 +++ b/openacc/impl/psb_s_oacc_mlt_v_2.f90 @@ -1,5 +1,5 @@ -subroutine s_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) - use psb_s_oacc_vect_mod, psb_protect_name => s_oacc_mlt_v_2 +subroutine psb_s_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) + use psb_s_oacc_vect_mod, psb_protect_name => psb_s_oacc_mlt_v_2 use psb_string_mod implicit none real(psb_spk_), intent(in) :: alpha, beta @@ -25,33 +25,13 @@ subroutine s_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) if (xx%is_host()) call xx%sync() if (yy%is_host()) call yy%sync() if ((beta /= szero) .and. (z%is_host())) call z%sync() - if (conjgx_.and.conjgy_) then - !$acc parallel loop - do i = 1, n - z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) - end do - else if (conjgx_.and.(.not.conjgy_)) then - !$acc parallel loop - do i = 1, n - z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) - end do - else if ((.not.conjgx_).and.(conjgy_)) then - !$acc parallel loop - do i = 1, n - z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) - end do - else - !$acc parallel loop - do i = 1, n - z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) - end do - - end if + call s_inner_oacc_mlt_v_2(n,alpha, xx%v, yy%v, beta, z%v, info, conjgx_, conjgy_) call z%set_dev() class default if (xx%is_dev()) call xx%sync() if (yy%is_dev()) call yy%sync() if ((beta /= szero) .and. (z%is_dev())) call z%sync() + !call s_inner_oacc_mlt_v_2(n,alpha, xx%v, yy%v, beta, z%v, info, conjgx_, conjgy_) if (conjgx_.and.conjgy_) then do i = 1, n z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) @@ -67,7 +47,7 @@ subroutine s_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) else do i = 1, n z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) - end do + end do end if call z%set_host() end select @@ -75,24 +55,56 @@ subroutine s_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) if (x%is_dev()) call x%sync() if (y%is_dev()) call y%sync() if ((beta /= szero) .and. (z%is_dev())) call z%sync() - if (conjgx_.and.conjgy_) then - do i = 1, n - z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) - end do - else if (conjgx_.and.(.not.conjgy_)) then - do i = 1, n - z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) - end do - else if ((.not.conjgx_).and.(conjgy_)) then - do i = 1, n - z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) - end do - else - do i = 1, n - z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) - end do - end if + if (conjgx_.and.conjgy_) then + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + else if (conjgx_.and.(.not.conjgy_)) then + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + else if ((.not.conjgx_).and.(conjgy_)) then + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + else + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + end if call z%set_host() end select -end subroutine s_oacc_mlt_v_2 +contains + subroutine s_inner_oacc_mlt_v_2(n,alpha, x, y, beta, z, info, conjgx, conjgy) + implicit none + integer(psb_ipk_), intent(in) :: n +real(psb_spk_), intent(in) :: alpha, beta +real(psb_spk_), intent(inout) :: x(:), y(:), z(:) + integer(psb_ipk_), intent(out) :: info + logical, intent(in) :: conjgx, conjgy + + integer(psb_ipk_) :: i + if (conjgx.and.conjgy) then + !$acc parallel loop present(x,y,z) + do i = 1, n + z(i) = alpha * (x(i)) * (y(i)) + beta * z(i) + end do + else if (conjgx.and.(.not.conjgy)) then + !$acc parallel loop present(x,y,z) + do i = 1, n + z(i) = alpha * (x(i)) * (y(i)) + beta * z(i) + end do + else if ((.not.conjgx).and.(conjgy)) then + !$acc parallel loop present(x,y,z) + do i = 1, n + z(i) = alpha * (x(i)) * (y(i)) + beta * z(i) + end do + else + !$acc parallel loop present(x,y,z) + do i = 1, n + z(i) = alpha * (x(i)) * (y(i)) + beta * z(i) + end do + end if + end subroutine s_inner_oacc_mlt_v_2 +end subroutine psb_s_oacc_mlt_v_2 diff --git a/openacc/impl/psb_z_oacc_mlt_v.f90 b/openacc/impl/psb_z_oacc_mlt_v.f90 index 7018f009..4bc582d2 100644 --- a/openacc/impl/psb_z_oacc_mlt_v.f90 +++ b/openacc/impl/psb_z_oacc_mlt_v.f90 @@ -1,6 +1,6 @@ -subroutine z_oacc_mlt_v(x, y, info) - use psb_z_oacc_vect_mod, psb_protect_name => z_oacc_mlt_v +subroutine psb_z_oacc_mlt_v(x, y, info) + use psb_z_oacc_vect_mod, psb_protect_name => psb_z_oacc_mlt_v implicit none class(psb_z_base_vect_type), intent(inout) :: x @@ -9,16 +9,19 @@ subroutine z_oacc_mlt_v(x, y, info) integer(psb_ipk_) :: i, n + info = 0 + n = min(x%get_nrows(), y%get_nrows()) info = 0 n = min(x%get_nrows(), y%get_nrows()) select type(xx => x) class is (psb_z_vect_oacc) if (y%is_host()) call y%sync() if (xx%is_host()) call xx%sync() - !$acc parallel loop - do i = 1, n - y%v(i) = y%v(i) * xx%v(i) - end do + call z_inner_oacc_mlt_v(n,xx%v, y%v) +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ y%v(i) = y%v(i) * xx%v(i) +!!$ end do call y%set_dev() class default if (xx%is_dev()) call xx%sync() @@ -28,4 +31,16 @@ subroutine z_oacc_mlt_v(x, y, info) end do call y%set_host() end select -end subroutine z_oacc_mlt_v +contains + subroutine z_inner_oacc_mlt_v(n,x, y) + implicit none + integer(psb_ipk_), intent(in) :: n + complex(psb_dpk_), intent(inout) :: x(:), y(:) + + integer(psb_ipk_) :: i + !$acc parallel loop present(x,y) + do i = 1, n + y(i) = (x(i)) * (y(i)) + end do + end subroutine z_inner_oacc_mlt_v +end subroutine psb_z_oacc_mlt_v diff --git a/openacc/impl/psb_z_oacc_mlt_v_2.f90 b/openacc/impl/psb_z_oacc_mlt_v_2.f90 index dbc0929c..337a0a96 100644 --- a/openacc/impl/psb_z_oacc_mlt_v_2.f90 +++ b/openacc/impl/psb_z_oacc_mlt_v_2.f90 @@ -1,5 +1,5 @@ -subroutine z_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) - use psb_z_oacc_vect_mod, psb_protect_name => z_oacc_mlt_v_2 +subroutine psb_z_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) + use psb_z_oacc_vect_mod, psb_protect_name => psb_z_oacc_mlt_v_2 use psb_string_mod implicit none complex(psb_dpk_), intent(in) :: alpha, beta @@ -25,33 +25,13 @@ subroutine z_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) if (xx%is_host()) call xx%sync() if (yy%is_host()) call yy%sync() if ((beta /= zzero) .and. (z%is_host())) call z%sync() - if (conjgx_.and.conjgy_) then - !$acc parallel loop - do i = 1, n - z%v(i) = alpha * conjg(xx%v(i)) * conjg(yy%v(i)) + beta * z%v(i) - end do - else if (conjgx_.and.(.not.conjgy_)) then - !$acc parallel loop - do i = 1, n - z%v(i) = alpha * conjg(xx%v(i)) * (yy%v(i)) + beta * z%v(i) - end do - else if ((.not.conjgx_).and.(conjgy_)) then - !$acc parallel loop - do i = 1, n - z%v(i) = alpha * (xx%v(i)) * conjg(yy%v(i)) + beta * z%v(i) - end do - else - !$acc parallel loop - do i = 1, n - z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) - end do - - end if + call z_inner_oacc_mlt_v_2(n,alpha, xx%v, yy%v, beta, z%v, info, conjgx_, conjgy_) call z%set_dev() class default if (xx%is_dev()) call xx%sync() if (yy%is_dev()) call yy%sync() if ((beta /= zzero) .and. (z%is_dev())) call z%sync() + !call z_inner_oacc_mlt_v_2(n,alpha, xx%v, yy%v, beta, z%v, info, conjgx_, conjgy_) if (conjgx_.and.conjgy_) then do i = 1, n z%v(i) = alpha * conjg(xx%v(i)) * conjg(yy%v(i)) + beta * z%v(i) @@ -67,7 +47,7 @@ subroutine z_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) else do i = 1, n z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) - end do + end do end if call z%set_host() end select @@ -75,24 +55,56 @@ subroutine z_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) if (x%is_dev()) call x%sync() if (y%is_dev()) call y%sync() if ((beta /= zzero) .and. (z%is_dev())) call z%sync() - if (conjgx_.and.conjgy_) then - do i = 1, n - z%v(i) = alpha * conjg(x%v(i)) * conjg(y%v(i)) + beta * z%v(i) - end do - else if (conjgx_.and.(.not.conjgy_)) then - do i = 1, n - z%v(i) = alpha * conjg(x%v(i)) * (y%v(i)) + beta * z%v(i) - end do - else if ((.not.conjgx_).and.(conjgy_)) then - do i = 1, n - z%v(i) = alpha * (x%v(i)) * conjg(y%v(i)) + beta * z%v(i) - end do - else - do i = 1, n - z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) - end do - end if + if (conjgx_.and.conjgy_) then + do i = 1, n + z%v(i) = alpha * conjg(x%v(i)) * conjg(y%v(i)) + beta * z%v(i) + end do + else if (conjgx_.and.(.not.conjgy_)) then + do i = 1, n + z%v(i) = alpha * conjg(x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + else if ((.not.conjgx_).and.(conjgy_)) then + do i = 1, n + z%v(i) = alpha * (x%v(i)) * conjg(y%v(i)) + beta * z%v(i) + end do + else + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + end if call z%set_host() end select -end subroutine z_oacc_mlt_v_2 +contains + subroutine z_inner_oacc_mlt_v_2(n,alpha, x, y, beta, z, info, conjgx, conjgy) + implicit none + integer(psb_ipk_), intent(in) :: n +complex(psb_dpk_), intent(in) :: alpha, beta +complex(psb_dpk_), intent(inout) :: x(:), y(:), z(:) + integer(psb_ipk_), intent(out) :: info + logical, intent(in) :: conjgx, conjgy + + integer(psb_ipk_) :: i + if (conjgx.and.conjgy) then + !$acc parallel loop present(x,y,z) + do i = 1, n + z(i) = alpha * conjg(x(i)) * conjg(y(i)) + beta * z(i) + end do + else if (conjgx.and.(.not.conjgy)) then + !$acc parallel loop present(x,y,z) + do i = 1, n + z(i) = alpha * conjg(x(i)) * (y(i)) + beta * z(i) + end do + else if ((.not.conjgx).and.(conjgy)) then + !$acc parallel loop present(x,y,z) + do i = 1, n + z(i) = alpha * (x(i)) * conjg(y(i)) + beta * z(i) + end do + else + !$acc parallel loop present(x,y,z) + do i = 1, n + z(i) = alpha * (x(i)) * (y(i)) + beta * z(i) + end do + end if + end subroutine z_inner_oacc_mlt_v_2 +end subroutine psb_z_oacc_mlt_v_2 diff --git a/openacc/psb_c_oacc_vect_mod.F90 b/openacc/psb_c_oacc_vect_mod.F90 index fc501e04..7362ba0e 100644 --- a/openacc/psb_c_oacc_vect_mod.F90 +++ b/openacc/psb_c_oacc_vect_mod.F90 @@ -1,5 +1,6 @@ module psb_c_oacc_vect_mod use iso_c_binding + use openacc use psb_const_mod use psb_error_mod use psb_c_vect_mod @@ -50,8 +51,8 @@ module psb_c_oacc_vect_mod procedure, pass(z) :: upd_xyz => c_oacc_upd_xyz procedure, pass(y) :: mlt_a => c_oacc_mlt_a procedure, pass(z) :: mlt_a_2 => c_oacc_mlt_a_2 - procedure, pass(y) :: mlt_v => c_oacc_mlt_v - procedure, pass(z) :: mlt_v_2 => c_oacc_mlt_v_2 + procedure, pass(y) :: mlt_v => psb_c_oacc_mlt_v + procedure, pass(z) :: mlt_v_2 => psb_c_oacc_mlt_v_2 procedure, pass(x) :: scal => c_oacc_scal procedure, pass(x) :: nrm2 => c_oacc_nrm2 procedure, pass(x) :: amax => c_oacc_amax @@ -62,17 +63,17 @@ module psb_c_oacc_vect_mod end type psb_c_vect_oacc interface - subroutine c_oacc_mlt_v(x, y, info) + subroutine psb_c_oacc_mlt_v(x, y, info) import implicit none class(psb_c_base_vect_type), intent(inout) :: x class(psb_c_vect_oacc), intent(inout) :: y integer(psb_ipk_), intent(out) :: info - end subroutine c_oacc_mlt_v + end subroutine psb_c_oacc_mlt_v end interface interface - subroutine c_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) + subroutine psb_c_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) import implicit none complex(psb_spk_), intent(in) :: alpha, beta @@ -81,7 +82,7 @@ module psb_c_oacc_vect_mod class(psb_c_vect_oacc), intent(inout) :: z integer(psb_ipk_), intent(out) :: info character(len=1), intent(in), optional :: conjgx, conjgy - end subroutine c_oacc_mlt_v_2 + end subroutine psb_c_oacc_mlt_v_2 end interface contains @@ -89,15 +90,23 @@ contains subroutine c_oacc_absval1(x) implicit none class(psb_c_vect_oacc), intent(inout) :: x - integer(psb_ipk_) :: n, i + integer(psb_ipk_) :: n - if (x%is_host()) call x%sync_space() + if (x%is_host()) call x%sync() n = size(x%v) - !$acc parallel loop - do i = 1, n - x%v(i) = abs(x%v(i)) - end do + call c_inner_oacc_absval1(n,x%v) call x%set_dev() + contains + subroutine c_inner_oacc_absval1(n,x) + implicit none + complex(psb_spk_), intent(inout) :: x(:) + integer(psb_ipk_) :: n + integer(psb_ipk_) :: i + !$acc parallel loop + do i = 1, n + x(i) = abs(x(i)) + end do + end subroutine c_inner_oacc_absval1 end subroutine c_oacc_absval1 subroutine c_oacc_absval2(x, y) @@ -112,15 +121,23 @@ contains class is (psb_c_vect_oacc) if (x%is_host()) call x%sync() if (yy%is_host()) call yy%sync() - !$acc parallel loop - do i = 1, n - yy%v(i) = abs(x%v(i)) - end do + call c_inner_oacc_absval2(n,x%v,yy%v) class default if (x%is_dev()) call x%sync() if (y%is_dev()) call y%sync() call x%psb_c_base_vect_type%absval(y) end select + contains + subroutine c_inner_oacc_absval2(n,x,y) + implicit none + complex(psb_spk_), intent(inout) :: x(:),y(:) + integer(psb_ipk_) :: n + integer(psb_ipk_) :: i + !$acc parallel loop + do i = 1, n + y(i) = abs(x(i)) + end do + end subroutine c_inner_oacc_absval2 end subroutine c_oacc_absval2 subroutine c_oacc_scal(alpha, x) @@ -128,32 +145,46 @@ contains class(psb_c_vect_oacc), intent(inout) :: x complex(psb_spk_), intent(in) :: alpha integer(psb_ipk_) :: info - integer(psb_ipk_) :: i - - if (x%is_host()) call x%sync_space() - !$acc parallel loop - do i = 1, size(x%v) - x%v(i) = alpha * x%v(i) - end do + if (x%is_host()) call x%sync() + call c_inner_oacc_scal(alpha, x%v) call x%set_dev() + contains + subroutine c_inner_oacc_scal(alpha, x) + complex(psb_spk_), intent(in) :: alpha + complex(psb_spk_), intent(inout) :: x(:) + integer(psb_ipk_) :: i + !$acc parallel loop + do i = 1, size(x) + x(i) = alpha * x(i) + end do + end subroutine c_inner_oacc_scal end subroutine c_oacc_scal function c_oacc_nrm2(n, x) result(res) implicit none class(psb_c_vect_oacc), intent(inout) :: x integer(psb_ipk_), intent(in) :: n - real(psb_spk_) :: res + real(psb_spk_) :: res + real(psb_spk_) :: mx integer(psb_ipk_) :: info - real(psb_spk_) :: sum - integer(psb_ipk_) :: i - if (x%is_host()) call x%sync_space() - sum = 0.0 - !$acc parallel loop reduction(+:sum) - do i = 1, n - sum = sum + abs(x%v(i))**2 - end do - res = sqrt(sum) + if (x%is_host()) call x%sync() + mx = c_oacc_amax(n,x) + res = c_inner_oacc_nrm2(n, mx, x%v) + contains + function c_inner_oacc_nrm2(n, mx,x) result(res) + integer(psb_ipk_) :: n + complex(psb_spk_) :: x(:) + real(psb_spk_) :: mx, res + real(psb_spk_) :: sum + integer(psb_ipk_) :: i + sum = 0.0 + !$acc parallel loop reduction(+:sum) + do i = 1, n + sum = sum + abs(x(i)/mx)**2 + end do + res = mx*sqrt(sum) + end function c_inner_oacc_nrm2 end function c_oacc_nrm2 function c_oacc_amax(n, x) result(res) @@ -162,18 +193,25 @@ contains integer(psb_ipk_), intent(in) :: n real(psb_spk_) :: res integer(psb_ipk_) :: info - real(psb_spk_) :: max_val - integer(psb_ipk_) :: i - if (x%is_host()) call x%sync_space() - max_val = -huge(0.0) - !$acc parallel loop reduction(max:max_val) - do i = 1, n - if (abs(x%v(i)) > max_val) max_val = abs(x%v(i)) - end do - res = max_val + if (x%is_host()) call x%sync() + res = c_inner_oacc_amax(n, x%v) + contains + function c_inner_oacc_amax(n, x) result(res) + integer(psb_ipk_) :: n + complex(psb_spk_) :: x(:) + real(psb_spk_) :: res + real(psb_spk_) :: max_val + integer(psb_ipk_) :: i + max_val = -huge(0.0) + !$acc parallel loop reduction(max:max_val) + do i = 1, n + if (abs(x(i)) > max_val) max_val = abs(x(i)) + end do + res = max_val + end function c_inner_oacc_amax end function c_oacc_amax - + function c_oacc_asum(n, x) result(res) implicit none class(psb_c_vect_oacc), intent(inout) :: x @@ -182,14 +220,20 @@ contains integer(psb_ipk_) :: info complex(psb_spk_) :: sum integer(psb_ipk_) :: i - - if (x%is_host()) call x%sync_space() - sum = 0.0 - !$acc parallel loop reduction(+:sum) - do i = 1, n - sum = sum + abs(x%v(i)) - end do - res = sum + if (x%is_host()) call x%sync() + res = c_inner_oacc_asum(n, x%v) + contains + function c_inner_oacc_asum(n, x) result(res) + integer(psb_ipk_) :: n + complex(psb_spk_) :: x(:) + real(psb_spk_) :: res + integer(psb_ipk_) :: i + res = 0.0 + !$acc parallel loop reduction(+:res) + do i = 1, n + res = res + abs(x(i)) + end do + end function c_inner_oacc_asum end function c_oacc_asum @@ -201,7 +245,7 @@ contains integer(psb_ipk_) :: i, n info = 0 - if (y%is_dev()) call y%sync_space() + if (y%is_dev()) call y%sync() !$acc parallel loop do i = 1, size(x) y%v(i) = y%v(i) * x(i) @@ -219,7 +263,7 @@ contains integer(psb_ipk_) :: i, n info = 0 - if (z%is_dev()) call z%sync_space() + if (z%is_dev()) call z%sync() !$acc parallel loop do i = 1, size(x) z%v(i) = alpha * x(i) * y(i) + beta * z%v(i) @@ -282,18 +326,18 @@ contains !!$ class is (psb_c_vect_oacc) !!$ select type (yy => y) !!$ class is (psb_c_vect_oacc) -!!$ if (xx%is_host()) call xx%sync_space() -!!$ if (yy%is_host()) call yy%sync_space() -!!$ if ((beta /= czero) .and. (z%is_host())) call z%sync_space() +!!$ if (xx%is_host()) call xx%sync() +!!$ if (yy%is_host()) call yy%sync() +!!$ if ((beta /= czero) .and. (z%is_host())) call z%sync() !!$ !$acc parallel loop !!$ do i = 1, n !!$ z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) !!$ end do !!$ call z%set_dev() !!$ class default -!!$ if (xx%is_dev()) call xx%sync_space() +!!$ if (xx%is_dev()) call xx%sync() !!$ if (yy%is_dev()) call yy%sync() -!!$ if ((beta /= czero) .and. (z%is_dev())) call z%sync_space() +!!$ if ((beta /= czero) .and. (z%is_dev())) call z%sync() !!$ !$acc parallel loop !!$ do i = 1, n !!$ z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) @@ -303,7 +347,7 @@ contains !!$ class default !!$ if (x%is_dev()) call x%sync() !!$ if (y%is_dev()) call y%sync() -!!$ if ((beta /= czero) .and. (z%is_dev())) call z%sync_space() +!!$ if ((beta /= czero) .and. (z%is_dev())) call z%sync() !!$ !$acc parallel loop !!$ do i = 1, n !!$ z%v(i) = alpha * x%v(i) * y%v(i) + beta * z%v(i) @@ -327,23 +371,36 @@ contains select type(xx => x) type is (psb_c_vect_oacc) - if ((beta /= czero) .and. y%is_host()) call y%sync_space() - if (xx%is_host()) call xx%sync_space() + if ((beta /= czero) .and. y%is_host()) call y%sync() + if (xx%is_host()) call xx%sync() nx = size(xx%v) ny = size(y%v) if ((nx < m) .or. (ny < m)) then info = psb_err_internal_error_ else - !$acc parallel loop - do i = 1, m - y%v(i) = alpha * xx%v(i) + beta * y%v(i) - end do + call c_inner_oacc_axpby(m, alpha, x%v, beta, y%v, info) end if call y%set_dev() class default if ((alpha /= czero) .and. (x%is_dev())) call x%sync() call y%axpby(m, alpha, x%v, beta, info) - end select + end select + contains + subroutine c_inner_oacc_axpby(m, alpha, x, beta, y, info) + !use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + complex(psb_spk_), intent(inout) :: x(:) + complex(psb_spk_), intent(inout) :: y(:) + complex(psb_spk_), intent(in) :: alpha, beta + integer(psb_ipk_), intent(out) :: info + !$acc parallel + !$acc loop + do i = 1, m + y(i) = alpha * x(i) + beta * y(i) + end do + !$acc end parallel + end subroutine c_inner_oacc_axpby end subroutine c_oacc_axpby_v subroutine c_oacc_axpby_a(m, alpha, x, beta, y, info) @@ -356,7 +413,7 @@ contains integer(psb_ipk_), intent(out) :: info integer(psb_ipk_) :: i - if ((beta /= czero) .and. (y%is_dev())) call y%sync_space() + if ((beta /= czero) .and. (y%is_dev())) call y%sync() !$acc parallel loop do i = 1, m y%v(i) = alpha * x(i) + beta * y%v(i) @@ -375,7 +432,7 @@ contains integer(psb_ipk_), intent(out) :: info integer(psb_ipk_) :: nx, ny, nz, i logical :: gpu_done - + write(0,*)'upd_xyz' info = psb_success_ gpu_done = .false. @@ -385,9 +442,9 @@ contains class is (psb_c_vect_oacc) select type(zz => z) class is (psb_c_vect_oacc) - if ((beta /= czero) .and. yy%is_host()) call yy%sync_space() - if ((delta /= czero) .and. zz%is_host()) call zz%sync_space() - if (xx%is_host()) call xx%sync_space() + if ((beta /= czero) .and. yy%is_host()) call yy%sync() + if ((delta /= czero) .and. zz%is_host()) call zz%sync() + if (xx%is_host()) call xx%sync() nx = size(xx%v) ny = size(yy%v) nz = size(zz%v) @@ -432,8 +489,8 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() - if (y%is_host()) call y%sync_space() + if (ii%is_host()) call ii%sync() + if (y%is_host()) call y%sync() !$acc parallel loop do i = 1, n @@ -459,13 +516,13 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() + if (ii%is_host()) call ii%sync() class default call psb_errpush(info, 'c_oacc_sctb_x') return end select - if (y%is_host()) call y%sync_space() + if (y%is_host()) call y%sync() !$acc parallel loop do i = 1, n @@ -486,7 +543,7 @@ contains integer(psb_ipk_) :: i if (n == 0) return - if (y%is_dev()) call y%sync_space() + if (y%is_dev()) call y%sync() !$acc parallel loop do i = 1, n @@ -512,13 +569,13 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() + if (ii%is_host()) call ii%sync() class default call psb_errpush(info, 'c_oacc_gthzbuf') return end select - if (x%is_host()) call x%sync_space() + if (x%is_host()) call x%sync() !$acc parallel loop do i = 1, n @@ -539,13 +596,13 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() + if (ii%is_host()) call ii%sync() class default call psb_errpush(info, 'c_oacc_gthzv_x') return end select - if (x%is_host()) call x%sync_space() + if (x%is_host()) call x%sync() !$acc parallel loop do i = 1, n @@ -573,9 +630,9 @@ contains type is (psb_i_vect_oacc) select type(vval => val) type is (psb_c_vect_oacc) - if (vval%is_host()) call vval%sync_space() - if (virl%is_host()) call virl%sync_space() - if (x%is_host()) call x%sync_space() + if (vval%is_host()) call vval%sync() + if (virl%is_host()) call virl%sync() + if (x%is_host()) call x%sync() !$acc parallel loop do i = 1, n x%v(virl%v(i)) = vval%v(i) @@ -588,11 +645,11 @@ contains if (.not.done_oacc) then select type(virl => irl) type is (psb_i_vect_oacc) - if (virl%is_dev()) call virl%sync_space() + if (virl%is_dev()) call virl%sync() end select select type(vval => val) type is (psb_c_vect_oacc) - if (vval%is_dev()) call vval%sync_space() + if (vval%is_dev()) call vval%sync() end select call x%ins(n, irl%v, val%v, dupl, info) end if @@ -616,7 +673,7 @@ contains integer(psb_ipk_) :: i info = 0 - if (x%is_dev()) call x%sync_space() + if (x%is_dev()) call x%sync() call x%psb_c_base_vect_type%ins(n, irl, val, dupl, info) call x%set_host() !$acc update device(x%v) @@ -635,7 +692,10 @@ contains call psb_errpush(info, 'c_oacc_bld_mn', i_err=(/n, n, n, n, n/)) end if call x%set_host() - !$acc update device(x%v) + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if + !$acc enter data copyin(x%v) end subroutine c_oacc_bld_mn @@ -657,7 +717,10 @@ contains x%v(:) = this(:) call x%set_host() - !$acc update device(x%v) + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if + !$acc enter data copyin(x%v) end subroutine c_oacc_bld_x @@ -676,13 +739,13 @@ contains if (nd < n) then call x%sync() call x%psb_c_base_vect_type%asb(n, info) - if (info == psb_success_) call x%sync_space() + if (info == psb_success_) call x%sync() call x%set_host() end if else if (size(x%v) < n) then call x%psb_c_base_vect_type%asb(n, info) - if (info == psb_success_) call x%sync_space() + if (info == psb_success_) call x%sync() call x%set_host() end if end if @@ -740,10 +803,9 @@ contains complex(psb_spk_) :: res complex(psb_spk_), external :: ddot integer(psb_ipk_) :: info - integer(psb_ipk_) :: i res = czero - + !write(0,*) 'dot_v' select type(yy => y) type is (psb_c_base_vect_type) if (x%is_dev()) call x%sync() @@ -751,18 +813,26 @@ contains type is (psb_c_vect_oacc) if (x%is_host()) call x%sync() if (yy%is_host()) call yy%sync() - - !$acc parallel loop reduction(+:res) present(x%v, yy%v) - do i = 1, n - res = res + x%v(i) * yy%v(i) - end do - !$acc end parallel loop - + res = c_inner_oacc_dot(n, x%v, yy%v) class default call x%sync() res = y%dot(n, x%v) end select - + contains + function c_inner_oacc_dot(n, x, y) result(res) + implicit none + complex(psb_spk_), intent(in) :: x(:) + complex(psb_spk_), intent(in) :: y(:) + integer(psb_ipk_), intent(in) :: n + complex(psb_spk_) :: res + integer(psb_ipk_) :: i + + !$acc parallel loop reduction(+:res) present(x, y) + do i = 1, n + res = res + x(i) * y(i) + end do + !$acc end parallel loop + end function c_inner_oacc_dot end function c_oacc_vect_dot function c_oacc_dot_a(n, x, y) result(res) @@ -808,7 +878,7 @@ contains implicit none class(psb_c_vect_oacc), intent(inout) :: x if (allocated(x%v)) then - call c_oacc_create_dev(x%v) + if (.not.acc_is_present(x%v)) call c_oacc_create_dev(x%v) end if contains subroutine c_oacc_create_dev(v) @@ -886,6 +956,9 @@ contains call psb_realloc(n, x%v, info) if (info == 0) then call x%set_host() + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if !$acc enter data create(x%v) call x%sync_space() end if @@ -902,7 +975,9 @@ contains integer(psb_ipk_), intent(out) :: info info = 0 if (allocated(x%v)) then - !$acc exit data delete(x%v) finalize + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if deallocate(x%v, stat=info) end if diff --git a/openacc/psb_d_oacc_vect_mod.F90 b/openacc/psb_d_oacc_vect_mod.F90 index bfb97b5c..9ecbccb4 100644 --- a/openacc/psb_d_oacc_vect_mod.F90 +++ b/openacc/psb_d_oacc_vect_mod.F90 @@ -1,5 +1,6 @@ module psb_d_oacc_vect_mod use iso_c_binding + use openacc use psb_const_mod use psb_error_mod use psb_d_vect_mod @@ -50,8 +51,8 @@ module psb_d_oacc_vect_mod procedure, pass(z) :: upd_xyz => d_oacc_upd_xyz procedure, pass(y) :: mlt_a => d_oacc_mlt_a procedure, pass(z) :: mlt_a_2 => d_oacc_mlt_a_2 - procedure, pass(y) :: mlt_v => d_oacc_mlt_v - procedure, pass(z) :: mlt_v_2 => d_oacc_mlt_v_2 + procedure, pass(y) :: mlt_v => psb_d_oacc_mlt_v + procedure, pass(z) :: mlt_v_2 => psb_d_oacc_mlt_v_2 procedure, pass(x) :: scal => d_oacc_scal procedure, pass(x) :: nrm2 => d_oacc_nrm2 procedure, pass(x) :: amax => d_oacc_amax @@ -62,17 +63,17 @@ module psb_d_oacc_vect_mod end type psb_d_vect_oacc interface - subroutine d_oacc_mlt_v(x, y, info) + subroutine psb_d_oacc_mlt_v(x, y, info) import implicit none class(psb_d_base_vect_type), intent(inout) :: x class(psb_d_vect_oacc), intent(inout) :: y integer(psb_ipk_), intent(out) :: info - end subroutine d_oacc_mlt_v + end subroutine psb_d_oacc_mlt_v end interface interface - subroutine d_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) + subroutine psb_d_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) import implicit none real(psb_dpk_), intent(in) :: alpha, beta @@ -81,7 +82,7 @@ module psb_d_oacc_vect_mod class(psb_d_vect_oacc), intent(inout) :: z integer(psb_ipk_), intent(out) :: info character(len=1), intent(in), optional :: conjgx, conjgy - end subroutine d_oacc_mlt_v_2 + end subroutine psb_d_oacc_mlt_v_2 end interface contains @@ -89,15 +90,23 @@ contains subroutine d_oacc_absval1(x) implicit none class(psb_d_vect_oacc), intent(inout) :: x - integer(psb_ipk_) :: n, i + integer(psb_ipk_) :: n - if (x%is_host()) call x%sync_space() + if (x%is_host()) call x%sync() n = size(x%v) - !$acc parallel loop - do i = 1, n - x%v(i) = abs(x%v(i)) - end do + call d_inner_oacc_absval1(n,x%v) call x%set_dev() + contains + subroutine d_inner_oacc_absval1(n,x) + implicit none + real(psb_dpk_), intent(inout) :: x(:) + integer(psb_ipk_) :: n + integer(psb_ipk_) :: i + !$acc parallel loop + do i = 1, n + x(i) = abs(x(i)) + end do + end subroutine d_inner_oacc_absval1 end subroutine d_oacc_absval1 subroutine d_oacc_absval2(x, y) @@ -112,15 +121,23 @@ contains class is (psb_d_vect_oacc) if (x%is_host()) call x%sync() if (yy%is_host()) call yy%sync() - !$acc parallel loop - do i = 1, n - yy%v(i) = abs(x%v(i)) - end do + call d_inner_oacc_absval2(n,x%v,yy%v) class default if (x%is_dev()) call x%sync() if (y%is_dev()) call y%sync() call x%psb_d_base_vect_type%absval(y) end select + contains + subroutine d_inner_oacc_absval2(n,x,y) + implicit none + real(psb_dpk_), intent(inout) :: x(:),y(:) + integer(psb_ipk_) :: n + integer(psb_ipk_) :: i + !$acc parallel loop + do i = 1, n + y(i) = abs(x(i)) + end do + end subroutine d_inner_oacc_absval2 end subroutine d_oacc_absval2 subroutine d_oacc_scal(alpha, x) @@ -128,32 +145,46 @@ contains class(psb_d_vect_oacc), intent(inout) :: x real(psb_dpk_), intent(in) :: alpha integer(psb_ipk_) :: info - integer(psb_ipk_) :: i - - if (x%is_host()) call x%sync_space() - !$acc parallel loop - do i = 1, size(x%v) - x%v(i) = alpha * x%v(i) - end do + if (x%is_host()) call x%sync() + call d_inner_oacc_scal(alpha, x%v) call x%set_dev() + contains + subroutine d_inner_oacc_scal(alpha, x) + real(psb_dpk_), intent(in) :: alpha + real(psb_dpk_), intent(inout) :: x(:) + integer(psb_ipk_) :: i + !$acc parallel loop + do i = 1, size(x) + x(i) = alpha * x(i) + end do + end subroutine d_inner_oacc_scal end subroutine d_oacc_scal function d_oacc_nrm2(n, x) result(res) implicit none class(psb_d_vect_oacc), intent(inout) :: x integer(psb_ipk_), intent(in) :: n - real(psb_dpk_) :: res + real(psb_dpk_) :: res + real(psb_dpk_) :: mx integer(psb_ipk_) :: info - real(psb_dpk_) :: sum - integer(psb_ipk_) :: i - if (x%is_host()) call x%sync_space() - sum = 0.0 - !$acc parallel loop reduction(+:sum) - do i = 1, n - sum = sum + abs(x%v(i))**2 - end do - res = sqrt(sum) + if (x%is_host()) call x%sync() + mx = d_oacc_amax(n,x) + res = d_inner_oacc_nrm2(n, mx, x%v) + contains + function d_inner_oacc_nrm2(n, mx,x) result(res) + integer(psb_ipk_) :: n + real(psb_dpk_) :: x(:) + real(psb_dpk_) :: mx, res + real(psb_dpk_) :: sum + integer(psb_ipk_) :: i + sum = 0.0 + !$acc parallel loop reduction(+:sum) + do i = 1, n + sum = sum + abs(x(i)/mx)**2 + end do + res = mx*sqrt(sum) + end function d_inner_oacc_nrm2 end function d_oacc_nrm2 function d_oacc_amax(n, x) result(res) @@ -162,18 +193,25 @@ contains integer(psb_ipk_), intent(in) :: n real(psb_dpk_) :: res integer(psb_ipk_) :: info - real(psb_dpk_) :: max_val - integer(psb_ipk_) :: i - if (x%is_host()) call x%sync_space() - max_val = -huge(0.0) - !$acc parallel loop reduction(max:max_val) - do i = 1, n - if (abs(x%v(i)) > max_val) max_val = abs(x%v(i)) - end do - res = max_val + if (x%is_host()) call x%sync() + res = d_inner_oacc_amax(n, x%v) + contains + function d_inner_oacc_amax(n, x) result(res) + integer(psb_ipk_) :: n + real(psb_dpk_) :: x(:) + real(psb_dpk_) :: res + real(psb_dpk_) :: max_val + integer(psb_ipk_) :: i + max_val = -huge(0.0) + !$acc parallel loop reduction(max:max_val) + do i = 1, n + if (abs(x(i)) > max_val) max_val = abs(x(i)) + end do + res = max_val + end function d_inner_oacc_amax end function d_oacc_amax - + function d_oacc_asum(n, x) result(res) implicit none class(psb_d_vect_oacc), intent(inout) :: x @@ -182,14 +220,20 @@ contains integer(psb_ipk_) :: info real(psb_dpk_) :: sum integer(psb_ipk_) :: i - - if (x%is_host()) call x%sync_space() - sum = 0.0 - !$acc parallel loop reduction(+:sum) - do i = 1, n - sum = sum + abs(x%v(i)) - end do - res = sum + if (x%is_host()) call x%sync() + res = d_inner_oacc_asum(n, x%v) + contains + function d_inner_oacc_asum(n, x) result(res) + integer(psb_ipk_) :: n + real(psb_dpk_) :: x(:) + real(psb_dpk_) :: res + integer(psb_ipk_) :: i + res = 0.0 + !$acc parallel loop reduction(+:res) + do i = 1, n + res = res + abs(x(i)) + end do + end function d_inner_oacc_asum end function d_oacc_asum @@ -201,7 +245,7 @@ contains integer(psb_ipk_) :: i, n info = 0 - if (y%is_dev()) call y%sync_space() + if (y%is_dev()) call y%sync() !$acc parallel loop do i = 1, size(x) y%v(i) = y%v(i) * x(i) @@ -219,7 +263,7 @@ contains integer(psb_ipk_) :: i, n info = 0 - if (z%is_dev()) call z%sync_space() + if (z%is_dev()) call z%sync() !$acc parallel loop do i = 1, size(x) z%v(i) = alpha * x(i) * y(i) + beta * z%v(i) @@ -282,18 +326,18 @@ contains !!$ class is (psb_d_vect_oacc) !!$ select type (yy => y) !!$ class is (psb_d_vect_oacc) -!!$ if (xx%is_host()) call xx%sync_space() -!!$ if (yy%is_host()) call yy%sync_space() -!!$ if ((beta /= dzero) .and. (z%is_host())) call z%sync_space() +!!$ if (xx%is_host()) call xx%sync() +!!$ if (yy%is_host()) call yy%sync() +!!$ if ((beta /= dzero) .and. (z%is_host())) call z%sync() !!$ !$acc parallel loop !!$ do i = 1, n !!$ z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) !!$ end do !!$ call z%set_dev() !!$ class default -!!$ if (xx%is_dev()) call xx%sync_space() +!!$ if (xx%is_dev()) call xx%sync() !!$ if (yy%is_dev()) call yy%sync() -!!$ if ((beta /= dzero) .and. (z%is_dev())) call z%sync_space() +!!$ if ((beta /= dzero) .and. (z%is_dev())) call z%sync() !!$ !$acc parallel loop !!$ do i = 1, n !!$ z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) @@ -303,7 +347,7 @@ contains !!$ class default !!$ if (x%is_dev()) call x%sync() !!$ if (y%is_dev()) call y%sync() -!!$ if ((beta /= dzero) .and. (z%is_dev())) call z%sync_space() +!!$ if ((beta /= dzero) .and. (z%is_dev())) call z%sync() !!$ !$acc parallel loop !!$ do i = 1, n !!$ z%v(i) = alpha * x%v(i) * y%v(i) + beta * z%v(i) @@ -327,23 +371,36 @@ contains select type(xx => x) type is (psb_d_vect_oacc) - if ((beta /= dzero) .and. y%is_host()) call y%sync_space() - if (xx%is_host()) call xx%sync_space() + if ((beta /= dzero) .and. y%is_host()) call y%sync() + if (xx%is_host()) call xx%sync() nx = size(xx%v) ny = size(y%v) if ((nx < m) .or. (ny < m)) then info = psb_err_internal_error_ else - !$acc parallel loop - do i = 1, m - y%v(i) = alpha * xx%v(i) + beta * y%v(i) - end do + call d_inner_oacc_axpby(m, alpha, x%v, beta, y%v, info) end if call y%set_dev() class default if ((alpha /= dzero) .and. (x%is_dev())) call x%sync() call y%axpby(m, alpha, x%v, beta, info) - end select + end select + contains + subroutine d_inner_oacc_axpby(m, alpha, x, beta, y, info) + !use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + real(psb_dpk_), intent(inout) :: x(:) + real(psb_dpk_), intent(inout) :: y(:) + real(psb_dpk_), intent(in) :: alpha, beta + integer(psb_ipk_), intent(out) :: info + !$acc parallel + !$acc loop + do i = 1, m + y(i) = alpha * x(i) + beta * y(i) + end do + !$acc end parallel + end subroutine d_inner_oacc_axpby end subroutine d_oacc_axpby_v subroutine d_oacc_axpby_a(m, alpha, x, beta, y, info) @@ -356,7 +413,7 @@ contains integer(psb_ipk_), intent(out) :: info integer(psb_ipk_) :: i - if ((beta /= dzero) .and. (y%is_dev())) call y%sync_space() + if ((beta /= dzero) .and. (y%is_dev())) call y%sync() !$acc parallel loop do i = 1, m y%v(i) = alpha * x(i) + beta * y%v(i) @@ -375,7 +432,7 @@ contains integer(psb_ipk_), intent(out) :: info integer(psb_ipk_) :: nx, ny, nz, i logical :: gpu_done - + write(0,*)'upd_xyz' info = psb_success_ gpu_done = .false. @@ -385,9 +442,9 @@ contains class is (psb_d_vect_oacc) select type(zz => z) class is (psb_d_vect_oacc) - if ((beta /= dzero) .and. yy%is_host()) call yy%sync_space() - if ((delta /= dzero) .and. zz%is_host()) call zz%sync_space() - if (xx%is_host()) call xx%sync_space() + if ((beta /= dzero) .and. yy%is_host()) call yy%sync() + if ((delta /= dzero) .and. zz%is_host()) call zz%sync() + if (xx%is_host()) call xx%sync() nx = size(xx%v) ny = size(yy%v) nz = size(zz%v) @@ -432,8 +489,8 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() - if (y%is_host()) call y%sync_space() + if (ii%is_host()) call ii%sync() + if (y%is_host()) call y%sync() !$acc parallel loop do i = 1, n @@ -459,13 +516,13 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() + if (ii%is_host()) call ii%sync() class default call psb_errpush(info, 'd_oacc_sctb_x') return end select - if (y%is_host()) call y%sync_space() + if (y%is_host()) call y%sync() !$acc parallel loop do i = 1, n @@ -486,7 +543,7 @@ contains integer(psb_ipk_) :: i if (n == 0) return - if (y%is_dev()) call y%sync_space() + if (y%is_dev()) call y%sync() !$acc parallel loop do i = 1, n @@ -512,13 +569,13 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() + if (ii%is_host()) call ii%sync() class default call psb_errpush(info, 'd_oacc_gthzbuf') return end select - if (x%is_host()) call x%sync_space() + if (x%is_host()) call x%sync() !$acc parallel loop do i = 1, n @@ -539,13 +596,13 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() + if (ii%is_host()) call ii%sync() class default call psb_errpush(info, 'd_oacc_gthzv_x') return end select - if (x%is_host()) call x%sync_space() + if (x%is_host()) call x%sync() !$acc parallel loop do i = 1, n @@ -573,9 +630,9 @@ contains type is (psb_i_vect_oacc) select type(vval => val) type is (psb_d_vect_oacc) - if (vval%is_host()) call vval%sync_space() - if (virl%is_host()) call virl%sync_space() - if (x%is_host()) call x%sync_space() + if (vval%is_host()) call vval%sync() + if (virl%is_host()) call virl%sync() + if (x%is_host()) call x%sync() !$acc parallel loop do i = 1, n x%v(virl%v(i)) = vval%v(i) @@ -588,11 +645,11 @@ contains if (.not.done_oacc) then select type(virl => irl) type is (psb_i_vect_oacc) - if (virl%is_dev()) call virl%sync_space() + if (virl%is_dev()) call virl%sync() end select select type(vval => val) type is (psb_d_vect_oacc) - if (vval%is_dev()) call vval%sync_space() + if (vval%is_dev()) call vval%sync() end select call x%ins(n, irl%v, val%v, dupl, info) end if @@ -616,7 +673,7 @@ contains integer(psb_ipk_) :: i info = 0 - if (x%is_dev()) call x%sync_space() + if (x%is_dev()) call x%sync() call x%psb_d_base_vect_type%ins(n, irl, val, dupl, info) call x%set_host() !$acc update device(x%v) @@ -635,7 +692,10 @@ contains call psb_errpush(info, 'd_oacc_bld_mn', i_err=(/n, n, n, n, n/)) end if call x%set_host() - !$acc update device(x%v) + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if + !$acc enter data copyin(x%v) end subroutine d_oacc_bld_mn @@ -657,7 +717,10 @@ contains x%v(:) = this(:) call x%set_host() - !$acc update device(x%v) + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if + !$acc enter data copyin(x%v) end subroutine d_oacc_bld_x @@ -676,13 +739,13 @@ contains if (nd < n) then call x%sync() call x%psb_d_base_vect_type%asb(n, info) - if (info == psb_success_) call x%sync_space() + if (info == psb_success_) call x%sync() call x%set_host() end if else if (size(x%v) < n) then call x%psb_d_base_vect_type%asb(n, info) - if (info == psb_success_) call x%sync_space() + if (info == psb_success_) call x%sync() call x%set_host() end if end if @@ -740,10 +803,9 @@ contains real(psb_dpk_) :: res real(psb_dpk_), external :: ddot integer(psb_ipk_) :: info - integer(psb_ipk_) :: i res = dzero - + !write(0,*) 'dot_v' select type(yy => y) type is (psb_d_base_vect_type) if (x%is_dev()) call x%sync() @@ -751,18 +813,26 @@ contains type is (psb_d_vect_oacc) if (x%is_host()) call x%sync() if (yy%is_host()) call yy%sync() - - !$acc parallel loop reduction(+:res) present(x%v, yy%v) - do i = 1, n - res = res + x%v(i) * yy%v(i) - end do - !$acc end parallel loop - + res = d_inner_oacc_dot(n, x%v, yy%v) class default call x%sync() res = y%dot(n, x%v) end select - + contains + function d_inner_oacc_dot(n, x, y) result(res) + implicit none + real(psb_dpk_), intent(in) :: x(:) + real(psb_dpk_), intent(in) :: y(:) + integer(psb_ipk_), intent(in) :: n + real(psb_dpk_) :: res + integer(psb_ipk_) :: i + + !$acc parallel loop reduction(+:res) present(x, y) + do i = 1, n + res = res + x(i) * y(i) + end do + !$acc end parallel loop + end function d_inner_oacc_dot end function d_oacc_vect_dot function d_oacc_dot_a(n, x, y) result(res) @@ -808,7 +878,7 @@ contains implicit none class(psb_d_vect_oacc), intent(inout) :: x if (allocated(x%v)) then - call d_oacc_create_dev(x%v) + if (.not.acc_is_present(x%v)) call d_oacc_create_dev(x%v) end if contains subroutine d_oacc_create_dev(v) @@ -886,6 +956,9 @@ contains call psb_realloc(n, x%v, info) if (info == 0) then call x%set_host() + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if !$acc enter data create(x%v) call x%sync_space() end if @@ -902,7 +975,9 @@ contains integer(psb_ipk_), intent(out) :: info info = 0 if (allocated(x%v)) then - !$acc exit data delete(x%v) finalize + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if deallocate(x%v, stat=info) end if diff --git a/openacc/psb_i_oacc_vect_mod.F90 b/openacc/psb_i_oacc_vect_mod.F90 index 72e9ada2..3dbc48f1 100644 --- a/openacc/psb_i_oacc_vect_mod.F90 +++ b/openacc/psb_i_oacc_vect_mod.F90 @@ -1,5 +1,6 @@ module psb_i_oacc_vect_mod use iso_c_binding + use openacc use psb_const_mod use psb_error_mod use psb_i_vect_mod @@ -64,8 +65,8 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() - if (y%is_host()) call y%sync_space() + if (ii%is_host()) call ii%sync() + if (y%is_host()) call y%sync() !$acc parallel loop do i = 1, n @@ -91,13 +92,13 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() + if (ii%is_host()) call ii%sync() class default call psb_errpush(info, 'i_oacc_sctb_x') return end select - if (y%is_host()) call y%sync_space() + if (y%is_host()) call y%sync() !$acc parallel loop do i = 1, n @@ -118,7 +119,7 @@ contains integer(psb_ipk_) :: i if (n == 0) return - if (y%is_dev()) call y%sync_space() + if (y%is_dev()) call y%sync() !$acc parallel loop do i = 1, n @@ -144,13 +145,13 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() + if (ii%is_host()) call ii%sync() class default call psb_errpush(info, 'i_oacc_gthzbuf') return end select - if (x%is_host()) call x%sync_space() + if (x%is_host()) call x%sync() !$acc parallel loop do i = 1, n @@ -171,13 +172,13 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() + if (ii%is_host()) call ii%sync() class default call psb_errpush(info, 'i_oacc_gthzv_x') return end select - if (x%is_host()) call x%sync_space() + if (x%is_host()) call x%sync() !$acc parallel loop do i = 1, n @@ -205,9 +206,9 @@ contains type is (psb_i_vect_oacc) select type(vval => val) type is (psb_i_vect_oacc) - if (vval%is_host()) call vval%sync_space() - if (virl%is_host()) call virl%sync_space() - if (x%is_host()) call x%sync_space() + if (vval%is_host()) call vval%sync() + if (virl%is_host()) call virl%sync() + if (x%is_host()) call x%sync() !$acc parallel loop do i = 1, n x%v(virl%v(i)) = vval%v(i) @@ -220,11 +221,11 @@ contains if (.not.done_oacc) then select type(virl => irl) type is (psb_i_vect_oacc) - if (virl%is_dev()) call virl%sync_space() + if (virl%is_dev()) call virl%sync() end select select type(vval => val) type is (psb_i_vect_oacc) - if (vval%is_dev()) call vval%sync_space() + if (vval%is_dev()) call vval%sync() end select call x%ins(n, irl%v, val%v, dupl, info) end if @@ -248,7 +249,7 @@ contains integer(psb_ipk_) :: i info = 0 - if (x%is_dev()) call x%sync_space() + if (x%is_dev()) call x%sync() call x%psb_i_base_vect_type%ins(n, irl, val, dupl, info) call x%set_host() !$acc update device(x%v) @@ -267,7 +268,10 @@ contains call psb_errpush(info, 'i_oacc_bld_mn', i_err=(/n, n, n, n, n/)) end if call x%set_host() - !$acc update device(x%v) + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if + !$acc enter data copyin(x%v) end subroutine i_oacc_bld_mn @@ -289,7 +293,10 @@ contains x%v(:) = this(:) call x%set_host() - !$acc update device(x%v) + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if + !$acc enter data copyin(x%v) end subroutine i_oacc_bld_x @@ -308,13 +315,13 @@ contains if (nd < n) then call x%sync() call x%psb_i_base_vect_type%asb(n, info) - if (info == psb_success_) call x%sync_space() + if (info == psb_success_) call x%sync() call x%set_host() end if else if (size(x%v) < n) then call x%psb_i_base_vect_type%asb(n, info) - if (info == psb_success_) call x%sync_space() + if (info == psb_success_) call x%sync() call x%set_host() end if end if @@ -393,7 +400,7 @@ contains implicit none class(psb_i_vect_oacc), intent(inout) :: x if (allocated(x%v)) then - call i_oacc_create_dev(x%v) + if (.not.acc_is_present(x%v)) call i_oacc_create_dev(x%v) end if contains subroutine i_oacc_create_dev(v) @@ -471,6 +478,9 @@ contains call psb_realloc(n, x%v, info) if (info == 0) then call x%set_host() + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if !$acc enter data create(x%v) call x%sync_space() end if @@ -487,7 +497,9 @@ contains integer(psb_ipk_), intent(out) :: info info = 0 if (allocated(x%v)) then - !$acc exit data delete(x%v) finalize + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if deallocate(x%v, stat=info) end if diff --git a/openacc/psb_l_oacc_vect_mod.F90 b/openacc/psb_l_oacc_vect_mod.F90 index aeba4537..cdf28366 100644 --- a/openacc/psb_l_oacc_vect_mod.F90 +++ b/openacc/psb_l_oacc_vect_mod.F90 @@ -1,5 +1,6 @@ module psb_l_oacc_vect_mod use iso_c_binding + use openacc use psb_const_mod use psb_error_mod use psb_l_vect_mod @@ -66,8 +67,8 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() - if (y%is_host()) call y%sync_space() + if (ii%is_host()) call ii%sync() + if (y%is_host()) call y%sync() !$acc parallel loop do i = 1, n @@ -93,13 +94,13 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() + if (ii%is_host()) call ii%sync() class default call psb_errpush(info, 'l_oacc_sctb_x') return end select - if (y%is_host()) call y%sync_space() + if (y%is_host()) call y%sync() !$acc parallel loop do i = 1, n @@ -120,7 +121,7 @@ contains integer(psb_ipk_) :: i if (n == 0) return - if (y%is_dev()) call y%sync_space() + if (y%is_dev()) call y%sync() !$acc parallel loop do i = 1, n @@ -146,13 +147,13 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() + if (ii%is_host()) call ii%sync() class default call psb_errpush(info, 'l_oacc_gthzbuf') return end select - if (x%is_host()) call x%sync_space() + if (x%is_host()) call x%sync() !$acc parallel loop do i = 1, n @@ -173,13 +174,13 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() + if (ii%is_host()) call ii%sync() class default call psb_errpush(info, 'l_oacc_gthzv_x') return end select - if (x%is_host()) call x%sync_space() + if (x%is_host()) call x%sync() !$acc parallel loop do i = 1, n @@ -207,9 +208,9 @@ contains type is (psb_i_vect_oacc) select type(vval => val) type is (psb_l_vect_oacc) - if (vval%is_host()) call vval%sync_space() - if (virl%is_host()) call virl%sync_space() - if (x%is_host()) call x%sync_space() + if (vval%is_host()) call vval%sync() + if (virl%is_host()) call virl%sync() + if (x%is_host()) call x%sync() !$acc parallel loop do i = 1, n x%v(virl%v(i)) = vval%v(i) @@ -222,11 +223,11 @@ contains if (.not.done_oacc) then select type(virl => irl) type is (psb_i_vect_oacc) - if (virl%is_dev()) call virl%sync_space() + if (virl%is_dev()) call virl%sync() end select select type(vval => val) type is (psb_l_vect_oacc) - if (vval%is_dev()) call vval%sync_space() + if (vval%is_dev()) call vval%sync() end select call x%ins(n, irl%v, val%v, dupl, info) end if @@ -250,7 +251,7 @@ contains integer(psb_ipk_) :: i info = 0 - if (x%is_dev()) call x%sync_space() + if (x%is_dev()) call x%sync() call x%psb_l_base_vect_type%ins(n, irl, val, dupl, info) call x%set_host() !$acc update device(x%v) @@ -269,7 +270,10 @@ contains call psb_errpush(info, 'l_oacc_bld_mn', i_err=(/n, n, n, n, n/)) end if call x%set_host() - !$acc update device(x%v) + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if + !$acc enter data copyin(x%v) end subroutine l_oacc_bld_mn @@ -291,7 +295,10 @@ contains x%v(:) = this(:) call x%set_host() - !$acc update device(x%v) + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if + !$acc enter data copyin(x%v) end subroutine l_oacc_bld_x @@ -310,13 +317,13 @@ contains if (nd < n) then call x%sync() call x%psb_l_base_vect_type%asb(n, info) - if (info == psb_success_) call x%sync_space() + if (info == psb_success_) call x%sync() call x%set_host() end if else if (size(x%v) < n) then call x%psb_l_base_vect_type%asb(n, info) - if (info == psb_success_) call x%sync_space() + if (info == psb_success_) call x%sync() call x%set_host() end if end if @@ -395,7 +402,7 @@ contains implicit none class(psb_l_vect_oacc), intent(inout) :: x if (allocated(x%v)) then - call l_oacc_create_dev(x%v) + if (.not.acc_is_present(x%v)) call l_oacc_create_dev(x%v) end if contains subroutine l_oacc_create_dev(v) @@ -473,6 +480,9 @@ contains call psb_realloc(n, x%v, info) if (info == 0) then call x%set_host() + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if !$acc enter data create(x%v) call x%sync_space() end if @@ -489,7 +499,9 @@ contains integer(psb_ipk_), intent(out) :: info info = 0 if (allocated(x%v)) then - !$acc exit data delete(x%v) finalize + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if deallocate(x%v, stat=info) end if diff --git a/openacc/psb_s_oacc_vect_mod.F90 b/openacc/psb_s_oacc_vect_mod.F90 index 47922d6a..c3b31af7 100644 --- a/openacc/psb_s_oacc_vect_mod.F90 +++ b/openacc/psb_s_oacc_vect_mod.F90 @@ -1,5 +1,6 @@ module psb_s_oacc_vect_mod use iso_c_binding + use openacc use psb_const_mod use psb_error_mod use psb_s_vect_mod @@ -50,8 +51,8 @@ module psb_s_oacc_vect_mod procedure, pass(z) :: upd_xyz => s_oacc_upd_xyz procedure, pass(y) :: mlt_a => s_oacc_mlt_a procedure, pass(z) :: mlt_a_2 => s_oacc_mlt_a_2 - procedure, pass(y) :: mlt_v => s_oacc_mlt_v - procedure, pass(z) :: mlt_v_2 => s_oacc_mlt_v_2 + procedure, pass(y) :: mlt_v => psb_s_oacc_mlt_v + procedure, pass(z) :: mlt_v_2 => psb_s_oacc_mlt_v_2 procedure, pass(x) :: scal => s_oacc_scal procedure, pass(x) :: nrm2 => s_oacc_nrm2 procedure, pass(x) :: amax => s_oacc_amax @@ -62,17 +63,17 @@ module psb_s_oacc_vect_mod end type psb_s_vect_oacc interface - subroutine s_oacc_mlt_v(x, y, info) + subroutine psb_s_oacc_mlt_v(x, y, info) import implicit none class(psb_s_base_vect_type), intent(inout) :: x class(psb_s_vect_oacc), intent(inout) :: y integer(psb_ipk_), intent(out) :: info - end subroutine s_oacc_mlt_v + end subroutine psb_s_oacc_mlt_v end interface interface - subroutine s_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) + subroutine psb_s_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) import implicit none real(psb_spk_), intent(in) :: alpha, beta @@ -81,7 +82,7 @@ module psb_s_oacc_vect_mod class(psb_s_vect_oacc), intent(inout) :: z integer(psb_ipk_), intent(out) :: info character(len=1), intent(in), optional :: conjgx, conjgy - end subroutine s_oacc_mlt_v_2 + end subroutine psb_s_oacc_mlt_v_2 end interface contains @@ -89,15 +90,23 @@ contains subroutine s_oacc_absval1(x) implicit none class(psb_s_vect_oacc), intent(inout) :: x - integer(psb_ipk_) :: n, i + integer(psb_ipk_) :: n - if (x%is_host()) call x%sync_space() + if (x%is_host()) call x%sync() n = size(x%v) - !$acc parallel loop - do i = 1, n - x%v(i) = abs(x%v(i)) - end do + call s_inner_oacc_absval1(n,x%v) call x%set_dev() + contains + subroutine s_inner_oacc_absval1(n,x) + implicit none + real(psb_spk_), intent(inout) :: x(:) + integer(psb_ipk_) :: n + integer(psb_ipk_) :: i + !$acc parallel loop + do i = 1, n + x(i) = abs(x(i)) + end do + end subroutine s_inner_oacc_absval1 end subroutine s_oacc_absval1 subroutine s_oacc_absval2(x, y) @@ -112,15 +121,23 @@ contains class is (psb_s_vect_oacc) if (x%is_host()) call x%sync() if (yy%is_host()) call yy%sync() - !$acc parallel loop - do i = 1, n - yy%v(i) = abs(x%v(i)) - end do + call s_inner_oacc_absval2(n,x%v,yy%v) class default if (x%is_dev()) call x%sync() if (y%is_dev()) call y%sync() call x%psb_s_base_vect_type%absval(y) end select + contains + subroutine s_inner_oacc_absval2(n,x,y) + implicit none + real(psb_spk_), intent(inout) :: x(:),y(:) + integer(psb_ipk_) :: n + integer(psb_ipk_) :: i + !$acc parallel loop + do i = 1, n + y(i) = abs(x(i)) + end do + end subroutine s_inner_oacc_absval2 end subroutine s_oacc_absval2 subroutine s_oacc_scal(alpha, x) @@ -128,32 +145,46 @@ contains class(psb_s_vect_oacc), intent(inout) :: x real(psb_spk_), intent(in) :: alpha integer(psb_ipk_) :: info - integer(psb_ipk_) :: i - - if (x%is_host()) call x%sync_space() - !$acc parallel loop - do i = 1, size(x%v) - x%v(i) = alpha * x%v(i) - end do + if (x%is_host()) call x%sync() + call s_inner_oacc_scal(alpha, x%v) call x%set_dev() + contains + subroutine s_inner_oacc_scal(alpha, x) + real(psb_spk_), intent(in) :: alpha + real(psb_spk_), intent(inout) :: x(:) + integer(psb_ipk_) :: i + !$acc parallel loop + do i = 1, size(x) + x(i) = alpha * x(i) + end do + end subroutine s_inner_oacc_scal end subroutine s_oacc_scal function s_oacc_nrm2(n, x) result(res) implicit none class(psb_s_vect_oacc), intent(inout) :: x integer(psb_ipk_), intent(in) :: n - real(psb_spk_) :: res + real(psb_spk_) :: res + real(psb_spk_) :: mx integer(psb_ipk_) :: info - real(psb_spk_) :: sum - integer(psb_ipk_) :: i - if (x%is_host()) call x%sync_space() - sum = 0.0 - !$acc parallel loop reduction(+:sum) - do i = 1, n - sum = sum + abs(x%v(i))**2 - end do - res = sqrt(sum) + if (x%is_host()) call x%sync() + mx = s_oacc_amax(n,x) + res = s_inner_oacc_nrm2(n, mx, x%v) + contains + function s_inner_oacc_nrm2(n, mx,x) result(res) + integer(psb_ipk_) :: n + real(psb_spk_) :: x(:) + real(psb_spk_) :: mx, res + real(psb_spk_) :: sum + integer(psb_ipk_) :: i + sum = 0.0 + !$acc parallel loop reduction(+:sum) + do i = 1, n + sum = sum + abs(x(i)/mx)**2 + end do + res = mx*sqrt(sum) + end function s_inner_oacc_nrm2 end function s_oacc_nrm2 function s_oacc_amax(n, x) result(res) @@ -162,18 +193,25 @@ contains integer(psb_ipk_), intent(in) :: n real(psb_spk_) :: res integer(psb_ipk_) :: info - real(psb_spk_) :: max_val - integer(psb_ipk_) :: i - if (x%is_host()) call x%sync_space() - max_val = -huge(0.0) - !$acc parallel loop reduction(max:max_val) - do i = 1, n - if (abs(x%v(i)) > max_val) max_val = abs(x%v(i)) - end do - res = max_val + if (x%is_host()) call x%sync() + res = s_inner_oacc_amax(n, x%v) + contains + function s_inner_oacc_amax(n, x) result(res) + integer(psb_ipk_) :: n + real(psb_spk_) :: x(:) + real(psb_spk_) :: res + real(psb_spk_) :: max_val + integer(psb_ipk_) :: i + max_val = -huge(0.0) + !$acc parallel loop reduction(max:max_val) + do i = 1, n + if (abs(x(i)) > max_val) max_val = abs(x(i)) + end do + res = max_val + end function s_inner_oacc_amax end function s_oacc_amax - + function s_oacc_asum(n, x) result(res) implicit none class(psb_s_vect_oacc), intent(inout) :: x @@ -182,14 +220,20 @@ contains integer(psb_ipk_) :: info real(psb_spk_) :: sum integer(psb_ipk_) :: i - - if (x%is_host()) call x%sync_space() - sum = 0.0 - !$acc parallel loop reduction(+:sum) - do i = 1, n - sum = sum + abs(x%v(i)) - end do - res = sum + if (x%is_host()) call x%sync() + res = s_inner_oacc_asum(n, x%v) + contains + function s_inner_oacc_asum(n, x) result(res) + integer(psb_ipk_) :: n + real(psb_spk_) :: x(:) + real(psb_spk_) :: res + integer(psb_ipk_) :: i + res = 0.0 + !$acc parallel loop reduction(+:res) + do i = 1, n + res = res + abs(x(i)) + end do + end function s_inner_oacc_asum end function s_oacc_asum @@ -201,7 +245,7 @@ contains integer(psb_ipk_) :: i, n info = 0 - if (y%is_dev()) call y%sync_space() + if (y%is_dev()) call y%sync() !$acc parallel loop do i = 1, size(x) y%v(i) = y%v(i) * x(i) @@ -219,7 +263,7 @@ contains integer(psb_ipk_) :: i, n info = 0 - if (z%is_dev()) call z%sync_space() + if (z%is_dev()) call z%sync() !$acc parallel loop do i = 1, size(x) z%v(i) = alpha * x(i) * y(i) + beta * z%v(i) @@ -282,18 +326,18 @@ contains !!$ class is (psb_s_vect_oacc) !!$ select type (yy => y) !!$ class is (psb_s_vect_oacc) -!!$ if (xx%is_host()) call xx%sync_space() -!!$ if (yy%is_host()) call yy%sync_space() -!!$ if ((beta /= szero) .and. (z%is_host())) call z%sync_space() +!!$ if (xx%is_host()) call xx%sync() +!!$ if (yy%is_host()) call yy%sync() +!!$ if ((beta /= szero) .and. (z%is_host())) call z%sync() !!$ !$acc parallel loop !!$ do i = 1, n !!$ z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) !!$ end do !!$ call z%set_dev() !!$ class default -!!$ if (xx%is_dev()) call xx%sync_space() +!!$ if (xx%is_dev()) call xx%sync() !!$ if (yy%is_dev()) call yy%sync() -!!$ if ((beta /= szero) .and. (z%is_dev())) call z%sync_space() +!!$ if ((beta /= szero) .and. (z%is_dev())) call z%sync() !!$ !$acc parallel loop !!$ do i = 1, n !!$ z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) @@ -303,7 +347,7 @@ contains !!$ class default !!$ if (x%is_dev()) call x%sync() !!$ if (y%is_dev()) call y%sync() -!!$ if ((beta /= szero) .and. (z%is_dev())) call z%sync_space() +!!$ if ((beta /= szero) .and. (z%is_dev())) call z%sync() !!$ !$acc parallel loop !!$ do i = 1, n !!$ z%v(i) = alpha * x%v(i) * y%v(i) + beta * z%v(i) @@ -327,23 +371,36 @@ contains select type(xx => x) type is (psb_s_vect_oacc) - if ((beta /= szero) .and. y%is_host()) call y%sync_space() - if (xx%is_host()) call xx%sync_space() + if ((beta /= szero) .and. y%is_host()) call y%sync() + if (xx%is_host()) call xx%sync() nx = size(xx%v) ny = size(y%v) if ((nx < m) .or. (ny < m)) then info = psb_err_internal_error_ else - !$acc parallel loop - do i = 1, m - y%v(i) = alpha * xx%v(i) + beta * y%v(i) - end do + call s_inner_oacc_axpby(m, alpha, x%v, beta, y%v, info) end if call y%set_dev() class default if ((alpha /= szero) .and. (x%is_dev())) call x%sync() call y%axpby(m, alpha, x%v, beta, info) - end select + end select + contains + subroutine s_inner_oacc_axpby(m, alpha, x, beta, y, info) + !use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + real(psb_spk_), intent(inout) :: x(:) + real(psb_spk_), intent(inout) :: y(:) + real(psb_spk_), intent(in) :: alpha, beta + integer(psb_ipk_), intent(out) :: info + !$acc parallel + !$acc loop + do i = 1, m + y(i) = alpha * x(i) + beta * y(i) + end do + !$acc end parallel + end subroutine s_inner_oacc_axpby end subroutine s_oacc_axpby_v subroutine s_oacc_axpby_a(m, alpha, x, beta, y, info) @@ -356,7 +413,7 @@ contains integer(psb_ipk_), intent(out) :: info integer(psb_ipk_) :: i - if ((beta /= szero) .and. (y%is_dev())) call y%sync_space() + if ((beta /= szero) .and. (y%is_dev())) call y%sync() !$acc parallel loop do i = 1, m y%v(i) = alpha * x(i) + beta * y%v(i) @@ -375,7 +432,7 @@ contains integer(psb_ipk_), intent(out) :: info integer(psb_ipk_) :: nx, ny, nz, i logical :: gpu_done - + write(0,*)'upd_xyz' info = psb_success_ gpu_done = .false. @@ -385,9 +442,9 @@ contains class is (psb_s_vect_oacc) select type(zz => z) class is (psb_s_vect_oacc) - if ((beta /= szero) .and. yy%is_host()) call yy%sync_space() - if ((delta /= szero) .and. zz%is_host()) call zz%sync_space() - if (xx%is_host()) call xx%sync_space() + if ((beta /= szero) .and. yy%is_host()) call yy%sync() + if ((delta /= szero) .and. zz%is_host()) call zz%sync() + if (xx%is_host()) call xx%sync() nx = size(xx%v) ny = size(yy%v) nz = size(zz%v) @@ -432,8 +489,8 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() - if (y%is_host()) call y%sync_space() + if (ii%is_host()) call ii%sync() + if (y%is_host()) call y%sync() !$acc parallel loop do i = 1, n @@ -459,13 +516,13 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() + if (ii%is_host()) call ii%sync() class default call psb_errpush(info, 's_oacc_sctb_x') return end select - if (y%is_host()) call y%sync_space() + if (y%is_host()) call y%sync() !$acc parallel loop do i = 1, n @@ -486,7 +543,7 @@ contains integer(psb_ipk_) :: i if (n == 0) return - if (y%is_dev()) call y%sync_space() + if (y%is_dev()) call y%sync() !$acc parallel loop do i = 1, n @@ -512,13 +569,13 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() + if (ii%is_host()) call ii%sync() class default call psb_errpush(info, 's_oacc_gthzbuf') return end select - if (x%is_host()) call x%sync_space() + if (x%is_host()) call x%sync() !$acc parallel loop do i = 1, n @@ -539,13 +596,13 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() + if (ii%is_host()) call ii%sync() class default call psb_errpush(info, 's_oacc_gthzv_x') return end select - if (x%is_host()) call x%sync_space() + if (x%is_host()) call x%sync() !$acc parallel loop do i = 1, n @@ -573,9 +630,9 @@ contains type is (psb_i_vect_oacc) select type(vval => val) type is (psb_s_vect_oacc) - if (vval%is_host()) call vval%sync_space() - if (virl%is_host()) call virl%sync_space() - if (x%is_host()) call x%sync_space() + if (vval%is_host()) call vval%sync() + if (virl%is_host()) call virl%sync() + if (x%is_host()) call x%sync() !$acc parallel loop do i = 1, n x%v(virl%v(i)) = vval%v(i) @@ -588,11 +645,11 @@ contains if (.not.done_oacc) then select type(virl => irl) type is (psb_i_vect_oacc) - if (virl%is_dev()) call virl%sync_space() + if (virl%is_dev()) call virl%sync() end select select type(vval => val) type is (psb_s_vect_oacc) - if (vval%is_dev()) call vval%sync_space() + if (vval%is_dev()) call vval%sync() end select call x%ins(n, irl%v, val%v, dupl, info) end if @@ -616,7 +673,7 @@ contains integer(psb_ipk_) :: i info = 0 - if (x%is_dev()) call x%sync_space() + if (x%is_dev()) call x%sync() call x%psb_s_base_vect_type%ins(n, irl, val, dupl, info) call x%set_host() !$acc update device(x%v) @@ -635,7 +692,10 @@ contains call psb_errpush(info, 's_oacc_bld_mn', i_err=(/n, n, n, n, n/)) end if call x%set_host() - !$acc update device(x%v) + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if + !$acc enter data copyin(x%v) end subroutine s_oacc_bld_mn @@ -657,7 +717,10 @@ contains x%v(:) = this(:) call x%set_host() - !$acc update device(x%v) + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if + !$acc enter data copyin(x%v) end subroutine s_oacc_bld_x @@ -676,13 +739,13 @@ contains if (nd < n) then call x%sync() call x%psb_s_base_vect_type%asb(n, info) - if (info == psb_success_) call x%sync_space() + if (info == psb_success_) call x%sync() call x%set_host() end if else if (size(x%v) < n) then call x%psb_s_base_vect_type%asb(n, info) - if (info == psb_success_) call x%sync_space() + if (info == psb_success_) call x%sync() call x%set_host() end if end if @@ -740,10 +803,9 @@ contains real(psb_spk_) :: res real(psb_spk_), external :: ddot integer(psb_ipk_) :: info - integer(psb_ipk_) :: i res = szero - + !write(0,*) 'dot_v' select type(yy => y) type is (psb_s_base_vect_type) if (x%is_dev()) call x%sync() @@ -751,18 +813,26 @@ contains type is (psb_s_vect_oacc) if (x%is_host()) call x%sync() if (yy%is_host()) call yy%sync() - - !$acc parallel loop reduction(+:res) present(x%v, yy%v) - do i = 1, n - res = res + x%v(i) * yy%v(i) - end do - !$acc end parallel loop - + res = s_inner_oacc_dot(n, x%v, yy%v) class default call x%sync() res = y%dot(n, x%v) end select - + contains + function s_inner_oacc_dot(n, x, y) result(res) + implicit none + real(psb_spk_), intent(in) :: x(:) + real(psb_spk_), intent(in) :: y(:) + integer(psb_ipk_), intent(in) :: n + real(psb_spk_) :: res + integer(psb_ipk_) :: i + + !$acc parallel loop reduction(+:res) present(x, y) + do i = 1, n + res = res + x(i) * y(i) + end do + !$acc end parallel loop + end function s_inner_oacc_dot end function s_oacc_vect_dot function s_oacc_dot_a(n, x, y) result(res) @@ -808,7 +878,7 @@ contains implicit none class(psb_s_vect_oacc), intent(inout) :: x if (allocated(x%v)) then - call s_oacc_create_dev(x%v) + if (.not.acc_is_present(x%v)) call s_oacc_create_dev(x%v) end if contains subroutine s_oacc_create_dev(v) @@ -886,6 +956,9 @@ contains call psb_realloc(n, x%v, info) if (info == 0) then call x%set_host() + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if !$acc enter data create(x%v) call x%sync_space() end if @@ -902,7 +975,9 @@ contains integer(psb_ipk_), intent(out) :: info info = 0 if (allocated(x%v)) then - !$acc exit data delete(x%v) finalize + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if deallocate(x%v, stat=info) end if diff --git a/openacc/psb_z_oacc_vect_mod.F90 b/openacc/psb_z_oacc_vect_mod.F90 index be03b1cd..bab1a0a0 100644 --- a/openacc/psb_z_oacc_vect_mod.F90 +++ b/openacc/psb_z_oacc_vect_mod.F90 @@ -1,5 +1,6 @@ module psb_z_oacc_vect_mod use iso_c_binding + use openacc use psb_const_mod use psb_error_mod use psb_z_vect_mod @@ -50,8 +51,8 @@ module psb_z_oacc_vect_mod procedure, pass(z) :: upd_xyz => z_oacc_upd_xyz procedure, pass(y) :: mlt_a => z_oacc_mlt_a procedure, pass(z) :: mlt_a_2 => z_oacc_mlt_a_2 - procedure, pass(y) :: mlt_v => z_oacc_mlt_v - procedure, pass(z) :: mlt_v_2 => z_oacc_mlt_v_2 + procedure, pass(y) :: mlt_v => psb_z_oacc_mlt_v + procedure, pass(z) :: mlt_v_2 => psb_z_oacc_mlt_v_2 procedure, pass(x) :: scal => z_oacc_scal procedure, pass(x) :: nrm2 => z_oacc_nrm2 procedure, pass(x) :: amax => z_oacc_amax @@ -62,17 +63,17 @@ module psb_z_oacc_vect_mod end type psb_z_vect_oacc interface - subroutine z_oacc_mlt_v(x, y, info) + subroutine psb_z_oacc_mlt_v(x, y, info) import implicit none class(psb_z_base_vect_type), intent(inout) :: x class(psb_z_vect_oacc), intent(inout) :: y integer(psb_ipk_), intent(out) :: info - end subroutine z_oacc_mlt_v + end subroutine psb_z_oacc_mlt_v end interface interface - subroutine z_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) + subroutine psb_z_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) import implicit none complex(psb_dpk_), intent(in) :: alpha, beta @@ -81,7 +82,7 @@ module psb_z_oacc_vect_mod class(psb_z_vect_oacc), intent(inout) :: z integer(psb_ipk_), intent(out) :: info character(len=1), intent(in), optional :: conjgx, conjgy - end subroutine z_oacc_mlt_v_2 + end subroutine psb_z_oacc_mlt_v_2 end interface contains @@ -89,15 +90,23 @@ contains subroutine z_oacc_absval1(x) implicit none class(psb_z_vect_oacc), intent(inout) :: x - integer(psb_ipk_) :: n, i + integer(psb_ipk_) :: n - if (x%is_host()) call x%sync_space() + if (x%is_host()) call x%sync() n = size(x%v) - !$acc parallel loop - do i = 1, n - x%v(i) = abs(x%v(i)) - end do + call z_inner_oacc_absval1(n,x%v) call x%set_dev() + contains + subroutine z_inner_oacc_absval1(n,x) + implicit none + complex(psb_dpk_), intent(inout) :: x(:) + integer(psb_ipk_) :: n + integer(psb_ipk_) :: i + !$acc parallel loop + do i = 1, n + x(i) = abs(x(i)) + end do + end subroutine z_inner_oacc_absval1 end subroutine z_oacc_absval1 subroutine z_oacc_absval2(x, y) @@ -112,15 +121,23 @@ contains class is (psb_z_vect_oacc) if (x%is_host()) call x%sync() if (yy%is_host()) call yy%sync() - !$acc parallel loop - do i = 1, n - yy%v(i) = abs(x%v(i)) - end do + call z_inner_oacc_absval2(n,x%v,yy%v) class default if (x%is_dev()) call x%sync() if (y%is_dev()) call y%sync() call x%psb_z_base_vect_type%absval(y) end select + contains + subroutine z_inner_oacc_absval2(n,x,y) + implicit none + complex(psb_dpk_), intent(inout) :: x(:),y(:) + integer(psb_ipk_) :: n + integer(psb_ipk_) :: i + !$acc parallel loop + do i = 1, n + y(i) = abs(x(i)) + end do + end subroutine z_inner_oacc_absval2 end subroutine z_oacc_absval2 subroutine z_oacc_scal(alpha, x) @@ -128,32 +145,46 @@ contains class(psb_z_vect_oacc), intent(inout) :: x complex(psb_dpk_), intent(in) :: alpha integer(psb_ipk_) :: info - integer(psb_ipk_) :: i - - if (x%is_host()) call x%sync_space() - !$acc parallel loop - do i = 1, size(x%v) - x%v(i) = alpha * x%v(i) - end do + if (x%is_host()) call x%sync() + call z_inner_oacc_scal(alpha, x%v) call x%set_dev() + contains + subroutine z_inner_oacc_scal(alpha, x) + complex(psb_dpk_), intent(in) :: alpha + complex(psb_dpk_), intent(inout) :: x(:) + integer(psb_ipk_) :: i + !$acc parallel loop + do i = 1, size(x) + x(i) = alpha * x(i) + end do + end subroutine z_inner_oacc_scal end subroutine z_oacc_scal function z_oacc_nrm2(n, x) result(res) implicit none class(psb_z_vect_oacc), intent(inout) :: x integer(psb_ipk_), intent(in) :: n - real(psb_dpk_) :: res + real(psb_dpk_) :: res + real(psb_dpk_) :: mx integer(psb_ipk_) :: info - real(psb_dpk_) :: sum - integer(psb_ipk_) :: i - if (x%is_host()) call x%sync_space() - sum = 0.0 - !$acc parallel loop reduction(+:sum) - do i = 1, n - sum = sum + abs(x%v(i))**2 - end do - res = sqrt(sum) + if (x%is_host()) call x%sync() + mx = z_oacc_amax(n,x) + res = z_inner_oacc_nrm2(n, mx, x%v) + contains + function z_inner_oacc_nrm2(n, mx,x) result(res) + integer(psb_ipk_) :: n + complex(psb_dpk_) :: x(:) + real(psb_dpk_) :: mx, res + real(psb_dpk_) :: sum + integer(psb_ipk_) :: i + sum = 0.0 + !$acc parallel loop reduction(+:sum) + do i = 1, n + sum = sum + abs(x(i)/mx)**2 + end do + res = mx*sqrt(sum) + end function z_inner_oacc_nrm2 end function z_oacc_nrm2 function z_oacc_amax(n, x) result(res) @@ -162,18 +193,25 @@ contains integer(psb_ipk_), intent(in) :: n real(psb_dpk_) :: res integer(psb_ipk_) :: info - real(psb_dpk_) :: max_val - integer(psb_ipk_) :: i - if (x%is_host()) call x%sync_space() - max_val = -huge(0.0) - !$acc parallel loop reduction(max:max_val) - do i = 1, n - if (abs(x%v(i)) > max_val) max_val = abs(x%v(i)) - end do - res = max_val + if (x%is_host()) call x%sync() + res = z_inner_oacc_amax(n, x%v) + contains + function z_inner_oacc_amax(n, x) result(res) + integer(psb_ipk_) :: n + complex(psb_dpk_) :: x(:) + real(psb_dpk_) :: res + real(psb_dpk_) :: max_val + integer(psb_ipk_) :: i + max_val = -huge(0.0) + !$acc parallel loop reduction(max:max_val) + do i = 1, n + if (abs(x(i)) > max_val) max_val = abs(x(i)) + end do + res = max_val + end function z_inner_oacc_amax end function z_oacc_amax - + function z_oacc_asum(n, x) result(res) implicit none class(psb_z_vect_oacc), intent(inout) :: x @@ -182,14 +220,20 @@ contains integer(psb_ipk_) :: info complex(psb_dpk_) :: sum integer(psb_ipk_) :: i - - if (x%is_host()) call x%sync_space() - sum = 0.0 - !$acc parallel loop reduction(+:sum) - do i = 1, n - sum = sum + abs(x%v(i)) - end do - res = sum + if (x%is_host()) call x%sync() + res = z_inner_oacc_asum(n, x%v) + contains + function z_inner_oacc_asum(n, x) result(res) + integer(psb_ipk_) :: n + complex(psb_dpk_) :: x(:) + real(psb_dpk_) :: res + integer(psb_ipk_) :: i + res = 0.0 + !$acc parallel loop reduction(+:res) + do i = 1, n + res = res + abs(x(i)) + end do + end function z_inner_oacc_asum end function z_oacc_asum @@ -201,7 +245,7 @@ contains integer(psb_ipk_) :: i, n info = 0 - if (y%is_dev()) call y%sync_space() + if (y%is_dev()) call y%sync() !$acc parallel loop do i = 1, size(x) y%v(i) = y%v(i) * x(i) @@ -219,7 +263,7 @@ contains integer(psb_ipk_) :: i, n info = 0 - if (z%is_dev()) call z%sync_space() + if (z%is_dev()) call z%sync() !$acc parallel loop do i = 1, size(x) z%v(i) = alpha * x(i) * y(i) + beta * z%v(i) @@ -282,18 +326,18 @@ contains !!$ class is (psb_z_vect_oacc) !!$ select type (yy => y) !!$ class is (psb_z_vect_oacc) -!!$ if (xx%is_host()) call xx%sync_space() -!!$ if (yy%is_host()) call yy%sync_space() -!!$ if ((beta /= zzero) .and. (z%is_host())) call z%sync_space() +!!$ if (xx%is_host()) call xx%sync() +!!$ if (yy%is_host()) call yy%sync() +!!$ if ((beta /= zzero) .and. (z%is_host())) call z%sync() !!$ !$acc parallel loop !!$ do i = 1, n !!$ z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) !!$ end do !!$ call z%set_dev() !!$ class default -!!$ if (xx%is_dev()) call xx%sync_space() +!!$ if (xx%is_dev()) call xx%sync() !!$ if (yy%is_dev()) call yy%sync() -!!$ if ((beta /= zzero) .and. (z%is_dev())) call z%sync_space() +!!$ if ((beta /= zzero) .and. (z%is_dev())) call z%sync() !!$ !$acc parallel loop !!$ do i = 1, n !!$ z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) @@ -303,7 +347,7 @@ contains !!$ class default !!$ if (x%is_dev()) call x%sync() !!$ if (y%is_dev()) call y%sync() -!!$ if ((beta /= zzero) .and. (z%is_dev())) call z%sync_space() +!!$ if ((beta /= zzero) .and. (z%is_dev())) call z%sync() !!$ !$acc parallel loop !!$ do i = 1, n !!$ z%v(i) = alpha * x%v(i) * y%v(i) + beta * z%v(i) @@ -327,23 +371,36 @@ contains select type(xx => x) type is (psb_z_vect_oacc) - if ((beta /= zzero) .and. y%is_host()) call y%sync_space() - if (xx%is_host()) call xx%sync_space() + if ((beta /= zzero) .and. y%is_host()) call y%sync() + if (xx%is_host()) call xx%sync() nx = size(xx%v) ny = size(y%v) if ((nx < m) .or. (ny < m)) then info = psb_err_internal_error_ else - !$acc parallel loop - do i = 1, m - y%v(i) = alpha * xx%v(i) + beta * y%v(i) - end do + call z_inner_oacc_axpby(m, alpha, x%v, beta, y%v, info) end if call y%set_dev() class default if ((alpha /= zzero) .and. (x%is_dev())) call x%sync() call y%axpby(m, alpha, x%v, beta, info) - end select + end select + contains + subroutine z_inner_oacc_axpby(m, alpha, x, beta, y, info) + !use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + complex(psb_dpk_), intent(inout) :: x(:) + complex(psb_dpk_), intent(inout) :: y(:) + complex(psb_dpk_), intent(in) :: alpha, beta + integer(psb_ipk_), intent(out) :: info + !$acc parallel + !$acc loop + do i = 1, m + y(i) = alpha * x(i) + beta * y(i) + end do + !$acc end parallel + end subroutine z_inner_oacc_axpby end subroutine z_oacc_axpby_v subroutine z_oacc_axpby_a(m, alpha, x, beta, y, info) @@ -356,7 +413,7 @@ contains integer(psb_ipk_), intent(out) :: info integer(psb_ipk_) :: i - if ((beta /= zzero) .and. (y%is_dev())) call y%sync_space() + if ((beta /= zzero) .and. (y%is_dev())) call y%sync() !$acc parallel loop do i = 1, m y%v(i) = alpha * x(i) + beta * y%v(i) @@ -375,7 +432,7 @@ contains integer(psb_ipk_), intent(out) :: info integer(psb_ipk_) :: nx, ny, nz, i logical :: gpu_done - + write(0,*)'upd_xyz' info = psb_success_ gpu_done = .false. @@ -385,9 +442,9 @@ contains class is (psb_z_vect_oacc) select type(zz => z) class is (psb_z_vect_oacc) - if ((beta /= zzero) .and. yy%is_host()) call yy%sync_space() - if ((delta /= zzero) .and. zz%is_host()) call zz%sync_space() - if (xx%is_host()) call xx%sync_space() + if ((beta /= zzero) .and. yy%is_host()) call yy%sync() + if ((delta /= zzero) .and. zz%is_host()) call zz%sync() + if (xx%is_host()) call xx%sync() nx = size(xx%v) ny = size(yy%v) nz = size(zz%v) @@ -432,8 +489,8 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() - if (y%is_host()) call y%sync_space() + if (ii%is_host()) call ii%sync() + if (y%is_host()) call y%sync() !$acc parallel loop do i = 1, n @@ -459,13 +516,13 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() + if (ii%is_host()) call ii%sync() class default call psb_errpush(info, 'z_oacc_sctb_x') return end select - if (y%is_host()) call y%sync_space() + if (y%is_host()) call y%sync() !$acc parallel loop do i = 1, n @@ -486,7 +543,7 @@ contains integer(psb_ipk_) :: i if (n == 0) return - if (y%is_dev()) call y%sync_space() + if (y%is_dev()) call y%sync() !$acc parallel loop do i = 1, n @@ -512,13 +569,13 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() + if (ii%is_host()) call ii%sync() class default call psb_errpush(info, 'z_oacc_gthzbuf') return end select - if (x%is_host()) call x%sync_space() + if (x%is_host()) call x%sync() !$acc parallel loop do i = 1, n @@ -539,13 +596,13 @@ contains select type(ii => idx) class is (psb_i_vect_oacc) - if (ii%is_host()) call ii%sync_space() + if (ii%is_host()) call ii%sync() class default call psb_errpush(info, 'z_oacc_gthzv_x') return end select - if (x%is_host()) call x%sync_space() + if (x%is_host()) call x%sync() !$acc parallel loop do i = 1, n @@ -573,9 +630,9 @@ contains type is (psb_i_vect_oacc) select type(vval => val) type is (psb_z_vect_oacc) - if (vval%is_host()) call vval%sync_space() - if (virl%is_host()) call virl%sync_space() - if (x%is_host()) call x%sync_space() + if (vval%is_host()) call vval%sync() + if (virl%is_host()) call virl%sync() + if (x%is_host()) call x%sync() !$acc parallel loop do i = 1, n x%v(virl%v(i)) = vval%v(i) @@ -588,11 +645,11 @@ contains if (.not.done_oacc) then select type(virl => irl) type is (psb_i_vect_oacc) - if (virl%is_dev()) call virl%sync_space() + if (virl%is_dev()) call virl%sync() end select select type(vval => val) type is (psb_z_vect_oacc) - if (vval%is_dev()) call vval%sync_space() + if (vval%is_dev()) call vval%sync() end select call x%ins(n, irl%v, val%v, dupl, info) end if @@ -616,7 +673,7 @@ contains integer(psb_ipk_) :: i info = 0 - if (x%is_dev()) call x%sync_space() + if (x%is_dev()) call x%sync() call x%psb_z_base_vect_type%ins(n, irl, val, dupl, info) call x%set_host() !$acc update device(x%v) @@ -635,7 +692,10 @@ contains call psb_errpush(info, 'z_oacc_bld_mn', i_err=(/n, n, n, n, n/)) end if call x%set_host() - !$acc update device(x%v) + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if + !$acc enter data copyin(x%v) end subroutine z_oacc_bld_mn @@ -657,7 +717,10 @@ contains x%v(:) = this(:) call x%set_host() - !$acc update device(x%v) + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if + !$acc enter data copyin(x%v) end subroutine z_oacc_bld_x @@ -676,13 +739,13 @@ contains if (nd < n) then call x%sync() call x%psb_z_base_vect_type%asb(n, info) - if (info == psb_success_) call x%sync_space() + if (info == psb_success_) call x%sync() call x%set_host() end if else if (size(x%v) < n) then call x%psb_z_base_vect_type%asb(n, info) - if (info == psb_success_) call x%sync_space() + if (info == psb_success_) call x%sync() call x%set_host() end if end if @@ -740,10 +803,9 @@ contains complex(psb_dpk_) :: res complex(psb_dpk_), external :: ddot integer(psb_ipk_) :: info - integer(psb_ipk_) :: i res = zzero - + !write(0,*) 'dot_v' select type(yy => y) type is (psb_z_base_vect_type) if (x%is_dev()) call x%sync() @@ -751,18 +813,26 @@ contains type is (psb_z_vect_oacc) if (x%is_host()) call x%sync() if (yy%is_host()) call yy%sync() - - !$acc parallel loop reduction(+:res) present(x%v, yy%v) - do i = 1, n - res = res + x%v(i) * yy%v(i) - end do - !$acc end parallel loop - + res = z_inner_oacc_dot(n, x%v, yy%v) class default call x%sync() res = y%dot(n, x%v) end select - + contains + function z_inner_oacc_dot(n, x, y) result(res) + implicit none + complex(psb_dpk_), intent(in) :: x(:) + complex(psb_dpk_), intent(in) :: y(:) + integer(psb_ipk_), intent(in) :: n + complex(psb_dpk_) :: res + integer(psb_ipk_) :: i + + !$acc parallel loop reduction(+:res) present(x, y) + do i = 1, n + res = res + x(i) * y(i) + end do + !$acc end parallel loop + end function z_inner_oacc_dot end function z_oacc_vect_dot function z_oacc_dot_a(n, x, y) result(res) @@ -808,7 +878,7 @@ contains implicit none class(psb_z_vect_oacc), intent(inout) :: x if (allocated(x%v)) then - call z_oacc_create_dev(x%v) + if (.not.acc_is_present(x%v)) call z_oacc_create_dev(x%v) end if contains subroutine z_oacc_create_dev(v) @@ -886,6 +956,9 @@ contains call psb_realloc(n, x%v, info) if (info == 0) then call x%set_host() + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if !$acc enter data create(x%v) call x%sync_space() end if @@ -902,7 +975,9 @@ contains integer(psb_ipk_), intent(out) :: info info = 0 if (allocated(x%v)) then - !$acc exit data delete(x%v) finalize + if (acc_is_present(x%v)) then + !$acc exit data delete(x%v) finalize + end if deallocate(x%v, stat=info) end if