diff --git a/openacc/Makefile b/openacc/Makefile index cdcc9f71..f60a810d 100644 --- a/openacc/Makefile +++ b/openacc/Makefile @@ -58,9 +58,10 @@ psb_oacc_mod.o : psb_i_oacc_vect_mod.o psb_l_oacc_vect_mod.o \ psb_z_oacc_ell_mat_mod.o psb_z_oacc_hll_mat_mod.o \ psb_oacc_env_mod.o -psb_s_oacc_vect_mod.o psb_d_oacc_vect_mod.o\ - psb_c_oacc_vect_mod.o psb_z_oacc_vect_mod.o: psb_i_oacc_vect_mod.o psb_l_oacc_vect_mod.o -psb_l_oacc_vect_mod.o: psb_i_oacc_vect_mod.o +psb_s_oacc_vect_mod.o psb_d_oacc_vect_mod.o \ + psb_c_oacc_vect_mod.o psb_z_oacc_vect_mod.o: psb_i_oacc_vect_mod.o psb_l_oacc_vect_mod.o psb_oacc_env_mod.o +psb_l_oacc_vect_mod.o: psb_i_oacc_vect_mod.o psb_oacc_env_mod.o +psb_i_oacc_vect_mod.o: psb_oacc_env_mod.o psb_s_oacc_csr_mat_mod.o psb_s_oacc_ell_mat_mod.o psb_s_oacc_hll_mat_mod.o: psb_s_oacc_vect_mod.o diff --git a/openacc/impl/psb_c_oacc_csr_vect_mv.F90 b/openacc/impl/psb_c_oacc_csr_vect_mv.F90 index 3c6f6494..c1030094 100644 --- a/openacc/impl/psb_c_oacc_csr_vect_mv.F90 +++ b/openacc/impl/psb_c_oacc_csr_vect_mv.F90 @@ -18,7 +18,7 @@ contains m = a%get_nrows() n = a%get_ncols() - if ((n /= size(x%v)) .or. (m /= size(y%v))) then + if ((n > size(x%v)) .or. (m > size(y%v))) then write(0,*) 'ocsrmv Size error ', m, n, size(x%v), size(y%v) info = psb_err_invalid_mat_state_ return diff --git a/openacc/impl/psb_c_oacc_ell_vect_mv.F90 b/openacc/impl/psb_c_oacc_ell_vect_mv.F90 index 8113297b..7a39c031 100644 --- a/openacc/impl/psb_c_oacc_ell_vect_mv.F90 +++ b/openacc/impl/psb_c_oacc_ell_vect_mv.F90 @@ -19,7 +19,7 @@ contains n = a%get_ncols() nzt = a%nzt nc = size(a%ja,2) - if ((n /= size(x%v)) .or. (m /= size(y%v))) then + if ((n > size(x%v)) .or. (m > size(y%v))) then write(0,*) 'oellmv Size error ', m, n, size(x%v), size(y%v) info = psb_err_invalid_mat_state_ return diff --git a/openacc/impl/psb_c_oacc_hll_vect_mv.F90 b/openacc/impl/psb_c_oacc_hll_vect_mv.F90 index 551b1a29..494ed149 100644 --- a/openacc/impl/psb_c_oacc_hll_vect_mv.F90 +++ b/openacc/impl/psb_c_oacc_hll_vect_mv.F90 @@ -20,7 +20,7 @@ contains nhacks = size(a%hkoffs) - 1 hksz = a%hksz - if ((n /= size(x%v)) .or. (m /= size(y%v))) then + if ((n > size(x%v)) .or. (m > size(y%v))) then write(0,*) 'Size error ', m, n, size(x%v), size(y%v) info = psb_err_invalid_mat_state_ return diff --git a/openacc/impl/psb_d_oacc_csr_vect_mv.F90 b/openacc/impl/psb_d_oacc_csr_vect_mv.F90 index 596f2b17..a2efdc3e 100644 --- a/openacc/impl/psb_d_oacc_csr_vect_mv.F90 +++ b/openacc/impl/psb_d_oacc_csr_vect_mv.F90 @@ -18,7 +18,7 @@ contains m = a%get_nrows() n = a%get_ncols() - if ((n /= size(x%v)) .or. (m /= size(y%v))) then + if ((n > size(x%v)) .or. (m > size(y%v))) then write(0,*) 'ocsrmv Size error ', m, n, size(x%v), size(y%v) info = psb_err_invalid_mat_state_ return diff --git a/openacc/impl/psb_d_oacc_ell_vect_mv.F90 b/openacc/impl/psb_d_oacc_ell_vect_mv.F90 index ddd4bfc8..b233669d 100644 --- a/openacc/impl/psb_d_oacc_ell_vect_mv.F90 +++ b/openacc/impl/psb_d_oacc_ell_vect_mv.F90 @@ -19,7 +19,7 @@ contains n = a%get_ncols() nzt = a%nzt nc = size(a%ja,2) - if ((n /= size(x%v)) .or. (m /= size(y%v))) then + if ((n > size(x%v)) .or. (m > size(y%v))) then write(0,*) 'oellmv Size error ', m, n, size(x%v), size(y%v) info = psb_err_invalid_mat_state_ return diff --git a/openacc/impl/psb_d_oacc_hll_vect_mv.F90 b/openacc/impl/psb_d_oacc_hll_vect_mv.F90 index f971d61a..150ade8e 100644 --- a/openacc/impl/psb_d_oacc_hll_vect_mv.F90 +++ b/openacc/impl/psb_d_oacc_hll_vect_mv.F90 @@ -20,7 +20,7 @@ contains nhacks = size(a%hkoffs) - 1 hksz = a%hksz - if ((n /= size(x%v)) .or. (m /= size(y%v))) then + if ((n > size(x%v)) .or. (m > size(y%v))) then write(0,*) 'Size error ', m, n, size(x%v), size(y%v) info = psb_err_invalid_mat_state_ return diff --git a/openacc/impl/psb_s_oacc_csr_vect_mv.F90 b/openacc/impl/psb_s_oacc_csr_vect_mv.F90 index 2799bd05..5d3cc30c 100644 --- a/openacc/impl/psb_s_oacc_csr_vect_mv.F90 +++ b/openacc/impl/psb_s_oacc_csr_vect_mv.F90 @@ -18,7 +18,7 @@ contains m = a%get_nrows() n = a%get_ncols() - if ((n /= size(x%v)) .or. (m /= size(y%v))) then + if ((n > size(x%v)) .or. (m > size(y%v))) then write(0,*) 'ocsrmv Size error ', m, n, size(x%v), size(y%v) info = psb_err_invalid_mat_state_ return diff --git a/openacc/impl/psb_s_oacc_ell_vect_mv.F90 b/openacc/impl/psb_s_oacc_ell_vect_mv.F90 index 81166643..76b1fe5b 100644 --- a/openacc/impl/psb_s_oacc_ell_vect_mv.F90 +++ b/openacc/impl/psb_s_oacc_ell_vect_mv.F90 @@ -19,7 +19,7 @@ contains n = a%get_ncols() nzt = a%nzt nc = size(a%ja,2) - if ((n /= size(x%v)) .or. (m /= size(y%v))) then + if ((n > size(x%v)) .or. (m > size(y%v))) then write(0,*) 'oellmv Size error ', m, n, size(x%v), size(y%v) info = psb_err_invalid_mat_state_ return diff --git a/openacc/impl/psb_s_oacc_hll_vect_mv.F90 b/openacc/impl/psb_s_oacc_hll_vect_mv.F90 index e289f07c..e1d42252 100644 --- a/openacc/impl/psb_s_oacc_hll_vect_mv.F90 +++ b/openacc/impl/psb_s_oacc_hll_vect_mv.F90 @@ -20,7 +20,7 @@ contains nhacks = size(a%hkoffs) - 1 hksz = a%hksz - if ((n /= size(x%v)) .or. (m /= size(y%v))) then + if ((n > size(x%v)) .or. (m > size(y%v))) then write(0,*) 'Size error ', m, n, size(x%v), size(y%v) info = psb_err_invalid_mat_state_ return diff --git a/openacc/impl/psb_z_oacc_csr_vect_mv.F90 b/openacc/impl/psb_z_oacc_csr_vect_mv.F90 index 75cc693b..b312b6b7 100644 --- a/openacc/impl/psb_z_oacc_csr_vect_mv.F90 +++ b/openacc/impl/psb_z_oacc_csr_vect_mv.F90 @@ -18,7 +18,7 @@ contains m = a%get_nrows() n = a%get_ncols() - if ((n /= size(x%v)) .or. (m /= size(y%v))) then + if ((n > size(x%v)) .or. (m > size(y%v))) then write(0,*) 'ocsrmv Size error ', m, n, size(x%v), size(y%v) info = psb_err_invalid_mat_state_ return diff --git a/openacc/impl/psb_z_oacc_ell_vect_mv.F90 b/openacc/impl/psb_z_oacc_ell_vect_mv.F90 index 8d442c1d..53283689 100644 --- a/openacc/impl/psb_z_oacc_ell_vect_mv.F90 +++ b/openacc/impl/psb_z_oacc_ell_vect_mv.F90 @@ -19,7 +19,7 @@ contains n = a%get_ncols() nzt = a%nzt nc = size(a%ja,2) - if ((n /= size(x%v)) .or. (m /= size(y%v))) then + if ((n > size(x%v)) .or. (m > size(y%v))) then write(0,*) 'oellmv Size error ', m, n, size(x%v), size(y%v) info = psb_err_invalid_mat_state_ return diff --git a/openacc/impl/psb_z_oacc_hll_vect_mv.F90 b/openacc/impl/psb_z_oacc_hll_vect_mv.F90 index e373d6ff..350592bc 100644 --- a/openacc/impl/psb_z_oacc_hll_vect_mv.F90 +++ b/openacc/impl/psb_z_oacc_hll_vect_mv.F90 @@ -20,7 +20,7 @@ contains nhacks = size(a%hkoffs) - 1 hksz = a%hksz - if ((n /= size(x%v)) .or. (m /= size(y%v))) then + if ((n > size(x%v)) .or. (m > size(y%v))) then write(0,*) 'Size error ', m, n, size(x%v), size(y%v) info = psb_err_invalid_mat_state_ return diff --git a/openacc/psb_c_oacc_vect_mod.F90 b/openacc/psb_c_oacc_vect_mod.F90 index 95c45646..4e2cca9e 100644 --- a/openacc/psb_c_oacc_vect_mod.F90 +++ b/openacc/psb_c_oacc_vect_mod.F90 @@ -3,6 +3,8 @@ module psb_c_oacc_vect_mod use openacc use psb_const_mod use psb_error_mod + use psb_realloc_mod + use psb_oacc_env_mod use psb_c_vect_mod use psb_i_vect_mod use psb_i_oacc_vect_mod @@ -26,6 +28,8 @@ module psb_c_oacc_vect_mod procedure, pass(x) :: bld_x => c_oacc_bld_x procedure, pass(x) :: bld_mn => c_oacc_bld_mn procedure, pass(x) :: free => c_oacc_vect_free + procedure, pass(x) :: free_buffer => c_oacc_vect_free_buffer + procedure, pass(x) :: maybe_free_buffer => c_oacc_vect_maybe_free_buffer procedure, pass(x) :: ins_a => c_oacc_ins_a procedure, pass(x) :: ins_v => c_oacc_ins_v procedure, pass(x) :: is_host => c_oacc_is_host @@ -36,11 +40,13 @@ module psb_c_oacc_vect_mod procedure, pass(x) :: set_sync => c_oacc_set_sync procedure, pass(x) :: set_scal => c_oacc_set_scal + procedure, pass(x) :: new_buffer => c_oacc_new_buffer procedure, pass(x) :: gthzv_x => c_oacc_gthzv_x - procedure, pass(x) :: gthzbuf_x => c_oacc_gthzbuf + procedure, pass(x) :: gthzbuf => c_oacc_gthzbuf procedure, pass(y) :: sctb => c_oacc_sctb procedure, pass(y) :: sctb_x => c_oacc_sctb_x procedure, pass(y) :: sctb_buf => c_oacc_sctb_buf + procedure, nopass :: device_wait => c_oacc_device_wait procedure, pass(x) :: get_size => c_oacc_get_size @@ -87,6 +93,11 @@ module psb_c_oacc_vect_mod contains + subroutine c_oacc_device_wait() + implicit none + call acc_wait_all() + end subroutine c_oacc_device_wait + subroutine c_oacc_absval1(x) implicit none class(psb_c_vect_oacc), intent(inout) :: x @@ -181,13 +192,17 @@ contains !$acc parallel loop reduction(max:mx) do i = 1, n if (abs(x(i)) > mx) mx = abs(x(i)) - end do - sum = szero - !$acc parallel loop reduction(+:sum) - do i = 1, n - sum = sum + abs(x(i)/mx)**2 end do - res = mx*sqrt(sum) + if (mx == szero) then + res = mx + else + sum = szero + !$acc parallel loop reduction(+:sum) + do i = 1, n + sum = sum + abs(x(i)/mx)**2 + end do + res = mx*sqrt(sum) + end if end function c_inner_oacc_nrm2 end function c_oacc_nrm2 @@ -398,29 +413,44 @@ contains class(psb_i_base_vect_type) :: idx complex(psb_spk_) :: beta class(psb_c_vect_oacc) :: y - integer(psb_ipk_) :: info - + integer(psb_ipk_) :: info, k + logical :: acc_done if (.not.allocated(y%combuf)) then call psb_errpush(psb_err_alloc_dealloc_, 'sctb_buf') return end if + acc_done = .false. select type(ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() if (y%is_host()) call y%sync() + !$acc update device(y%combuf) + call inner_sctb(n,y%combuf(i:i+n-1),beta,y%v,ii%v(i:i+n-1)) + call y%set_dev() + acc_done = .true. + end select - !$acc parallel loop - do i = 1, n - y%v(ii%v(i)) = beta * y%v(ii%v(i)) + y%combuf(i) + if (.not.acc_done) then + if (idx%is_dev()) call idx%sync() + if (y%is_dev()) call y%sync() + do k = 1, n + y%v(idx%v(k+i-1)) = beta * y%v(idx%v(k+i-1)) + y%combuf(k) end do + end if - class default - !$acc parallel loop - do i = 1, n - y%v(idx%v(i)) = beta * y%v(idx%v(i)) + y%combuf(i) + contains + subroutine inner_sctb(n,x,beta,y,idx) + integer(psb_ipk_) :: n, idx(:) + complex(psb_spk_) :: beta,x(:), y(:) + integer(psb_ipk_) :: k + !$acc parallel loop + do k = 1, n + y(idx(k)) = x(k) + beta *y(idx(k)) end do - end select + !$acc end parallel loop + end subroutine inner_sctb + end subroutine c_oacc_sctb_buf subroutine c_oacc_sctb_x(i, n, idx, x, beta, y) @@ -430,24 +460,41 @@ contains class(psb_i_base_vect_type) :: idx complex(psb_spk_) :: beta, x(:) class(psb_c_vect_oacc) :: y - integer(psb_ipk_) :: info, ni + integer(psb_ipk_) :: info, ni, k + logical :: acc_done + acc_done = .false. select type(ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() - class default - call psb_errpush(info, 'c_oacc_sctb_x') - return + if (y%is_host()) call y%sync() + if (acc_is_present(x)) then + call inner_sctb(n,x(i:i+n-1),beta,y%v,idx%v(i:i+n-1)) + acc_done = .true. + call y%set_dev() + end if end select + if (.not.acc_done) then + if (idx%is_dev()) call idx%sync() + if (y%is_dev()) call y%sync() + do k = 1, n + y%v(idx%v(k+i-1)) = beta * y%v(idx%v(k+i-1)) + x(k+i-1) + end do + call y%set_host() + end if - if (y%is_host()) call y%sync() - - !$acc parallel loop - do i = 1, n - y%v(idx%v(i)) = beta * y%v(idx%v(i)) + x(i) - end do - - call y%set_dev() + contains + subroutine inner_sctb(n,x,beta,y,idx) + integer(psb_ipk_) :: n, idx(:) + complex(psb_spk_) :: beta, x(:), y(:) + integer(psb_ipk_) :: k + !$acc parallel loop + do k = 1, n + y(idx(k)) = x(k) + beta *y(idx(k)) + end do + !$acc end parallel loop + end subroutine inner_sctb + end subroutine c_oacc_sctb_x subroutine c_oacc_sctb(n, idx, x, beta, y) @@ -463,7 +510,6 @@ contains if (n == 0) return if (y%is_dev()) call y%sync() - !$acc parallel loop do i = 1, n y%v(idx(i)) = beta * y%v(idx(i)) + x(i) end do @@ -477,30 +523,48 @@ contains integer(psb_ipk_) :: i, n class(psb_i_base_vect_type) :: idx class(psb_c_vect_oacc) :: x - integer(psb_ipk_) :: info + integer(psb_ipk_) :: info,k + logical :: acc_done info = 0 + acc_done = .false. + if (.not.allocated(x%combuf)) then call psb_errpush(psb_err_alloc_dealloc_, 'gthzbuf') return end if - select type(ii => idx) + select type (ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() - class default - call psb_errpush(info, 'c_oacc_gthzbuf') - return + if (x%is_host()) call x%sync() + call inner_gth(n,x%v,x%combuf(i:i+n-1),ii%v(i:i+n-1)) + acc_done = .true. end select - if (x%is_host()) call x%sync() + if (.not.acc_done) then + if (idx%is_dev()) call idx%sync() + if (x%is_dev()) call x%sync() + do k = 1, n + x%combuf(k+i-1) = x%v(idx%v(k+i-1)) + end do + end if - !$acc parallel loop - do i = 1, n - x%combuf(i) = x%v(idx%v(i)) - end do + contains + subroutine inner_gth(n,x,y,idx) + integer(psb_ipk_) :: n, idx(:) + complex(psb_spk_) :: x(:), y(:) + integer(psb_ipk_) :: k + + !$acc parallel loop present(y) + do k = 1, n + y(k) = x(idx(k)) + end do + !$acc end parallel loop + !$acc update self(y) + end subroutine inner_gth end subroutine c_oacc_gthzbuf - + subroutine c_oacc_gthzv_x(i, n, idx, x, y) use psb_base_mod implicit none @@ -508,24 +572,41 @@ contains class(psb_i_base_vect_type):: idx complex(psb_spk_) :: y(:) class(psb_c_vect_oacc):: x - integer(psb_ipk_) :: info + integer(psb_ipk_) :: info, k + logical :: acc_done info = 0 - - select type(ii => idx) + acc_done = .false. + select type (ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() - class default - call psb_errpush(info, 'c_oacc_gthzv_x') - return + if (x%is_host()) call x%sync() + if (acc_is_present(y)) then + call inner_gth(n,x%v,y(i:),ii%v(i:)) + acc_done=.true. + end if end select - - if (x%is_host()) call x%sync() - - !$acc parallel loop - do i = 1, n - y(i) = x%v(idx%v(i)) - end do + if (.not.acc_done) then + if (x%is_dev()) call x%sync() + if (idx%is_dev()) call idx%sync() + do k = 1, n + y(k+i-1) = x%v(idx%v(k+i-1)) + !write(0,*) 'oa gthzv ',k+i-1,idx%v(k+i-1),k,y(k) + end do + end if + contains + subroutine inner_gth(n,x,y,idx) + integer(psb_ipk_) :: n, idx(:) + complex(psb_spk_) :: x(:), y(:) + integer(psb_ipk_) :: k + + !$acc parallel loop present(y) + do k = 1, n + y(k) = x(idx(k)) + end do + !$acc end parallel loop + !$acc update self(y) + end subroutine inner_gth end subroutine c_oacc_gthzv_x subroutine c_oacc_ins_v(n, irl, val, dupl, x, info) @@ -718,7 +799,7 @@ contains integer(psb_ipk_) :: info res = czero - !write(0,*) 'dot_v' +!!$ write(0,*) 'oacc_dot_v' select type(yy => y) type is (psb_c_base_vect_type) if (x%is_dev()) call x%sync() @@ -762,6 +843,17 @@ contains end function c_oacc_dot_a + subroutine c_oacc_new_buffer(n,x,info) + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + integer(psb_ipk_), intent(out) :: info + if (n /= psb_size(x%combuf)) then + call x%psb_c_base_vect_type%new_buffer(n,info) + !$acc enter data copyin(x%combuf) + end if + end subroutine c_oacc_new_buffer + subroutine c_oacc_sync_dev_space(x) implicit none class(psb_c_vect_oacc), intent(inout) :: x @@ -860,12 +952,33 @@ contains class(psb_c_vect_oacc), intent(inout) :: x integer(psb_ipk_), intent(out) :: info info = 0 - if (allocated(x%v)) then - if (acc_is_present(x%v)) call acc_delete_finalize(x%v) - deallocate(x%v, stat=info) - end if + if (acc_is_present(x%v)) call acc_delete_finalize(x%v) + if (acc_is_present(x%combuf)) call acc_delete_finalize(x%combuf) + call x%psb_c_base_vect_type%free(info) end subroutine c_oacc_vect_free + + subroutine c_oacc_vect_maybe_free_buffer(x,info) + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + info = 0 + if (psb_oacc_get_maybe_free_buffer())& + & call x%free_buffer(info) + end subroutine c_oacc_vect_maybe_free_buffer + + subroutine c_oacc_vect_free_buffer(x,info) + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + info = 0 + if (acc_is_present(x%combuf)) call acc_delete_finalize(x%combuf) + call x%psb_c_base_vect_type%free_buffer(info) + + end subroutine c_oacc_vect_free_buffer + function c_oacc_get_size(x) result(res) implicit none class(psb_c_vect_oacc), intent(inout) :: x diff --git a/openacc/psb_d_oacc_vect_mod.F90 b/openacc/psb_d_oacc_vect_mod.F90 index 3d71e54c..80ac35f7 100644 --- a/openacc/psb_d_oacc_vect_mod.F90 +++ b/openacc/psb_d_oacc_vect_mod.F90 @@ -3,6 +3,8 @@ module psb_d_oacc_vect_mod use openacc use psb_const_mod use psb_error_mod + use psb_realloc_mod + use psb_oacc_env_mod use psb_d_vect_mod use psb_i_vect_mod use psb_i_oacc_vect_mod @@ -26,6 +28,8 @@ module psb_d_oacc_vect_mod procedure, pass(x) :: bld_x => d_oacc_bld_x procedure, pass(x) :: bld_mn => d_oacc_bld_mn procedure, pass(x) :: free => d_oacc_vect_free + procedure, pass(x) :: free_buffer => d_oacc_vect_free_buffer + procedure, pass(x) :: maybe_free_buffer => d_oacc_vect_maybe_free_buffer procedure, pass(x) :: ins_a => d_oacc_ins_a procedure, pass(x) :: ins_v => d_oacc_ins_v procedure, pass(x) :: is_host => d_oacc_is_host @@ -36,11 +40,13 @@ module psb_d_oacc_vect_mod procedure, pass(x) :: set_sync => d_oacc_set_sync procedure, pass(x) :: set_scal => d_oacc_set_scal + procedure, pass(x) :: new_buffer => d_oacc_new_buffer procedure, pass(x) :: gthzv_x => d_oacc_gthzv_x - procedure, pass(x) :: gthzbuf_x => d_oacc_gthzbuf + procedure, pass(x) :: gthzbuf => d_oacc_gthzbuf procedure, pass(y) :: sctb => d_oacc_sctb procedure, pass(y) :: sctb_x => d_oacc_sctb_x procedure, pass(y) :: sctb_buf => d_oacc_sctb_buf + procedure, nopass :: device_wait => d_oacc_device_wait procedure, pass(x) :: get_size => d_oacc_get_size @@ -87,6 +93,11 @@ module psb_d_oacc_vect_mod contains + subroutine d_oacc_device_wait() + implicit none + call acc_wait_all() + end subroutine d_oacc_device_wait + subroutine d_oacc_absval1(x) implicit none class(psb_d_vect_oacc), intent(inout) :: x @@ -181,13 +192,17 @@ contains !$acc parallel loop reduction(max:mx) do i = 1, n if (abs(x(i)) > mx) mx = abs(x(i)) - end do - sum = dzero - !$acc parallel loop reduction(+:sum) - do i = 1, n - sum = sum + abs(x(i)/mx)**2 end do - res = mx*sqrt(sum) + if (mx == dzero) then + res = mx + else + sum = dzero + !$acc parallel loop reduction(+:sum) + do i = 1, n + sum = sum + abs(x(i)/mx)**2 + end do + res = mx*sqrt(sum) + end if end function d_inner_oacc_nrm2 end function d_oacc_nrm2 @@ -398,29 +413,44 @@ contains class(psb_i_base_vect_type) :: idx real(psb_dpk_) :: beta class(psb_d_vect_oacc) :: y - integer(psb_ipk_) :: info - + integer(psb_ipk_) :: info, k + logical :: acc_done if (.not.allocated(y%combuf)) then call psb_errpush(psb_err_alloc_dealloc_, 'sctb_buf') return end if + acc_done = .false. select type(ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() if (y%is_host()) call y%sync() + !$acc update device(y%combuf) + call inner_sctb(n,y%combuf(i:i+n-1),beta,y%v,ii%v(i:i+n-1)) + call y%set_dev() + acc_done = .true. + end select - !$acc parallel loop - do i = 1, n - y%v(ii%v(i)) = beta * y%v(ii%v(i)) + y%combuf(i) + if (.not.acc_done) then + if (idx%is_dev()) call idx%sync() + if (y%is_dev()) call y%sync() + do k = 1, n + y%v(idx%v(k+i-1)) = beta * y%v(idx%v(k+i-1)) + y%combuf(k) end do + end if - class default - !$acc parallel loop - do i = 1, n - y%v(idx%v(i)) = beta * y%v(idx%v(i)) + y%combuf(i) + contains + subroutine inner_sctb(n,x,beta,y,idx) + integer(psb_ipk_) :: n, idx(:) + real(psb_dpk_) :: beta,x(:), y(:) + integer(psb_ipk_) :: k + !$acc parallel loop + do k = 1, n + y(idx(k)) = x(k) + beta *y(idx(k)) end do - end select + !$acc end parallel loop + end subroutine inner_sctb + end subroutine d_oacc_sctb_buf subroutine d_oacc_sctb_x(i, n, idx, x, beta, y) @@ -430,24 +460,41 @@ contains class(psb_i_base_vect_type) :: idx real(psb_dpk_) :: beta, x(:) class(psb_d_vect_oacc) :: y - integer(psb_ipk_) :: info, ni + integer(psb_ipk_) :: info, ni, k + logical :: acc_done + acc_done = .false. select type(ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() - class default - call psb_errpush(info, 'd_oacc_sctb_x') - return + if (y%is_host()) call y%sync() + if (acc_is_present(x)) then + call inner_sctb(n,x(i:i+n-1),beta,y%v,idx%v(i:i+n-1)) + acc_done = .true. + call y%set_dev() + end if end select + if (.not.acc_done) then + if (idx%is_dev()) call idx%sync() + if (y%is_dev()) call y%sync() + do k = 1, n + y%v(idx%v(k+i-1)) = beta * y%v(idx%v(k+i-1)) + x(k+i-1) + end do + call y%set_host() + end if - if (y%is_host()) call y%sync() - - !$acc parallel loop - do i = 1, n - y%v(idx%v(i)) = beta * y%v(idx%v(i)) + x(i) - end do - - call y%set_dev() + contains + subroutine inner_sctb(n,x,beta,y,idx) + integer(psb_ipk_) :: n, idx(:) + real(psb_dpk_) :: beta, x(:), y(:) + integer(psb_ipk_) :: k + !$acc parallel loop + do k = 1, n + y(idx(k)) = x(k) + beta *y(idx(k)) + end do + !$acc end parallel loop + end subroutine inner_sctb + end subroutine d_oacc_sctb_x subroutine d_oacc_sctb(n, idx, x, beta, y) @@ -463,7 +510,6 @@ contains if (n == 0) return if (y%is_dev()) call y%sync() - !$acc parallel loop do i = 1, n y%v(idx(i)) = beta * y%v(idx(i)) + x(i) end do @@ -477,30 +523,48 @@ contains integer(psb_ipk_) :: i, n class(psb_i_base_vect_type) :: idx class(psb_d_vect_oacc) :: x - integer(psb_ipk_) :: info + integer(psb_ipk_) :: info,k + logical :: acc_done info = 0 + acc_done = .false. + if (.not.allocated(x%combuf)) then call psb_errpush(psb_err_alloc_dealloc_, 'gthzbuf') return end if - select type(ii => idx) + select type (ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() - class default - call psb_errpush(info, 'd_oacc_gthzbuf') - return + if (x%is_host()) call x%sync() + call inner_gth(n,x%v,x%combuf(i:i+n-1),ii%v(i:i+n-1)) + acc_done = .true. end select - if (x%is_host()) call x%sync() + if (.not.acc_done) then + if (idx%is_dev()) call idx%sync() + if (x%is_dev()) call x%sync() + do k = 1, n + x%combuf(k+i-1) = x%v(idx%v(k+i-1)) + end do + end if - !$acc parallel loop - do i = 1, n - x%combuf(i) = x%v(idx%v(i)) - end do + contains + subroutine inner_gth(n,x,y,idx) + integer(psb_ipk_) :: n, idx(:) + real(psb_dpk_) :: x(:), y(:) + integer(psb_ipk_) :: k + + !$acc parallel loop present(y) + do k = 1, n + y(k) = x(idx(k)) + end do + !$acc end parallel loop + !$acc update self(y) + end subroutine inner_gth end subroutine d_oacc_gthzbuf - + subroutine d_oacc_gthzv_x(i, n, idx, x, y) use psb_base_mod implicit none @@ -508,24 +572,41 @@ contains class(psb_i_base_vect_type):: idx real(psb_dpk_) :: y(:) class(psb_d_vect_oacc):: x - integer(psb_ipk_) :: info + integer(psb_ipk_) :: info, k + logical :: acc_done info = 0 - - select type(ii => idx) + acc_done = .false. + select type (ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() - class default - call psb_errpush(info, 'd_oacc_gthzv_x') - return + if (x%is_host()) call x%sync() + if (acc_is_present(y)) then + call inner_gth(n,x%v,y(i:),ii%v(i:)) + acc_done=.true. + end if end select - - if (x%is_host()) call x%sync() - - !$acc parallel loop - do i = 1, n - y(i) = x%v(idx%v(i)) - end do + if (.not.acc_done) then + if (x%is_dev()) call x%sync() + if (idx%is_dev()) call idx%sync() + do k = 1, n + y(k+i-1) = x%v(idx%v(k+i-1)) + !write(0,*) 'oa gthzv ',k+i-1,idx%v(k+i-1),k,y(k) + end do + end if + contains + subroutine inner_gth(n,x,y,idx) + integer(psb_ipk_) :: n, idx(:) + real(psb_dpk_) :: x(:), y(:) + integer(psb_ipk_) :: k + + !$acc parallel loop present(y) + do k = 1, n + y(k) = x(idx(k)) + end do + !$acc end parallel loop + !$acc update self(y) + end subroutine inner_gth end subroutine d_oacc_gthzv_x subroutine d_oacc_ins_v(n, irl, val, dupl, x, info) @@ -718,7 +799,7 @@ contains integer(psb_ipk_) :: info res = dzero - !write(0,*) 'dot_v' +!!$ write(0,*) 'oacc_dot_v' select type(yy => y) type is (psb_d_base_vect_type) if (x%is_dev()) call x%sync() @@ -762,6 +843,17 @@ contains end function d_oacc_dot_a + subroutine d_oacc_new_buffer(n,x,info) + implicit none + class(psb_d_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + integer(psb_ipk_), intent(out) :: info + if (n /= psb_size(x%combuf)) then + call x%psb_d_base_vect_type%new_buffer(n,info) + !$acc enter data copyin(x%combuf) + end if + end subroutine d_oacc_new_buffer + subroutine d_oacc_sync_dev_space(x) implicit none class(psb_d_vect_oacc), intent(inout) :: x @@ -860,12 +952,33 @@ contains class(psb_d_vect_oacc), intent(inout) :: x integer(psb_ipk_), intent(out) :: info info = 0 - if (allocated(x%v)) then - if (acc_is_present(x%v)) call acc_delete_finalize(x%v) - deallocate(x%v, stat=info) - end if + if (acc_is_present(x%v)) call acc_delete_finalize(x%v) + if (acc_is_present(x%combuf)) call acc_delete_finalize(x%combuf) + call x%psb_d_base_vect_type%free(info) end subroutine d_oacc_vect_free + + subroutine d_oacc_vect_maybe_free_buffer(x,info) + implicit none + class(psb_d_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + info = 0 + if (psb_oacc_get_maybe_free_buffer())& + & call x%free_buffer(info) + end subroutine d_oacc_vect_maybe_free_buffer + + subroutine d_oacc_vect_free_buffer(x,info) + implicit none + class(psb_d_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + info = 0 + if (acc_is_present(x%combuf)) call acc_delete_finalize(x%combuf) + call x%psb_d_base_vect_type%free_buffer(info) + + end subroutine d_oacc_vect_free_buffer + function d_oacc_get_size(x) result(res) implicit none class(psb_d_vect_oacc), intent(inout) :: x diff --git a/openacc/psb_i_oacc_vect_mod.F90 b/openacc/psb_i_oacc_vect_mod.F90 index 42226f0c..cfd0c210 100644 --- a/openacc/psb_i_oacc_vect_mod.F90 +++ b/openacc/psb_i_oacc_vect_mod.F90 @@ -3,6 +3,8 @@ module psb_i_oacc_vect_mod use openacc use psb_const_mod use psb_error_mod + use psb_realloc_mod + use psb_oacc_env_mod use psb_i_vect_mod integer(psb_ipk_), parameter, private :: is_host = -1 @@ -24,6 +26,8 @@ module psb_i_oacc_vect_mod procedure, pass(x) :: bld_x => i_oacc_bld_x procedure, pass(x) :: bld_mn => i_oacc_bld_mn procedure, pass(x) :: free => i_oacc_vect_free + procedure, pass(x) :: free_buffer => i_oacc_vect_free_buffer + procedure, pass(x) :: maybe_free_buffer => i_oacc_vect_maybe_free_buffer procedure, pass(x) :: ins_a => i_oacc_ins_a procedure, pass(x) :: ins_v => i_oacc_ins_v procedure, pass(x) :: is_host => i_oacc_is_host @@ -34,11 +38,13 @@ module psb_i_oacc_vect_mod procedure, pass(x) :: set_sync => i_oacc_set_sync procedure, pass(x) :: set_scal => i_oacc_set_scal + procedure, pass(x) :: new_buffer => i_oacc_new_buffer procedure, pass(x) :: gthzv_x => i_oacc_gthzv_x - procedure, pass(x) :: gthzbuf_x => i_oacc_gthzbuf + procedure, pass(x) :: gthzbuf => i_oacc_gthzbuf procedure, pass(y) :: sctb => i_oacc_sctb procedure, pass(y) :: sctb_x => i_oacc_sctb_x procedure, pass(y) :: sctb_buf => i_oacc_sctb_buf + procedure, nopass :: device_wait => i_oacc_device_wait procedure, pass(x) :: get_size => i_oacc_get_size @@ -48,6 +54,11 @@ module psb_i_oacc_vect_mod contains + subroutine i_oacc_device_wait() + implicit none + call acc_wait_all() + end subroutine i_oacc_device_wait + subroutine i_oacc_sctb_buf(i, n, idx, beta, y) use psb_base_mod @@ -56,29 +67,44 @@ contains class(psb_i_base_vect_type) :: idx integer(psb_ipk_) :: beta class(psb_i_vect_oacc) :: y - integer(psb_ipk_) :: info - + integer(psb_ipk_) :: info, k + logical :: acc_done if (.not.allocated(y%combuf)) then call psb_errpush(psb_err_alloc_dealloc_, 'sctb_buf') return end if + acc_done = .false. select type(ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() if (y%is_host()) call y%sync() + !$acc update device(y%combuf) + call inner_sctb(n,y%combuf(i:i+n-1),beta,y%v,ii%v(i:i+n-1)) + call y%set_dev() + acc_done = .true. + end select - !$acc parallel loop - do i = 1, n - y%v(ii%v(i)) = beta * y%v(ii%v(i)) + y%combuf(i) + if (.not.acc_done) then + if (idx%is_dev()) call idx%sync() + if (y%is_dev()) call y%sync() + do k = 1, n + y%v(idx%v(k+i-1)) = beta * y%v(idx%v(k+i-1)) + y%combuf(k) end do + end if - class default - !$acc parallel loop - do i = 1, n - y%v(idx%v(i)) = beta * y%v(idx%v(i)) + y%combuf(i) + contains + subroutine inner_sctb(n,x,beta,y,idx) + integer(psb_ipk_) :: n, idx(:) + integer(psb_ipk_) :: beta,x(:), y(:) + integer(psb_ipk_) :: k + !$acc parallel loop + do k = 1, n + y(idx(k)) = x(k) + beta *y(idx(k)) end do - end select + !$acc end parallel loop + end subroutine inner_sctb + end subroutine i_oacc_sctb_buf subroutine i_oacc_sctb_x(i, n, idx, x, beta, y) @@ -88,24 +114,41 @@ contains class(psb_i_base_vect_type) :: idx integer(psb_ipk_) :: beta, x(:) class(psb_i_vect_oacc) :: y - integer(psb_ipk_) :: info, ni + integer(psb_ipk_) :: info, ni, k + logical :: acc_done + acc_done = .false. select type(ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() - class default - call psb_errpush(info, 'i_oacc_sctb_x') - return + if (y%is_host()) call y%sync() + if (acc_is_present(x)) then + call inner_sctb(n,x(i:i+n-1),beta,y%v,idx%v(i:i+n-1)) + acc_done = .true. + call y%set_dev() + end if end select + if (.not.acc_done) then + if (idx%is_dev()) call idx%sync() + if (y%is_dev()) call y%sync() + do k = 1, n + y%v(idx%v(k+i-1)) = beta * y%v(idx%v(k+i-1)) + x(k+i-1) + end do + call y%set_host() + end if - if (y%is_host()) call y%sync() - - !$acc parallel loop - do i = 1, n - y%v(idx%v(i)) = beta * y%v(idx%v(i)) + x(i) - end do - - call y%set_dev() + contains + subroutine inner_sctb(n,x,beta,y,idx) + integer(psb_ipk_) :: n, idx(:) + integer(psb_ipk_) :: beta, x(:), y(:) + integer(psb_ipk_) :: k + !$acc parallel loop + do k = 1, n + y(idx(k)) = x(k) + beta *y(idx(k)) + end do + !$acc end parallel loop + end subroutine inner_sctb + end subroutine i_oacc_sctb_x subroutine i_oacc_sctb(n, idx, x, beta, y) @@ -121,7 +164,6 @@ contains if (n == 0) return if (y%is_dev()) call y%sync() - !$acc parallel loop do i = 1, n y%v(idx(i)) = beta * y%v(idx(i)) + x(i) end do @@ -135,30 +177,48 @@ contains integer(psb_ipk_) :: i, n class(psb_i_base_vect_type) :: idx class(psb_i_vect_oacc) :: x - integer(psb_ipk_) :: info + integer(psb_ipk_) :: info,k + logical :: acc_done info = 0 + acc_done = .false. + if (.not.allocated(x%combuf)) then call psb_errpush(psb_err_alloc_dealloc_, 'gthzbuf') return end if - select type(ii => idx) + select type (ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() - class default - call psb_errpush(info, 'i_oacc_gthzbuf') - return + if (x%is_host()) call x%sync() + call inner_gth(n,x%v,x%combuf(i:i+n-1),ii%v(i:i+n-1)) + acc_done = .true. end select - if (x%is_host()) call x%sync() + if (.not.acc_done) then + if (idx%is_dev()) call idx%sync() + if (x%is_dev()) call x%sync() + do k = 1, n + x%combuf(k+i-1) = x%v(idx%v(k+i-1)) + end do + end if - !$acc parallel loop - do i = 1, n - x%combuf(i) = x%v(idx%v(i)) - end do + contains + subroutine inner_gth(n,x,y,idx) + integer(psb_ipk_) :: n, idx(:) + integer(psb_ipk_) :: x(:), y(:) + integer(psb_ipk_) :: k + + !$acc parallel loop present(y) + do k = 1, n + y(k) = x(idx(k)) + end do + !$acc end parallel loop + !$acc update self(y) + end subroutine inner_gth end subroutine i_oacc_gthzbuf - + subroutine i_oacc_gthzv_x(i, n, idx, x, y) use psb_base_mod implicit none @@ -166,24 +226,41 @@ contains class(psb_i_base_vect_type):: idx integer(psb_ipk_) :: y(:) class(psb_i_vect_oacc):: x - integer(psb_ipk_) :: info + integer(psb_ipk_) :: info, k + logical :: acc_done info = 0 - - select type(ii => idx) + acc_done = .false. + select type (ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() - class default - call psb_errpush(info, 'i_oacc_gthzv_x') - return + if (x%is_host()) call x%sync() + if (acc_is_present(y)) then + call inner_gth(n,x%v,y(i:),ii%v(i:)) + acc_done=.true. + end if end select - - if (x%is_host()) call x%sync() - - !$acc parallel loop - do i = 1, n - y(i) = x%v(idx%v(i)) - end do + if (.not.acc_done) then + if (x%is_dev()) call x%sync() + if (idx%is_dev()) call idx%sync() + do k = 1, n + y(k+i-1) = x%v(idx%v(k+i-1)) + !write(0,*) 'oa gthzv ',k+i-1,idx%v(k+i-1),k,y(k) + end do + end if + contains + subroutine inner_gth(n,x,y,idx) + integer(psb_ipk_) :: n, idx(:) + integer(psb_ipk_) :: x(:), y(:) + integer(psb_ipk_) :: k + + !$acc parallel loop present(y) + do k = 1, n + y(k) = x(idx(k)) + end do + !$acc end parallel loop + !$acc update self(y) + end subroutine inner_gth end subroutine i_oacc_gthzv_x subroutine i_oacc_ins_v(n, irl, val, dupl, x, info) @@ -366,6 +443,17 @@ contains end function i_oacc_get_fmt + subroutine i_oacc_new_buffer(n,x,info) + implicit none + class(psb_i_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + integer(psb_ipk_), intent(out) :: info + if (n /= psb_size(x%combuf)) then + call x%psb_i_base_vect_type%new_buffer(n,info) + !$acc enter data copyin(x%combuf) + end if + end subroutine i_oacc_new_buffer + subroutine i_oacc_sync_dev_space(x) implicit none class(psb_i_vect_oacc), intent(inout) :: x @@ -464,12 +552,33 @@ contains class(psb_i_vect_oacc), intent(inout) :: x integer(psb_ipk_), intent(out) :: info info = 0 - if (allocated(x%v)) then - if (acc_is_present(x%v)) call acc_delete_finalize(x%v) - deallocate(x%v, stat=info) - end if + if (acc_is_present(x%v)) call acc_delete_finalize(x%v) + if (acc_is_present(x%combuf)) call acc_delete_finalize(x%combuf) + call x%psb_i_base_vect_type%free(info) end subroutine i_oacc_vect_free + + subroutine i_oacc_vect_maybe_free_buffer(x,info) + implicit none + class(psb_i_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + info = 0 + if (psb_oacc_get_maybe_free_buffer())& + & call x%free_buffer(info) + + end subroutine i_oacc_vect_maybe_free_buffer + + subroutine i_oacc_vect_free_buffer(x,info) + implicit none + class(psb_i_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + info = 0 + if (acc_is_present(x%combuf)) call acc_delete_finalize(x%combuf) + call x%psb_i_base_vect_type%free_buffer(info) + end subroutine i_oacc_vect_free_buffer + function i_oacc_get_size(x) result(res) implicit none class(psb_i_vect_oacc), intent(inout) :: x diff --git a/openacc/psb_l_oacc_vect_mod.F90 b/openacc/psb_l_oacc_vect_mod.F90 index eb9b2b9a..5526796f 100644 --- a/openacc/psb_l_oacc_vect_mod.F90 +++ b/openacc/psb_l_oacc_vect_mod.F90 @@ -3,6 +3,8 @@ module psb_l_oacc_vect_mod use openacc use psb_const_mod use psb_error_mod + use psb_realloc_mod + use psb_oacc_env_mod use psb_l_vect_mod use psb_i_vect_mod use psb_i_oacc_vect_mod @@ -26,6 +28,8 @@ module psb_l_oacc_vect_mod procedure, pass(x) :: bld_x => l_oacc_bld_x procedure, pass(x) :: bld_mn => l_oacc_bld_mn procedure, pass(x) :: free => l_oacc_vect_free + procedure, pass(x) :: free_buffer => l_oacc_vect_free_buffer + procedure, pass(x) :: maybe_free_buffer => l_oacc_vect_maybe_free_buffer procedure, pass(x) :: ins_a => l_oacc_ins_a procedure, pass(x) :: ins_v => l_oacc_ins_v procedure, pass(x) :: is_host => l_oacc_is_host @@ -36,11 +40,13 @@ module psb_l_oacc_vect_mod procedure, pass(x) :: set_sync => l_oacc_set_sync procedure, pass(x) :: set_scal => l_oacc_set_scal + procedure, pass(x) :: new_buffer => l_oacc_new_buffer procedure, pass(x) :: gthzv_x => l_oacc_gthzv_x - procedure, pass(x) :: gthzbuf_x => l_oacc_gthzbuf + procedure, pass(x) :: gthzbuf => l_oacc_gthzbuf procedure, pass(y) :: sctb => l_oacc_sctb procedure, pass(y) :: sctb_x => l_oacc_sctb_x procedure, pass(y) :: sctb_buf => l_oacc_sctb_buf + procedure, nopass :: device_wait => l_oacc_device_wait procedure, pass(x) :: get_size => l_oacc_get_size @@ -50,6 +56,11 @@ module psb_l_oacc_vect_mod contains + subroutine l_oacc_device_wait() + implicit none + call acc_wait_all() + end subroutine l_oacc_device_wait + subroutine l_oacc_sctb_buf(i, n, idx, beta, y) use psb_base_mod @@ -58,29 +69,44 @@ contains class(psb_i_base_vect_type) :: idx integer(psb_lpk_) :: beta class(psb_l_vect_oacc) :: y - integer(psb_ipk_) :: info - + integer(psb_ipk_) :: info, k + logical :: acc_done if (.not.allocated(y%combuf)) then call psb_errpush(psb_err_alloc_dealloc_, 'sctb_buf') return end if + acc_done = .false. select type(ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() if (y%is_host()) call y%sync() + !$acc update device(y%combuf) + call inner_sctb(n,y%combuf(i:i+n-1),beta,y%v,ii%v(i:i+n-1)) + call y%set_dev() + acc_done = .true. + end select - !$acc parallel loop - do i = 1, n - y%v(ii%v(i)) = beta * y%v(ii%v(i)) + y%combuf(i) + if (.not.acc_done) then + if (idx%is_dev()) call idx%sync() + if (y%is_dev()) call y%sync() + do k = 1, n + y%v(idx%v(k+i-1)) = beta * y%v(idx%v(k+i-1)) + y%combuf(k) end do + end if - class default - !$acc parallel loop - do i = 1, n - y%v(idx%v(i)) = beta * y%v(idx%v(i)) + y%combuf(i) + contains + subroutine inner_sctb(n,x,beta,y,idx) + integer(psb_ipk_) :: n, idx(:) + integer(psb_lpk_) :: beta,x(:), y(:) + integer(psb_ipk_) :: k + !$acc parallel loop + do k = 1, n + y(idx(k)) = x(k) + beta *y(idx(k)) end do - end select + !$acc end parallel loop + end subroutine inner_sctb + end subroutine l_oacc_sctb_buf subroutine l_oacc_sctb_x(i, n, idx, x, beta, y) @@ -90,24 +116,41 @@ contains class(psb_i_base_vect_type) :: idx integer(psb_lpk_) :: beta, x(:) class(psb_l_vect_oacc) :: y - integer(psb_ipk_) :: info, ni + integer(psb_ipk_) :: info, ni, k + logical :: acc_done + acc_done = .false. select type(ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() - class default - call psb_errpush(info, 'l_oacc_sctb_x') - return + if (y%is_host()) call y%sync() + if (acc_is_present(x)) then + call inner_sctb(n,x(i:i+n-1),beta,y%v,idx%v(i:i+n-1)) + acc_done = .true. + call y%set_dev() + end if end select + if (.not.acc_done) then + if (idx%is_dev()) call idx%sync() + if (y%is_dev()) call y%sync() + do k = 1, n + y%v(idx%v(k+i-1)) = beta * y%v(idx%v(k+i-1)) + x(k+i-1) + end do + call y%set_host() + end if - if (y%is_host()) call y%sync() - - !$acc parallel loop - do i = 1, n - y%v(idx%v(i)) = beta * y%v(idx%v(i)) + x(i) - end do - - call y%set_dev() + contains + subroutine inner_sctb(n,x,beta,y,idx) + integer(psb_ipk_) :: n, idx(:) + integer(psb_lpk_) :: beta, x(:), y(:) + integer(psb_ipk_) :: k + !$acc parallel loop + do k = 1, n + y(idx(k)) = x(k) + beta *y(idx(k)) + end do + !$acc end parallel loop + end subroutine inner_sctb + end subroutine l_oacc_sctb_x subroutine l_oacc_sctb(n, idx, x, beta, y) @@ -123,7 +166,6 @@ contains if (n == 0) return if (y%is_dev()) call y%sync() - !$acc parallel loop do i = 1, n y%v(idx(i)) = beta * y%v(idx(i)) + x(i) end do @@ -137,30 +179,48 @@ contains integer(psb_ipk_) :: i, n class(psb_i_base_vect_type) :: idx class(psb_l_vect_oacc) :: x - integer(psb_ipk_) :: info + integer(psb_ipk_) :: info,k + logical :: acc_done info = 0 + acc_done = .false. + if (.not.allocated(x%combuf)) then call psb_errpush(psb_err_alloc_dealloc_, 'gthzbuf') return end if - select type(ii => idx) + select type (ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() - class default - call psb_errpush(info, 'l_oacc_gthzbuf') - return + if (x%is_host()) call x%sync() + call inner_gth(n,x%v,x%combuf(i:i+n-1),ii%v(i:i+n-1)) + acc_done = .true. end select - if (x%is_host()) call x%sync() + if (.not.acc_done) then + if (idx%is_dev()) call idx%sync() + if (x%is_dev()) call x%sync() + do k = 1, n + x%combuf(k+i-1) = x%v(idx%v(k+i-1)) + end do + end if - !$acc parallel loop - do i = 1, n - x%combuf(i) = x%v(idx%v(i)) - end do + contains + subroutine inner_gth(n,x,y,idx) + integer(psb_ipk_) :: n, idx(:) + integer(psb_lpk_) :: x(:), y(:) + integer(psb_ipk_) :: k + + !$acc parallel loop present(y) + do k = 1, n + y(k) = x(idx(k)) + end do + !$acc end parallel loop + !$acc update self(y) + end subroutine inner_gth end subroutine l_oacc_gthzbuf - + subroutine l_oacc_gthzv_x(i, n, idx, x, y) use psb_base_mod implicit none @@ -168,24 +228,41 @@ contains class(psb_i_base_vect_type):: idx integer(psb_lpk_) :: y(:) class(psb_l_vect_oacc):: x - integer(psb_ipk_) :: info + integer(psb_ipk_) :: info, k + logical :: acc_done info = 0 - - select type(ii => idx) + acc_done = .false. + select type (ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() - class default - call psb_errpush(info, 'l_oacc_gthzv_x') - return + if (x%is_host()) call x%sync() + if (acc_is_present(y)) then + call inner_gth(n,x%v,y(i:),ii%v(i:)) + acc_done=.true. + end if end select - - if (x%is_host()) call x%sync() - - !$acc parallel loop - do i = 1, n - y(i) = x%v(idx%v(i)) - end do + if (.not.acc_done) then + if (x%is_dev()) call x%sync() + if (idx%is_dev()) call idx%sync() + do k = 1, n + y(k+i-1) = x%v(idx%v(k+i-1)) + !write(0,*) 'oa gthzv ',k+i-1,idx%v(k+i-1),k,y(k) + end do + end if + contains + subroutine inner_gth(n,x,y,idx) + integer(psb_ipk_) :: n, idx(:) + integer(psb_lpk_) :: x(:), y(:) + integer(psb_ipk_) :: k + + !$acc parallel loop present(y) + do k = 1, n + y(k) = x(idx(k)) + end do + !$acc end parallel loop + !$acc update self(y) + end subroutine inner_gth end subroutine l_oacc_gthzv_x subroutine l_oacc_ins_v(n, irl, val, dupl, x, info) @@ -368,6 +445,17 @@ contains end function l_oacc_get_fmt + subroutine l_oacc_new_buffer(n,x,info) + implicit none + class(psb_l_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + integer(psb_ipk_), intent(out) :: info + if (n /= psb_size(x%combuf)) then + call x%psb_l_base_vect_type%new_buffer(n,info) + !$acc enter data copyin(x%combuf) + end if + end subroutine l_oacc_new_buffer + subroutine l_oacc_sync_dev_space(x) implicit none class(psb_l_vect_oacc), intent(inout) :: x @@ -466,12 +554,33 @@ contains class(psb_l_vect_oacc), intent(inout) :: x integer(psb_ipk_), intent(out) :: info info = 0 - if (allocated(x%v)) then - if (acc_is_present(x%v)) call acc_delete_finalize(x%v) - deallocate(x%v, stat=info) - end if + if (acc_is_present(x%v)) call acc_delete_finalize(x%v) + if (acc_is_present(x%combuf)) call acc_delete_finalize(x%combuf) + call x%psb_l_base_vect_type%free(info) end subroutine l_oacc_vect_free + + subroutine l_oacc_vect_maybe_free_buffer(x,info) + implicit none + class(psb_l_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + info = 0 + if (psb_oacc_get_maybe_free_buffer())& + & call x%free_buffer(info) + + end subroutine l_oacc_vect_maybe_free_buffer + + subroutine l_oacc_vect_free_buffer(x,info) + implicit none + class(psb_l_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + info = 0 + if (acc_is_present(x%combuf)) call acc_delete_finalize(x%combuf) + call x%psb_l_base_vect_type%free_buffer(info) + end subroutine l_oacc_vect_free_buffer + function l_oacc_get_size(x) result(res) implicit none class(psb_l_vect_oacc), intent(inout) :: x diff --git a/openacc/psb_oacc_env_mod.F90 b/openacc/psb_oacc_env_mod.F90 index 83c9426d..dc01ad3a 100644 --- a/openacc/psb_oacc_env_mod.F90 +++ b/openacc/psb_oacc_env_mod.F90 @@ -1,18 +1,29 @@ module psb_oacc_env_mod -contains + use psb_penv_mod + use psb_const_mod + use psb_error_mod + logical, private :: oacc_do_maybe_free_buffer = .false. - subroutine psb_oacc_init(ctxt, dev) - use psb_penv_mod - use psb_const_mod - use psb_error_mod - type(psb_ctxt_type), intent(in) :: ctxt - integer, intent(in), optional :: dev +contains + function psb_oacc_get_maybe_free_buffer() result(res) + logical :: res + res = oacc_do_maybe_free_buffer + end function psb_oacc_get_maybe_free_buffer - end subroutine psb_oacc_init + subroutine psb_oacc_set_maybe_free_buffer(val) + logical, intent(in) :: val + oacc_do_maybe_free_buffer = val + end subroutine psb_oacc_set_maybe_free_buffer - subroutine psb_oacc_exit() - integer :: res + subroutine psb_oacc_init(ctxt, dev) + type(psb_ctxt_type), intent(in) :: ctxt + integer, intent(in), optional :: dev + + end subroutine psb_oacc_init - end subroutine psb_oacc_exit + subroutine psb_oacc_exit() + integer :: res + + end subroutine psb_oacc_exit end module psb_oacc_env_mod diff --git a/openacc/psb_s_oacc_vect_mod.F90 b/openacc/psb_s_oacc_vect_mod.F90 index 16b45461..b80108ab 100644 --- a/openacc/psb_s_oacc_vect_mod.F90 +++ b/openacc/psb_s_oacc_vect_mod.F90 @@ -3,6 +3,8 @@ module psb_s_oacc_vect_mod use openacc use psb_const_mod use psb_error_mod + use psb_realloc_mod + use psb_oacc_env_mod use psb_s_vect_mod use psb_i_vect_mod use psb_i_oacc_vect_mod @@ -26,6 +28,8 @@ module psb_s_oacc_vect_mod procedure, pass(x) :: bld_x => s_oacc_bld_x procedure, pass(x) :: bld_mn => s_oacc_bld_mn procedure, pass(x) :: free => s_oacc_vect_free + procedure, pass(x) :: free_buffer => s_oacc_vect_free_buffer + procedure, pass(x) :: maybe_free_buffer => s_oacc_vect_maybe_free_buffer procedure, pass(x) :: ins_a => s_oacc_ins_a procedure, pass(x) :: ins_v => s_oacc_ins_v procedure, pass(x) :: is_host => s_oacc_is_host @@ -36,11 +40,13 @@ module psb_s_oacc_vect_mod procedure, pass(x) :: set_sync => s_oacc_set_sync procedure, pass(x) :: set_scal => s_oacc_set_scal + procedure, pass(x) :: new_buffer => s_oacc_new_buffer procedure, pass(x) :: gthzv_x => s_oacc_gthzv_x - procedure, pass(x) :: gthzbuf_x => s_oacc_gthzbuf + procedure, pass(x) :: gthzbuf => s_oacc_gthzbuf procedure, pass(y) :: sctb => s_oacc_sctb procedure, pass(y) :: sctb_x => s_oacc_sctb_x procedure, pass(y) :: sctb_buf => s_oacc_sctb_buf + procedure, nopass :: device_wait => s_oacc_device_wait procedure, pass(x) :: get_size => s_oacc_get_size @@ -87,6 +93,11 @@ module psb_s_oacc_vect_mod contains + subroutine s_oacc_device_wait() + implicit none + call acc_wait_all() + end subroutine s_oacc_device_wait + subroutine s_oacc_absval1(x) implicit none class(psb_s_vect_oacc), intent(inout) :: x @@ -181,13 +192,17 @@ contains !$acc parallel loop reduction(max:mx) do i = 1, n if (abs(x(i)) > mx) mx = abs(x(i)) - end do - sum = szero - !$acc parallel loop reduction(+:sum) - do i = 1, n - sum = sum + abs(x(i)/mx)**2 end do - res = mx*sqrt(sum) + if (mx == szero) then + res = mx + else + sum = szero + !$acc parallel loop reduction(+:sum) + do i = 1, n + sum = sum + abs(x(i)/mx)**2 + end do + res = mx*sqrt(sum) + end if end function s_inner_oacc_nrm2 end function s_oacc_nrm2 @@ -398,29 +413,44 @@ contains class(psb_i_base_vect_type) :: idx real(psb_spk_) :: beta class(psb_s_vect_oacc) :: y - integer(psb_ipk_) :: info - + integer(psb_ipk_) :: info, k + logical :: acc_done if (.not.allocated(y%combuf)) then call psb_errpush(psb_err_alloc_dealloc_, 'sctb_buf') return end if + acc_done = .false. select type(ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() if (y%is_host()) call y%sync() + !$acc update device(y%combuf) + call inner_sctb(n,y%combuf(i:i+n-1),beta,y%v,ii%v(i:i+n-1)) + call y%set_dev() + acc_done = .true. + end select - !$acc parallel loop - do i = 1, n - y%v(ii%v(i)) = beta * y%v(ii%v(i)) + y%combuf(i) + if (.not.acc_done) then + if (idx%is_dev()) call idx%sync() + if (y%is_dev()) call y%sync() + do k = 1, n + y%v(idx%v(k+i-1)) = beta * y%v(idx%v(k+i-1)) + y%combuf(k) end do + end if - class default - !$acc parallel loop - do i = 1, n - y%v(idx%v(i)) = beta * y%v(idx%v(i)) + y%combuf(i) + contains + subroutine inner_sctb(n,x,beta,y,idx) + integer(psb_ipk_) :: n, idx(:) + real(psb_spk_) :: beta,x(:), y(:) + integer(psb_ipk_) :: k + !$acc parallel loop + do k = 1, n + y(idx(k)) = x(k) + beta *y(idx(k)) end do - end select + !$acc end parallel loop + end subroutine inner_sctb + end subroutine s_oacc_sctb_buf subroutine s_oacc_sctb_x(i, n, idx, x, beta, y) @@ -430,24 +460,41 @@ contains class(psb_i_base_vect_type) :: idx real(psb_spk_) :: beta, x(:) class(psb_s_vect_oacc) :: y - integer(psb_ipk_) :: info, ni + integer(psb_ipk_) :: info, ni, k + logical :: acc_done + acc_done = .false. select type(ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() - class default - call psb_errpush(info, 's_oacc_sctb_x') - return + if (y%is_host()) call y%sync() + if (acc_is_present(x)) then + call inner_sctb(n,x(i:i+n-1),beta,y%v,idx%v(i:i+n-1)) + acc_done = .true. + call y%set_dev() + end if end select + if (.not.acc_done) then + if (idx%is_dev()) call idx%sync() + if (y%is_dev()) call y%sync() + do k = 1, n + y%v(idx%v(k+i-1)) = beta * y%v(idx%v(k+i-1)) + x(k+i-1) + end do + call y%set_host() + end if - if (y%is_host()) call y%sync() - - !$acc parallel loop - do i = 1, n - y%v(idx%v(i)) = beta * y%v(idx%v(i)) + x(i) - end do - - call y%set_dev() + contains + subroutine inner_sctb(n,x,beta,y,idx) + integer(psb_ipk_) :: n, idx(:) + real(psb_spk_) :: beta, x(:), y(:) + integer(psb_ipk_) :: k + !$acc parallel loop + do k = 1, n + y(idx(k)) = x(k) + beta *y(idx(k)) + end do + !$acc end parallel loop + end subroutine inner_sctb + end subroutine s_oacc_sctb_x subroutine s_oacc_sctb(n, idx, x, beta, y) @@ -463,7 +510,6 @@ contains if (n == 0) return if (y%is_dev()) call y%sync() - !$acc parallel loop do i = 1, n y%v(idx(i)) = beta * y%v(idx(i)) + x(i) end do @@ -477,30 +523,48 @@ contains integer(psb_ipk_) :: i, n class(psb_i_base_vect_type) :: idx class(psb_s_vect_oacc) :: x - integer(psb_ipk_) :: info + integer(psb_ipk_) :: info,k + logical :: acc_done info = 0 + acc_done = .false. + if (.not.allocated(x%combuf)) then call psb_errpush(psb_err_alloc_dealloc_, 'gthzbuf') return end if - select type(ii => idx) + select type (ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() - class default - call psb_errpush(info, 's_oacc_gthzbuf') - return + if (x%is_host()) call x%sync() + call inner_gth(n,x%v,x%combuf(i:i+n-1),ii%v(i:i+n-1)) + acc_done = .true. end select - if (x%is_host()) call x%sync() + if (.not.acc_done) then + if (idx%is_dev()) call idx%sync() + if (x%is_dev()) call x%sync() + do k = 1, n + x%combuf(k+i-1) = x%v(idx%v(k+i-1)) + end do + end if - !$acc parallel loop - do i = 1, n - x%combuf(i) = x%v(idx%v(i)) - end do + contains + subroutine inner_gth(n,x,y,idx) + integer(psb_ipk_) :: n, idx(:) + real(psb_spk_) :: x(:), y(:) + integer(psb_ipk_) :: k + + !$acc parallel loop present(y) + do k = 1, n + y(k) = x(idx(k)) + end do + !$acc end parallel loop + !$acc update self(y) + end subroutine inner_gth end subroutine s_oacc_gthzbuf - + subroutine s_oacc_gthzv_x(i, n, idx, x, y) use psb_base_mod implicit none @@ -508,24 +572,41 @@ contains class(psb_i_base_vect_type):: idx real(psb_spk_) :: y(:) class(psb_s_vect_oacc):: x - integer(psb_ipk_) :: info + integer(psb_ipk_) :: info, k + logical :: acc_done info = 0 - - select type(ii => idx) + acc_done = .false. + select type (ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() - class default - call psb_errpush(info, 's_oacc_gthzv_x') - return + if (x%is_host()) call x%sync() + if (acc_is_present(y)) then + call inner_gth(n,x%v,y(i:),ii%v(i:)) + acc_done=.true. + end if end select - - if (x%is_host()) call x%sync() - - !$acc parallel loop - do i = 1, n - y(i) = x%v(idx%v(i)) - end do + if (.not.acc_done) then + if (x%is_dev()) call x%sync() + if (idx%is_dev()) call idx%sync() + do k = 1, n + y(k+i-1) = x%v(idx%v(k+i-1)) + !write(0,*) 'oa gthzv ',k+i-1,idx%v(k+i-1),k,y(k) + end do + end if + contains + subroutine inner_gth(n,x,y,idx) + integer(psb_ipk_) :: n, idx(:) + real(psb_spk_) :: x(:), y(:) + integer(psb_ipk_) :: k + + !$acc parallel loop present(y) + do k = 1, n + y(k) = x(idx(k)) + end do + !$acc end parallel loop + !$acc update self(y) + end subroutine inner_gth end subroutine s_oacc_gthzv_x subroutine s_oacc_ins_v(n, irl, val, dupl, x, info) @@ -718,7 +799,7 @@ contains integer(psb_ipk_) :: info res = szero - !write(0,*) 'dot_v' +!!$ write(0,*) 'oacc_dot_v' select type(yy => y) type is (psb_s_base_vect_type) if (x%is_dev()) call x%sync() @@ -762,6 +843,17 @@ contains end function s_oacc_dot_a + subroutine s_oacc_new_buffer(n,x,info) + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + integer(psb_ipk_), intent(out) :: info + if (n /= psb_size(x%combuf)) then + call x%psb_s_base_vect_type%new_buffer(n,info) + !$acc enter data copyin(x%combuf) + end if + end subroutine s_oacc_new_buffer + subroutine s_oacc_sync_dev_space(x) implicit none class(psb_s_vect_oacc), intent(inout) :: x @@ -860,12 +952,33 @@ contains class(psb_s_vect_oacc), intent(inout) :: x integer(psb_ipk_), intent(out) :: info info = 0 - if (allocated(x%v)) then - if (acc_is_present(x%v)) call acc_delete_finalize(x%v) - deallocate(x%v, stat=info) - end if + if (acc_is_present(x%v)) call acc_delete_finalize(x%v) + if (acc_is_present(x%combuf)) call acc_delete_finalize(x%combuf) + call x%psb_s_base_vect_type%free(info) end subroutine s_oacc_vect_free + + subroutine s_oacc_vect_maybe_free_buffer(x,info) + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + info = 0 + if (psb_oacc_get_maybe_free_buffer())& + & call x%free_buffer(info) + end subroutine s_oacc_vect_maybe_free_buffer + + subroutine s_oacc_vect_free_buffer(x,info) + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + info = 0 + if (acc_is_present(x%combuf)) call acc_delete_finalize(x%combuf) + call x%psb_s_base_vect_type%free_buffer(info) + + end subroutine s_oacc_vect_free_buffer + function s_oacc_get_size(x) result(res) implicit none class(psb_s_vect_oacc), intent(inout) :: x diff --git a/openacc/psb_z_oacc_vect_mod.F90 b/openacc/psb_z_oacc_vect_mod.F90 index 9e6bbb2d..86107c31 100644 --- a/openacc/psb_z_oacc_vect_mod.F90 +++ b/openacc/psb_z_oacc_vect_mod.F90 @@ -3,6 +3,8 @@ module psb_z_oacc_vect_mod use openacc use psb_const_mod use psb_error_mod + use psb_realloc_mod + use psb_oacc_env_mod use psb_z_vect_mod use psb_i_vect_mod use psb_i_oacc_vect_mod @@ -26,6 +28,8 @@ module psb_z_oacc_vect_mod procedure, pass(x) :: bld_x => z_oacc_bld_x procedure, pass(x) :: bld_mn => z_oacc_bld_mn procedure, pass(x) :: free => z_oacc_vect_free + procedure, pass(x) :: free_buffer => z_oacc_vect_free_buffer + procedure, pass(x) :: maybe_free_buffer => z_oacc_vect_maybe_free_buffer procedure, pass(x) :: ins_a => z_oacc_ins_a procedure, pass(x) :: ins_v => z_oacc_ins_v procedure, pass(x) :: is_host => z_oacc_is_host @@ -36,11 +40,13 @@ module psb_z_oacc_vect_mod procedure, pass(x) :: set_sync => z_oacc_set_sync procedure, pass(x) :: set_scal => z_oacc_set_scal + procedure, pass(x) :: new_buffer => z_oacc_new_buffer procedure, pass(x) :: gthzv_x => z_oacc_gthzv_x - procedure, pass(x) :: gthzbuf_x => z_oacc_gthzbuf + procedure, pass(x) :: gthzbuf => z_oacc_gthzbuf procedure, pass(y) :: sctb => z_oacc_sctb procedure, pass(y) :: sctb_x => z_oacc_sctb_x procedure, pass(y) :: sctb_buf => z_oacc_sctb_buf + procedure, nopass :: device_wait => z_oacc_device_wait procedure, pass(x) :: get_size => z_oacc_get_size @@ -87,6 +93,11 @@ module psb_z_oacc_vect_mod contains + subroutine z_oacc_device_wait() + implicit none + call acc_wait_all() + end subroutine z_oacc_device_wait + subroutine z_oacc_absval1(x) implicit none class(psb_z_vect_oacc), intent(inout) :: x @@ -181,13 +192,17 @@ contains !$acc parallel loop reduction(max:mx) do i = 1, n if (abs(x(i)) > mx) mx = abs(x(i)) - end do - sum = dzero - !$acc parallel loop reduction(+:sum) - do i = 1, n - sum = sum + abs(x(i)/mx)**2 end do - res = mx*sqrt(sum) + if (mx == dzero) then + res = mx + else + sum = dzero + !$acc parallel loop reduction(+:sum) + do i = 1, n + sum = sum + abs(x(i)/mx)**2 + end do + res = mx*sqrt(sum) + end if end function z_inner_oacc_nrm2 end function z_oacc_nrm2 @@ -398,29 +413,44 @@ contains class(psb_i_base_vect_type) :: idx complex(psb_dpk_) :: beta class(psb_z_vect_oacc) :: y - integer(psb_ipk_) :: info - + integer(psb_ipk_) :: info, k + logical :: acc_done if (.not.allocated(y%combuf)) then call psb_errpush(psb_err_alloc_dealloc_, 'sctb_buf') return end if + acc_done = .false. select type(ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() if (y%is_host()) call y%sync() + !$acc update device(y%combuf) + call inner_sctb(n,y%combuf(i:i+n-1),beta,y%v,ii%v(i:i+n-1)) + call y%set_dev() + acc_done = .true. + end select - !$acc parallel loop - do i = 1, n - y%v(ii%v(i)) = beta * y%v(ii%v(i)) + y%combuf(i) + if (.not.acc_done) then + if (idx%is_dev()) call idx%sync() + if (y%is_dev()) call y%sync() + do k = 1, n + y%v(idx%v(k+i-1)) = beta * y%v(idx%v(k+i-1)) + y%combuf(k) end do + end if - class default - !$acc parallel loop - do i = 1, n - y%v(idx%v(i)) = beta * y%v(idx%v(i)) + y%combuf(i) + contains + subroutine inner_sctb(n,x,beta,y,idx) + integer(psb_ipk_) :: n, idx(:) + complex(psb_dpk_) :: beta,x(:), y(:) + integer(psb_ipk_) :: k + !$acc parallel loop + do k = 1, n + y(idx(k)) = x(k) + beta *y(idx(k)) end do - end select + !$acc end parallel loop + end subroutine inner_sctb + end subroutine z_oacc_sctb_buf subroutine z_oacc_sctb_x(i, n, idx, x, beta, y) @@ -430,24 +460,41 @@ contains class(psb_i_base_vect_type) :: idx complex(psb_dpk_) :: beta, x(:) class(psb_z_vect_oacc) :: y - integer(psb_ipk_) :: info, ni + integer(psb_ipk_) :: info, ni, k + logical :: acc_done + acc_done = .false. select type(ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() - class default - call psb_errpush(info, 'z_oacc_sctb_x') - return + if (y%is_host()) call y%sync() + if (acc_is_present(x)) then + call inner_sctb(n,x(i:i+n-1),beta,y%v,idx%v(i:i+n-1)) + acc_done = .true. + call y%set_dev() + end if end select + if (.not.acc_done) then + if (idx%is_dev()) call idx%sync() + if (y%is_dev()) call y%sync() + do k = 1, n + y%v(idx%v(k+i-1)) = beta * y%v(idx%v(k+i-1)) + x(k+i-1) + end do + call y%set_host() + end if - if (y%is_host()) call y%sync() - - !$acc parallel loop - do i = 1, n - y%v(idx%v(i)) = beta * y%v(idx%v(i)) + x(i) - end do - - call y%set_dev() + contains + subroutine inner_sctb(n,x,beta,y,idx) + integer(psb_ipk_) :: n, idx(:) + complex(psb_dpk_) :: beta, x(:), y(:) + integer(psb_ipk_) :: k + !$acc parallel loop + do k = 1, n + y(idx(k)) = x(k) + beta *y(idx(k)) + end do + !$acc end parallel loop + end subroutine inner_sctb + end subroutine z_oacc_sctb_x subroutine z_oacc_sctb(n, idx, x, beta, y) @@ -463,7 +510,6 @@ contains if (n == 0) return if (y%is_dev()) call y%sync() - !$acc parallel loop do i = 1, n y%v(idx(i)) = beta * y%v(idx(i)) + x(i) end do @@ -477,30 +523,48 @@ contains integer(psb_ipk_) :: i, n class(psb_i_base_vect_type) :: idx class(psb_z_vect_oacc) :: x - integer(psb_ipk_) :: info + integer(psb_ipk_) :: info,k + logical :: acc_done info = 0 + acc_done = .false. + if (.not.allocated(x%combuf)) then call psb_errpush(psb_err_alloc_dealloc_, 'gthzbuf') return end if - select type(ii => idx) + select type (ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() - class default - call psb_errpush(info, 'z_oacc_gthzbuf') - return + if (x%is_host()) call x%sync() + call inner_gth(n,x%v,x%combuf(i:i+n-1),ii%v(i:i+n-1)) + acc_done = .true. end select - if (x%is_host()) call x%sync() + if (.not.acc_done) then + if (idx%is_dev()) call idx%sync() + if (x%is_dev()) call x%sync() + do k = 1, n + x%combuf(k+i-1) = x%v(idx%v(k+i-1)) + end do + end if - !$acc parallel loop - do i = 1, n - x%combuf(i) = x%v(idx%v(i)) - end do + contains + subroutine inner_gth(n,x,y,idx) + integer(psb_ipk_) :: n, idx(:) + complex(psb_dpk_) :: x(:), y(:) + integer(psb_ipk_) :: k + + !$acc parallel loop present(y) + do k = 1, n + y(k) = x(idx(k)) + end do + !$acc end parallel loop + !$acc update self(y) + end subroutine inner_gth end subroutine z_oacc_gthzbuf - + subroutine z_oacc_gthzv_x(i, n, idx, x, y) use psb_base_mod implicit none @@ -508,24 +572,41 @@ contains class(psb_i_base_vect_type):: idx complex(psb_dpk_) :: y(:) class(psb_z_vect_oacc):: x - integer(psb_ipk_) :: info + integer(psb_ipk_) :: info, k + logical :: acc_done info = 0 - - select type(ii => idx) + acc_done = .false. + select type (ii => idx) class is (psb_i_vect_oacc) if (ii%is_host()) call ii%sync() - class default - call psb_errpush(info, 'z_oacc_gthzv_x') - return + if (x%is_host()) call x%sync() + if (acc_is_present(y)) then + call inner_gth(n,x%v,y(i:),ii%v(i:)) + acc_done=.true. + end if end select - - if (x%is_host()) call x%sync() - - !$acc parallel loop - do i = 1, n - y(i) = x%v(idx%v(i)) - end do + if (.not.acc_done) then + if (x%is_dev()) call x%sync() + if (idx%is_dev()) call idx%sync() + do k = 1, n + y(k+i-1) = x%v(idx%v(k+i-1)) + !write(0,*) 'oa gthzv ',k+i-1,idx%v(k+i-1),k,y(k) + end do + end if + contains + subroutine inner_gth(n,x,y,idx) + integer(psb_ipk_) :: n, idx(:) + complex(psb_dpk_) :: x(:), y(:) + integer(psb_ipk_) :: k + + !$acc parallel loop present(y) + do k = 1, n + y(k) = x(idx(k)) + end do + !$acc end parallel loop + !$acc update self(y) + end subroutine inner_gth end subroutine z_oacc_gthzv_x subroutine z_oacc_ins_v(n, irl, val, dupl, x, info) @@ -718,7 +799,7 @@ contains integer(psb_ipk_) :: info res = zzero - !write(0,*) 'dot_v' +!!$ write(0,*) 'oacc_dot_v' select type(yy => y) type is (psb_z_base_vect_type) if (x%is_dev()) call x%sync() @@ -762,6 +843,17 @@ contains end function z_oacc_dot_a + subroutine z_oacc_new_buffer(n,x,info) + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + integer(psb_ipk_), intent(out) :: info + if (n /= psb_size(x%combuf)) then + call x%psb_z_base_vect_type%new_buffer(n,info) + !$acc enter data copyin(x%combuf) + end if + end subroutine z_oacc_new_buffer + subroutine z_oacc_sync_dev_space(x) implicit none class(psb_z_vect_oacc), intent(inout) :: x @@ -860,12 +952,33 @@ contains class(psb_z_vect_oacc), intent(inout) :: x integer(psb_ipk_), intent(out) :: info info = 0 - if (allocated(x%v)) then - if (acc_is_present(x%v)) call acc_delete_finalize(x%v) - deallocate(x%v, stat=info) - end if + if (acc_is_present(x%v)) call acc_delete_finalize(x%v) + if (acc_is_present(x%combuf)) call acc_delete_finalize(x%combuf) + call x%psb_z_base_vect_type%free(info) end subroutine z_oacc_vect_free + + subroutine z_oacc_vect_maybe_free_buffer(x,info) + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + info = 0 + if (psb_oacc_get_maybe_free_buffer())& + & call x%free_buffer(info) + end subroutine z_oacc_vect_maybe_free_buffer + + subroutine z_oacc_vect_free_buffer(x,info) + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + + info = 0 + if (acc_is_present(x%combuf)) call acc_delete_finalize(x%combuf) + call x%psb_z_base_vect_type%free_buffer(info) + + end subroutine z_oacc_vect_free_buffer + function z_oacc_get_size(x) result(res) implicit none class(psb_z_vect_oacc), intent(inout) :: x