diff --git a/openacc/Makefile b/openacc/Makefile index d1311fe2..3a752ac7 100644 --- a/openacc/Makefile +++ b/openacc/Makefile @@ -17,9 +17,12 @@ CINCLUDES= #LIBS=-L$(LIBDIR) -lpsb_util -lpsb_ext -lpsb_base -lopenblas -lmetis -FOBJS= psb_i_oacc_vect_mod.o psb_d_oacc_vect_mod.o \ - psb_oacc_mod.o psb_d_oacc_csr_mat_mod.o \ - psb_oacc_env_mod.o +FOBJS= psb_i_oacc_vect_mod.o \ + psb_s_oacc_vect_mod.o psb_s_oacc_csr_mat_mod.o \ + psb_d_oacc_vect_mod.o psb_d_oacc_csr_mat_mod.o \ + psb_c_oacc_vect_mod.o psb_c_oacc_csr_mat_mod.o \ + psb_z_oacc_vect_mod.o psb_z_oacc_csr_mat_mod.o \ + psb_oacc_mod.o psb_oacc_env_mod.o LIBNAME=libpsb_openacc.a @@ -40,8 +43,23 @@ iobjs: $(OBJS) ilib: $(OBJS) $(MAKE) -C impl lib -psb_oacc_mod.o : psb_i_oacc_vect_mod.o psb_d_oacc_vect_mod.o \ - psb_d_oacc_csr_mat_mod.o psb_oacc_env_mod.o +psb_oacc_mod.o : psb_i_oacc_vect_mod.o \ + psb_s_oacc_vect_mod.o psb_s_oacc_csr_mat_mod.o \ + psb_d_oacc_vect_mod.o psb_d_oacc_csr_mat_mod.o \ + psb_c_oacc_vect_mod.o psb_c_oacc_csr_mat_mod.o \ + psb_z_oacc_vect_mod.o psb_z_oacc_csr_mat_mod.o \ + psb_oacc_env_mod.o + +psb_s_oacc_vect_mod.o psb_d_oacc_vect_mod.o \ + psb_c_oacc_vect_mod.o psb_z_oacc_vect_mod.o : psb_i_oacc_vect_mod.o + + +psb_s_oacc_csr_mat_mod.o: psb_s_oacc_vect_mod.o +psb_d_oacc_csr_mat_mod.o: psb_d_oacc_vect_mod.o +psb_c_oacc_csr_mat_mod.o: psb_c_oacc_vect_mod.o +psb_z_oacc_csr_mat_mod.o: psb_z_oacc_vect_mod.o + + clean: cclean iclean /bin/rm -f $(FOBJS) *$(.mod) *.a diff --git a/openacc/impl/psb_c_oacc_csr_allocate_mnnz.F90 b/openacc/impl/psb_c_oacc_csr_allocate_mnnz.F90 new file mode 100644 index 00000000..09cdc228 --- /dev/null +++ b/openacc/impl/psb_c_oacc_csr_allocate_mnnz.F90 @@ -0,0 +1,35 @@ +submodule (psb_c_oacc_csr_mat_mod) psb_c_oacc_csr_allocate_mnnz_impl + use psb_base_mod +contains + module subroutine psb_c_oacc_csr_allocate_mnnz(m, n, a, nz) + implicit none + integer(psb_ipk_), intent(in) :: m, n + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + integer(psb_ipk_), intent(in), optional :: nz + integer(psb_ipk_) :: info + integer(psb_ipk_) :: err_act, nz_ + character(len=20) :: name='allocate_mnz' + logical, parameter :: debug=.false. + + call psb_erractionsave(err_act) + info = psb_success_ + + call a%psb_c_csr_sparse_mat%allocate(m, n, nz) + + if (.not.allocated(a%val)) then + allocate(a%val(nz)) + allocate(a%ja(nz)) + allocate(a%irp(m+1)) + end if + + call a%set_dev() + if (info /= 0) goto 9999 + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + + end subroutine psb_c_oacc_csr_allocate_mnnz +end submodule psb_c_oacc_csr_allocate_mnnz_impl diff --git a/openacc/impl/psb_c_oacc_csr_cp_from_coo.F90 b/openacc/impl/psb_c_oacc_csr_cp_from_coo.F90 new file mode 100644 index 00000000..70380c95 --- /dev/null +++ b/openacc/impl/psb_c_oacc_csr_cp_from_coo.F90 @@ -0,0 +1,26 @@ +submodule (psb_c_oacc_csr_mat_mod) psb_c_oacc_csr_cp_from_coo_impl + use psb_base_mod +contains + module subroutine psb_c_oacc_csr_cp_from_coo(a, b, info) + implicit none + + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_c_coo_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + + call a%psb_c_csr_sparse_mat%cp_from_coo(b, info) + if (info /= 0) goto 9999 + + call a%set_dev() + if (info /= 0) goto 9999 + + return + +9999 continue + info = psb_err_alloc_dealloc_ + return + + end subroutine psb_c_oacc_csr_cp_from_coo +end submodule psb_c_oacc_csr_cp_from_coo_impl diff --git a/openacc/impl/psb_c_oacc_csr_cp_from_fmt.F90 b/openacc/impl/psb_c_oacc_csr_cp_from_fmt.F90 new file mode 100644 index 00000000..7e664791 --- /dev/null +++ b/openacc/impl/psb_c_oacc_csr_cp_from_fmt.F90 @@ -0,0 +1,24 @@ +submodule (psb_c_oacc_csr_mat_mod) psb_c_oacc_csr_cp_from_fmt_impl + use psb_base_mod +contains + module subroutine psb_c_oacc_csr_cp_from_fmt(a, b, info) + implicit none + + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_c_base_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + + select type(b) + type is (psb_c_coo_sparse_mat) + call a%cp_from_coo(b, info) + class default + call a%psb_c_csr_sparse_mat%cp_from_fmt(b, info) + if (info /= 0) return + + !$acc update device(a%val, a%ja, a%irp) + end select + + end subroutine psb_c_oacc_csr_cp_from_fmt +end submodule psb_c_oacc_csr_cp_from_fmt_impl diff --git a/openacc/impl/psb_c_oacc_csr_csmm.F90 b/openacc/impl/psb_c_oacc_csr_csmm.F90 new file mode 100644 index 00000000..c26df410 --- /dev/null +++ b/openacc/impl/psb_c_oacc_csr_csmm.F90 @@ -0,0 +1,86 @@ +submodule (psb_c_oacc_csr_mat_mod) psb_c_oacc_csr_csmm_impl + use psb_base_mod +contains + module subroutine psb_c_oacc_csr_csmm(alpha, a, x, beta, y, info, trans) + implicit none + class(psb_c_oacc_csr_sparse_mat), intent(in) :: a + complex(psb_spk_), intent(in) :: alpha, beta + complex(psb_spk_), intent(in) :: x(:,:) + complex(psb_spk_), intent(inout) :: y(:,:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + + character :: trans_ + integer(psb_ipk_) :: i, j, m, n,k, nxy + logical :: tra + integer(psb_ipk_) :: err_act + character(len=20) :: name = 'c_oacc_csmm' + logical, parameter :: debug = .false. + + info = psb_success_ + call psb_erractionsave(err_act) + + if (present(trans)) then + trans_ = trans + else + trans_ = 'N' + end if + + if (.not.a%is_asb()) then + info = psb_err_invalic_mat_state_ + call psb_errpush(info, name) + goto 9999 + endif + tra = (psb_toupper(trans_) == 'T') .or. (psb_toupper(trans_) == 'C') + + if (tra) then + m = a%get_ncols() + n = a%get_nrows() + else + n = a%get_ncols() + m = a%get_nrows() + end if + + if (size(x,1) < n) then + info = 36 + call psb_errpush(info, name, i_err = (/3 * ione, n, izero, izero, izero/)) + goto 9999 + end if + + if (size(y,1) < m) then + info = 36 + call psb_errpush(info, name, i_err = (/5 * ione, m, izero, izero, izero/)) + goto 9999 + end if + + if (tra) then + call a%psb_c_csr_sparse_mat%spmm(alpha, x, beta, y, info, trans) + else + nxy = min(size(x,2), size(y,2)) + + !$acc parallel loop collapse(2) present(a, x, y) + do j = 1, nxy + do i = 1, m + y(i,j) = beta * y(i,j) + end do + end do + + !$acc parallel loop collapse(2) present(a, x, y) + do j = 1, nxy + do i = 1, n + do k = a%irp(i), a%irp(i+1) - 1 + y(a%ja(k), j) = y(a%ja(k), j) + alpha * a%val(k) * x(i, j) + end do + end do + end do + endif + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + + end subroutine psb_c_oacc_csr_csmm +end submodule psb_c_oacc_csr_csmm_impl + diff --git a/openacc/impl/psb_c_oacc_csr_csmv.F90 b/openacc/impl/psb_c_oacc_csr_csmv.F90 new file mode 100644 index 00000000..8f37efb3 --- /dev/null +++ b/openacc/impl/psb_c_oacc_csr_csmv.F90 @@ -0,0 +1,81 @@ +submodule (psb_c_oacc_csr_mat_mod) psb_c_oacc_csr_csmv_impl + use psb_base_mod +contains + module subroutine psb_c_oacc_csr_csmv(alpha, a, x, beta, y, info, trans) + implicit none + class(psb_c_oacc_csr_sparse_mat), intent(in) :: a + complex(psb_spk_), intent(in) :: alpha, beta + complex(psb_spk_), intent(in) :: x(:) + complex(psb_spk_), intent(inout) :: y(:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + + character :: trans_ + integer(psb_ipk_) :: i, j, m, n + logical :: tra + integer(psb_ipk_) :: err_act + character(len=20) :: name = 'c_oacc_csmv' + logical, parameter :: debug = .false. + + call psb_erractionsave(err_act) + info = psb_success_ + + if (present(trans)) then + trans_ = trans + else + trans_ = 'N' + end if + + if (.not.a%is_asb()) then + info = psb_err_invalic_mat_state_ + call psb_errpush(info, name) + goto 9999 + endif + + tra = (psb_toupper(trans_) == 'T') .or. (psb_toupper(trans_) == 'C') + + if (tra) then + m = a%get_ncols() + n = a%get_nrows() + else + n = a%get_ncols() + m = a%get_nrows() + end if + + if (size(x,1) < n) then + info = 36 + call psb_errpush(info, name, i_err = (/3 * ione, n, izero, izero, izero/)) + goto 9999 + end if + + if (size(y,1) < m) then + info = 36 + call psb_errpush(info, name, i_err = (/5 * ione, m, izero, izero, izero/)) + goto 9999 + end if + + if (tra) then + call a%psb_c_csr_sparse_mat%spmm(alpha, x, beta, y, info, trans) + else + !$acc parallel loop present(a, x, y) + do i = 1, m + y(i) = beta * y(i) + end do + + !$acc parallel loop present(a, x, y) + do i = 1, n + do j = a%irp(i), a%irp(i+1) - 1 + y(a%ja(j)) = y(a%ja(j)) + alpha * a%val(j) * x(i) + end do + end do + endif + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + + end subroutine psb_c_oacc_csr_csmv +end submodule psb_c_oacc_csr_csmv_impl + diff --git a/openacc/impl/psb_c_oacc_csr_inner_vect_sv.F90 b/openacc/impl/psb_c_oacc_csr_inner_vect_sv.F90 new file mode 100644 index 00000000..2d733f48 --- /dev/null +++ b/openacc/impl/psb_c_oacc_csr_inner_vect_sv.F90 @@ -0,0 +1,83 @@ +submodule (psb_c_oacc_csr_mat_mod) psb_c_oacc_csr_inner_vect_sv_impl + use psb_base_mod +contains + module subroutine psb_c_oacc_csr_inner_vect_sv(alpha, a, x, beta, y, info, trans) + implicit none + class(psb_c_oacc_csr_sparse_mat), intent(in) :: a + complex(psb_spk_), intent(in) :: alpha, beta + class(psb_c_base_vect_type), intent(inout) :: x, y + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + + complex(psb_spk_), allocatable :: rx(:), ry(:) + logical :: tra + character :: trans_ + integer(psb_ipk_) :: err_act + character(len=20) :: name = 'c_oacc_csr_inner_vect_sv' + logical, parameter :: debug = .false. + integer(psb_ipk_) :: i + + call psb_get_erraction(err_act) + info = psb_success_ + + if (present(trans)) then + trans_ = trans + else + trans_ = 'N' + end if + + if (.not.a%is_asb()) then + info = psb_err_invalic_mat_state_ + call psb_errpush(info, name) + goto 9999 + endif + + tra = (psb_toupper(trans_) == 'T') .or. (psb_toupper(trans_) == 'C') + + if (tra .or. (beta /= dzero)) then + call x%sync() + call y%sync() + call a%psb_c_csr_sparse_mat%inner_spsm(alpha, x, beta, y, info, trans) + call y%set_host() + else + select type (xx => x) + type is (psb_c_vect_oacc) + select type(yy => y) + type is (psb_c_vect_oacc) + if (xx%is_host()) call xx%sync() + if (beta /= dzero) then + if (yy%is_host()) call yy%sync() + end if + !$acc parallel loop present(a, xx, yy) + do i = 1, size(a%val) + yy%v(i) = alpha * a%val(i) * xx%v(a%ja(i)) + beta * yy%v(i) + end do + call yy%set_dev() + class default + rx = xx%get_vect() + ry = y%get_vect() + call a%psb_c_csr_sparse_mat%inner_spsm(alpha, rx, beta, ry, info) + call y%bld(ry) + end select + class default + rx = x%get_vect() + ry = y%get_vect() + call a%psb_c_csr_sparse_mat%inner_spsm(alpha, rx, beta, ry, info) + call y%bld(ry) + end select + endif + + if (info /= psb_success_) then + info = psb_err_from_subroutine_ + call psb_errpush(info, name, a_err = 'csrg_vect_sv') + goto 9999 + endif + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + end subroutine psb_c_oacc_csr_inner_vect_sv +end submodule psb_c_oacc_csr_inner_vect_sv_impl + diff --git a/openacc/impl/psb_c_oacc_csr_mold.F90 b/openacc/impl/psb_c_oacc_csr_mold.F90 new file mode 100644 index 00000000..6ee36985 --- /dev/null +++ b/openacc/impl/psb_c_oacc_csr_mold.F90 @@ -0,0 +1,35 @@ +submodule (psb_c_oacc_csr_mat_mod) psb_c_oacc_csr_mold_impl + use psb_base_mod +contains + module subroutine psb_c_oacc_csr_mold(a, b, info) + implicit none + class(psb_c_oacc_csr_sparse_mat), intent(in) :: a + class(psb_c_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: err_act + character(len=20) :: name='csr_mold' + logical, parameter :: debug=.false. + + call psb_get_erraction(err_act) + + info = 0 + if (allocated(b)) then + call b%free() + deallocate(b, stat=info) + end if + if (info == 0) allocate(psb_c_oacc_csr_sparse_mat :: b, stat=info) + + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info, name) + goto 9999 + end if + return + +9999 call psb_error_handler(err_act) + + return + + end subroutine psb_c_oacc_csr_mold +end submodule psb_c_oacc_csr_molc_impl + diff --git a/openacc/impl/psb_c_oacc_csr_mv_from_coo.F90 b/openacc/impl/psb_c_oacc_csr_mv_from_coo.F90 new file mode 100644 index 00000000..f8c5c39d --- /dev/null +++ b/openacc/impl/psb_c_oacc_csr_mv_from_coo.F90 @@ -0,0 +1,25 @@ +submodule (psb_c_oacc_csr_mat_mod) psb_c_oacc_csr_mv_from_coo_impl + use psb_base_mod +contains + module subroutine psb_c_oacc_csr_mv_from_coo(a, b, info) + implicit none + + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_c_coo_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + + call a%psb_c_csr_sparse_mat%mv_from_coo(b, info) + if (info /= 0) goto 9999 + + !$acc update device(a%val, a%ja, a%irp) + + return + +9999 continue + info = psb_err_alloc_dealloc_ + return + + end subroutine psb_c_oacc_csr_mv_from_coo +end submodule psb_c_oacc_csr_mv_from_coo_impl diff --git a/openacc/impl/psb_c_oacc_csr_mv_from_fmt.F90 b/openacc/impl/psb_c_oacc_csr_mv_from_fmt.F90 new file mode 100644 index 00000000..7ba971b4 --- /dev/null +++ b/openacc/impl/psb_c_oacc_csr_mv_from_fmt.F90 @@ -0,0 +1,24 @@ +submodule (psb_c_oacc_csr_mat_mod) psb_c_oacc_csr_mv_from_fmt_impl + use psb_base_mod +contains + module subroutine psb_c_oacc_csr_mv_from_fmt(a, b, info) + implicit none + + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_c_base_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + + select type(b) + type is (psb_c_coo_sparse_mat) + call a%mv_from_coo(b, info) + class default + call a%psb_c_csr_sparse_mat%mv_from_fmt(b, info) + if (info /= 0) return + + !$acc update device(a%val, a%ja, a%irp) + end select + + end subroutine psb_c_oacc_csr_mv_from_fmt +end submodule psb_c_oacc_csr_mv_from_fmt_impl diff --git a/openacc/impl/psb_c_oacc_csr_reallocate_nz.F90 b/openacc/impl/psb_c_oacc_csr_reallocate_nz.F90 new file mode 100644 index 00000000..92a53116 --- /dev/null +++ b/openacc/impl/psb_c_oacc_csr_reallocate_nz.F90 @@ -0,0 +1,28 @@ +submodule (psb_c_oacc_csr_mat_mod) psb_c_oacc_csr_reallocate_nz_impl + use psb_base_mod +contains + module subroutine psb_c_oacc_csr_reallocate_nz(nz, a) + implicit none + integer(psb_ipk_), intent(in) :: nz + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + integer(psb_ipk_) :: info + integer(psb_ipk_) :: err_act + character(len=20) :: name='c_oacc_csr_reallocate_nz' + logical, parameter :: debug=.false. + + call psb_erractionsave(err_act) + info = psb_success_ + + call a%psb_c_csr_sparse_mat%reallocate(nz) + + call a%set_dev() + if (info /= 0) goto 9999 + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + + end subroutine psb_c_oacc_csr_reallocate_nz +end submodule psb_c_oacc_csr_reallocate_nz_impl diff --git a/openacc/impl/psb_c_oacc_csr_scal.F90 b/openacc/impl/psb_c_oacc_csr_scal.F90 new file mode 100644 index 00000000..5dece48b --- /dev/null +++ b/openacc/impl/psb_c_oacc_csr_scal.F90 @@ -0,0 +1,53 @@ +submodule (psb_c_oacc_csr_mat_mod) psb_c_oacc_csr_scal_impl + use psb_base_mod +contains + module subroutine psb_c_oacc_csr_scal(d, a, info, side) + implicit none + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + complex(psb_spk_), intent(in) :: d(:) + integer(psb_ipk_), intent(out) :: info + character, intent(in), optional :: side + + integer(psb_ipk_) :: err_act + character(len=20) :: name='scal' + logical, parameter :: debug=.false. + integer(psb_ipk_) :: i, j + + info = psb_success_ + call psb_erractionsave(err_act) + + if (a%is_host()) call a%sync() + + if (present(side)) then + if (side == 'L') then + !$acc parallel loop present(a, d) + do i = 1, a%get_nrows() + do j = a%irp(i), a%irp(i+1) - 1 + a%val(j) = a%val(j) * d(i) + end do + end do + else if (side == 'R') then + !$acc parallel loop present(a, d) + do i = 1, a%get_ncols() + do j = a%irp(i), a%irp(i+1) - 1 + a%val(j) = a%val(j) * d(a%ja(j)) + end do + end do + end if + else + !$acc parallel loop present(a, d) + do i = 1, size(a%val) + a%val(i) = a%val(i) * d(i) + end do + end if + + call a%set_dev() + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + + end subroutine psb_c_oacc_csr_scal +end submodule psb_c_oacc_csr_scal_impl diff --git a/openacc/impl/psb_c_oacc_csr_scals.F90 b/openacc/impl/psb_c_oacc_csr_scals.F90 new file mode 100644 index 00000000..aba22d43 --- /dev/null +++ b/openacc/impl/psb_c_oacc_csr_scals.F90 @@ -0,0 +1,34 @@ +submodule (psb_c_oacc_csr_mat_mod) psb_c_oacc_csr_scals_impl + use psb_base_mod +contains + module subroutine psb_c_oacc_csr_scals(d, a, info) + implicit none + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + complex(psb_spk_), intent(in) :: d + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: err_act + character(len=20) :: name='scal' + logical, parameter :: debug=.false. + integer(psb_ipk_) :: i + + info = psb_success_ + call psb_erractionsave(err_act) + + if (a%is_host()) call a%sync() + + !$acc parallel loop present(a) + do i = 1, size(a%val) + a%val(i) = a%val(i) * d + end do + + call a%set_dev() + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + + end subroutine psb_c_oacc_csr_scals +end submodule psb_c_oacc_csr_scals_impl diff --git a/openacc/impl/psb_c_oacc_csr_vect_mv.F90 b/openacc/impl/psb_c_oacc_csr_vect_mv.F90 new file mode 100644 index 00000000..b4b79d56 --- /dev/null +++ b/openacc/impl/psb_c_oacc_csr_vect_mv.F90 @@ -0,0 +1,63 @@ +submodule (psb_c_oacc_csr_mat_mod) psb_c_oacc_csr_vect_mv_impl + use psb_base_mod +contains + module subroutine psb_c_oacc_csr_vect_mv(alpha, a, x, beta, y, info, trans) + implicit none + + complex(psb_spk_), intent(in) :: alpha, beta + class(psb_c_oacc_csr_sparse_mat), intent(in) :: a + class(psb_c_base_vect_type), intent(inout) :: x, y + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + + integer(psb_ipk_) :: m, n + + info = psb_success_ + m = a%get_nrows() + n = a%get_ncols() + + if ((n /= size(x%v)) .or. (n /= size(y%v))) then + write(0,*) 'Size error ', m, n, size(x%v), size(y%v) + info = psb_err_invalic_mat_state_ + return + end if + + if (a%is_host()) call a%sync() + if (x%is_host()) call x%sync() + if (y%is_host()) call y%sync() + + call inner_spmv(m, n, alpha, a%val, a%ja, a%irp, x%v, beta, y%v, info) + call y%set_dev() + + contains + + subroutine inner_spmv(m, n, alpha, val, ja, irp, x, beta, y, info) + implicit none + integer(psb_ipk_) :: m, n + complex(psb_spk_), intent(in) :: alpha, beta + complex(psb_spk_) :: val(:), x(:), y(:) + integer(psb_ipk_) :: ja(:), irp(:) + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: i, j, ii, isz + complex(psb_spk_) :: tmp + integer(psb_ipk_), parameter :: vsz = 256 + + info = 0 + + !$acc parallel loop vector_length(vsz) private(isz) + do ii = 1, m, vsz + isz = min(vsz, m - ii + 1) + !$acc loop independent private(tmp) + do i = ii, ii + isz - 1 + tmp = 0.0_psb_dpk_ + !$acc loop seq + do j = irp(i), irp(i + 1) - 1 + tmp = tmp + val(j) * x(ja(j)) + end do + y(i) = alpha * tmp + beta * y(i) + end do + end do + end subroutine inner_spmv + + end subroutine psb_c_oacc_csr_vect_mv +end submodule psb_c_oacc_csr_vect_mv_impl diff --git a/openacc/impl/psb_c_oacc_mlt_v.f90 b/openacc/impl/psb_c_oacc_mlt_v.f90 new file mode 100644 index 00000000..66c4e865 --- /dev/null +++ b/openacc/impl/psb_c_oacc_mlt_v.f90 @@ -0,0 +1,31 @@ + +subroutine c_oacc_mlt_v(x, y, info) + use psb_c_oacc_vect_mod, psb_protect_name => c_oacc_mlt_v + + implicit none + class(psb_c_base_vect_type), intent(inout) :: x + class(psb_c_vect_oacc), intent(inout) :: y + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i, n + + info = 0 + n = min(x%get_nrows(), y%get_nrows()) + select type(xx => x) + class is (psb_c_vect_oacc) + if (y%is_host()) call y%sync() + if (xx%is_host()) call xx%sync() + !$acc parallel loop + do i = 1, n + y%v(i) = y%v(i) * xx%v(i) + end do + call y%set_dev() + class default + if (xx%is_dev()) call xx%sync() + if (y%is_dev()) call y%sync() + do i = 1, n + y%v(i) = y%v(i) * xx%v(i) + end do + call y%set_host() + end select +end subroutine c_oacc_mlt_v diff --git a/openacc/impl/psb_c_oacc_mlt_v_2.f90 b/openacc/impl/psb_c_oacc_mlt_v_2.f90 new file mode 100644 index 00000000..a6bb6cc5 --- /dev/null +++ b/openacc/impl/psb_c_oacc_mlt_v_2.f90 @@ -0,0 +1,98 @@ +subroutine c_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) + use psb_c_oacc_vect_mod, psb_protect_name => c_oacc_mlt_v_2 + use psb_string_mod + implicit none + complex(psb_spk_), intent(in) :: alpha, beta + class(psb_c_base_vect_type), intent(inout) :: x + class(psb_c_base_vect_type), intent(inout) :: y + class(psb_c_vect_oacc), intent(inout) :: z + integer(psb_ipk_), intent(out) :: info + character(len=1), intent(in), optional :: conjgx, conjgy + integer(psb_ipk_) :: i, n + logical :: conjgx_, conjgy_ + + conjgx_ = .false. + conjgy_ = .false. + if (present(conjgx)) conjgx_ = (psb_toupper(conjgx) == 'C') + if (present(conjgy)) conjgy_ = (psb_toupper(conjgy) == 'C') + + n = min(x%get_nrows(), y%get_nrows(), z%get_nrows()) + info = 0 + select type(xx => x) + class is (psb_c_vect_oacc) + select type (yy => y) + class is (psb_c_vect_oacc) + if (xx%is_host()) call xx%sync() + if (yy%is_host()) call yy%sync() + if ((beta /= czero) .and. (z%is_host())) call z%sync() + if (conjgx_.and.conjgy_) then + !$acc parallel loop + do i = 1, n + z%v(i) = alpha * conjg(xx%v(i)) * conjg(yy%v(i)) + beta * z%v(i) + end do + else if (conjgx_.and.(.not.conjgy_)) then + !$acc parallel loop + do i = 1, n + z%v(i) = alpha * conjg(xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + else if ((.not.conjgx_).and.(conjgy_)) then + !$acc parallel loop + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * conjg(yy%v(i)) + beta * z%v(i) + end do + else + !$acc parallel loop + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + + end if + call z%set_dev() + class default + if (xx%is_dev()) call xx%sync() + if (yy%is_dev()) call yy%sync() + if ((beta /= czero) .and. (z%is_dev())) call z%sync() + if (conjgx_.and.conjgy_) then + do i = 1, n + z%v(i) = alpha * conjg(xx%v(i)) * conjg(yy%v(i)) + beta * z%v(i) + end do + else if (conjgx_.and.(.not.conjgy_)) then + do i = 1, n + z%v(i) = alpha * conjg(xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + else if ((.not.conjgx_).and.(conjgy_)) then + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * conjg(yy%v(i)) + beta * z%v(i) + end do + else + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + end if + call z%set_host() + end select + class default + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + if ((beta /= czero) .and. (z%is_dev())) call z%sync() + if (conjgx_.and.conjgy_) then + do i = 1, n + z%v(i) = alpha * conjg(x%v(i)) * conjg(y%v(i)) + beta * z%v(i) + end do + else if (conjgx_.and.(.not.conjgy_)) then + do i = 1, n + z%v(i) = alpha * conjg(x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + else if ((.not.conjgx_).and.(conjgy_)) then + do i = 1, n + z%v(i) = alpha * (x%v(i)) * conjg(y%v(i)) + beta * z%v(i) + end do + else + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + end if + call z%set_host() + end select +end subroutine c_oacc_mlt_v_2 + diff --git a/openacc/impl/psb_d_oacc_mlt_v.f90 b/openacc/impl/psb_d_oacc_mlt_v.f90 index ab242b57..bedd0247 100644 --- a/openacc/impl/psb_d_oacc_mlt_v.f90 +++ b/openacc/impl/psb_d_oacc_mlt_v.f90 @@ -1,34 +1,31 @@ -submodule (psb_d_oacc_vect_mod) psb_d_oacc_mlt_v_impl - use psb_string_mod -contains - module subroutine psb_d_oacc_mlt_v(x, y, info) +subroutine d_oacc_mlt_v(x, y, info) + use psb_d_oacc_vect_mod, psb_protect_name => d_oacc_mlt_v - implicit none - class(psb_d_base_vect_type), intent(inout) :: x - class(psb_d_vect_oacc), intent(inout) :: y - integer(psb_ipk_), intent(out) :: info + implicit none + class(psb_d_base_vect_type), intent(inout) :: x + class(psb_d_vect_oacc), intent(inout) :: y + integer(psb_ipk_), intent(out) :: info - integer(psb_ipk_) :: i, n + integer(psb_ipk_) :: i, n - info = 0 - n = min(x%get_nrows(), y%get_nrows()) - select type(xx => x) - class is (psb_d_vect_oacc) - if (y%is_host()) call y%sync() - if (xx%is_host()) call xx%sync() - !$acc parallel loop - do i = 1, n - y%v(i) = y%v(i) * xx%v(i) - end do - call y%set_dev() - class default - if (xx%is_dev()) call xx%sync() - if (y%is_dev()) call y%sync() - do i = 1, n - y%v(i) = y%v(i) * xx%v(i) - end do - call y%set_host() - end select - end subroutine psb_d_oacc_mlt_v -end submodule psb_d_oacc_mlt_v_impl + info = 0 + n = min(x%get_nrows(), y%get_nrows()) + select type(xx => x) + class is (psb_d_vect_oacc) + if (y%is_host()) call y%sync() + if (xx%is_host()) call xx%sync() + !$acc parallel loop + do i = 1, n + y%v(i) = y%v(i) * xx%v(i) + end do + call y%set_dev() + class default + if (xx%is_dev()) call xx%sync() + if (y%is_dev()) call y%sync() + do i = 1, n + y%v(i) = y%v(i) * xx%v(i) + end do + call y%set_host() + end select +end subroutine d_oacc_mlt_v diff --git a/openacc/impl/psb_d_oacc_mlt_v_2.f90 b/openacc/impl/psb_d_oacc_mlt_v_2.f90 index 4ca2bdab..e7dd604f 100644 --- a/openacc/impl/psb_d_oacc_mlt_v_2.f90 +++ b/openacc/impl/psb_d_oacc_mlt_v_2.f90 @@ -1,55 +1,98 @@ -submodule (psb_d_oacc_vect_mod) d_oacc_mlt_v_2_impl +subroutine d_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) + use psb_d_oacc_vect_mod, psb_protect_name => d_oacc_mlt_v_2 use psb_string_mod -contains - module subroutine d_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) - implicit none - real(psb_dpk_), intent(in) :: alpha, beta - class(psb_d_base_vect_type), intent(inout) :: x - class(psb_d_base_vect_type), intent(inout) :: y - class(psb_d_vect_oacc), intent(inout) :: z - integer(psb_ipk_), intent(out) :: info - character(len=1), intent(in), optional :: conjgx, conjgy - integer(psb_ipk_) :: i, n - logical :: conjgx_, conjgy_ + implicit none + real(psb_dpk_), intent(in) :: alpha, beta + class(psb_d_base_vect_type), intent(inout) :: x + class(psb_d_base_vect_type), intent(inout) :: y + class(psb_d_vect_oacc), intent(inout) :: z + integer(psb_ipk_), intent(out) :: info + character(len=1), intent(in), optional :: conjgx, conjgy + integer(psb_ipk_) :: i, n + logical :: conjgx_, conjgy_ - conjgx_ = .false. - conjgy_ = .false. - if (present(conjgx)) conjgx_ = (psb_toupper(conjgx) == 'C') - if (present(conjgy)) conjgy_ = (psb_toupper(conjgy) == 'C') + conjgx_ = .false. + conjgy_ = .false. + if (present(conjgx)) conjgx_ = (psb_toupper(conjgx) == 'C') + if (present(conjgy)) conjgy_ = (psb_toupper(conjgy) == 'C') - n = min(x%get_nrows(), y%get_nrows(), z%get_nrows()) - - info = 0 - select type(xx => x) + n = min(x%get_nrows(), y%get_nrows(), z%get_nrows()) + info = 0 + select type(xx => x) + class is (psb_d_vect_oacc) + select type (yy => y) class is (psb_d_vect_oacc) - select type (yy => y) - class is (psb_d_vect_oacc) - if (xx%is_host()) call xx%sync() - if (yy%is_host()) call yy%sync() - if ((beta /= dzero) .and. (z%is_host())) call z%sync() + if (xx%is_host()) call xx%sync() + if (yy%is_host()) call yy%sync() + if ((beta /= dzero) .and. (z%is_host())) call z%sync() + if (conjgx_.and.conjgy_) then + !$acc parallel loop + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + else if (conjgx_.and.(.not.conjgy_)) then !$acc parallel loop do i = 1, n - z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) end do - call z%set_dev() - class default - if (xx%is_dev()) call xx%sync() - if (yy%is_dev()) call yy%sync() - if ((beta /= dzero) .and. (z%is_dev())) call z%sync() + else if ((.not.conjgx_).and.(conjgy_)) then + !$acc parallel loop do i = 1, n - z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) end do - call z%set_host() - end select + else + !$acc parallel loop + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + + end if + call z%set_dev() class default - if (x%is_dev()) call x%sync() - if (y%is_dev()) call y%sync() + if (xx%is_dev()) call xx%sync() + if (yy%is_dev()) call yy%sync() if ((beta /= dzero) .and. (z%is_dev())) call z%sync() - do i = 1, n - z%v(i) = alpha * x%v(i) * y%v(i) + beta * z%v(i) - end do + if (conjgx_.and.conjgy_) then + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + else if (conjgx_.and.(.not.conjgy_)) then + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + else if ((.not.conjgx_).and.(conjgy_)) then + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + else + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + end if call z%set_host() end select - end subroutine d_oacc_mlt_v_2 -end submodule d_oacc_mlt_v_2_impl + class default + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + if ((beta /= dzero) .and. (z%is_dev())) call z%sync() + if (conjgx_.and.conjgy_) then + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + else if (conjgx_.and.(.not.conjgy_)) then + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + else if ((.not.conjgx_).and.(conjgy_)) then + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + else + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + end if + call z%set_host() + end select +end subroutine d_oacc_mlt_v_2 diff --git a/openacc/impl/psb_s_oacc_csr_allocate_mnnz.F90 b/openacc/impl/psb_s_oacc_csr_allocate_mnnz.F90 new file mode 100644 index 00000000..08c51bce --- /dev/null +++ b/openacc/impl/psb_s_oacc_csr_allocate_mnnz.F90 @@ -0,0 +1,35 @@ +submodule (psb_s_oacc_csr_mat_mod) psb_s_oacc_csr_allocate_mnnz_impl + use psb_base_mod +contains + module subroutine psb_s_oacc_csr_allocate_mnnz(m, n, a, nz) + implicit none + integer(psb_ipk_), intent(in) :: m, n + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + integer(psb_ipk_), intent(in), optional :: nz + integer(psb_ipk_) :: info + integer(psb_ipk_) :: err_act, nz_ + character(len=20) :: name='allocate_mnz' + logical, parameter :: debug=.false. + + call psb_erractionsave(err_act) + info = psb_success_ + + call a%psb_s_csr_sparse_mat%allocate(m, n, nz) + + if (.not.allocated(a%val)) then + allocate(a%val(nz)) + allocate(a%ja(nz)) + allocate(a%irp(m+1)) + end if + + call a%set_dev() + if (info /= 0) goto 9999 + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + + end subroutine psb_s_oacc_csr_allocate_mnnz +end submodule psb_s_oacc_csr_allocate_mnnz_impl diff --git a/openacc/impl/psb_s_oacc_csr_cp_from_coo.F90 b/openacc/impl/psb_s_oacc_csr_cp_from_coo.F90 new file mode 100644 index 00000000..94ef67b3 --- /dev/null +++ b/openacc/impl/psb_s_oacc_csr_cp_from_coo.F90 @@ -0,0 +1,26 @@ +submodule (psb_s_oacc_csr_mat_mod) psb_s_oacc_csr_cp_from_coo_impl + use psb_base_mod +contains + module subroutine psb_s_oacc_csr_cp_from_coo(a, b, info) + implicit none + + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_s_coo_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + + call a%psb_s_csr_sparse_mat%cp_from_coo(b, info) + if (info /= 0) goto 9999 + + call a%set_dev() + if (info /= 0) goto 9999 + + return + +9999 continue + info = psb_err_alloc_dealloc_ + return + + end subroutine psb_s_oacc_csr_cp_from_coo +end submodule psb_s_oacc_csr_cp_from_coo_impl diff --git a/openacc/impl/psb_s_oacc_csr_cp_from_fmt.F90 b/openacc/impl/psb_s_oacc_csr_cp_from_fmt.F90 new file mode 100644 index 00000000..2c64b5fe --- /dev/null +++ b/openacc/impl/psb_s_oacc_csr_cp_from_fmt.F90 @@ -0,0 +1,24 @@ +submodule (psb_s_oacc_csr_mat_mod) psb_s_oacc_csr_cp_from_fmt_impl + use psb_base_mod +contains + module subroutine psb_s_oacc_csr_cp_from_fmt(a, b, info) + implicit none + + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_s_base_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + + select type(b) + type is (psb_s_coo_sparse_mat) + call a%cp_from_coo(b, info) + class default + call a%psb_s_csr_sparse_mat%cp_from_fmt(b, info) + if (info /= 0) return + + !$acc update device(a%val, a%ja, a%irp) + end select + + end subroutine psb_s_oacc_csr_cp_from_fmt +end submodule psb_s_oacc_csr_cp_from_fmt_impl diff --git a/openacc/impl/psb_s_oacc_csr_csmm.F90 b/openacc/impl/psb_s_oacc_csr_csmm.F90 new file mode 100644 index 00000000..2e7def53 --- /dev/null +++ b/openacc/impl/psb_s_oacc_csr_csmm.F90 @@ -0,0 +1,86 @@ +submodule (psb_s_oacc_csr_mat_mod) psb_s_oacc_csr_csmm_impl + use psb_base_mod +contains + module subroutine psb_s_oacc_csr_csmm(alpha, a, x, beta, y, info, trans) + implicit none + class(psb_s_oacc_csr_sparse_mat), intent(in) :: a + real(psb_spk_), intent(in) :: alpha, beta + real(psb_spk_), intent(in) :: x(:,:) + real(psb_spk_), intent(inout) :: y(:,:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + + character :: trans_ + integer(psb_ipk_) :: i, j, m, n,k, nxy + logical :: tra + integer(psb_ipk_) :: err_act + character(len=20) :: name = 's_oacc_csmm' + logical, parameter :: debug = .false. + + info = psb_success_ + call psb_erractionsave(err_act) + + if (present(trans)) then + trans_ = trans + else + trans_ = 'N' + end if + + if (.not.a%is_asb()) then + info = psb_err_invalis_mat_state_ + call psb_errpush(info, name) + goto 9999 + endif + tra = (psb_toupper(trans_) == 'T') .or. (psb_toupper(trans_) == 'C') + + if (tra) then + m = a%get_ncols() + n = a%get_nrows() + else + n = a%get_ncols() + m = a%get_nrows() + end if + + if (size(x,1) < n) then + info = 36 + call psb_errpush(info, name, i_err = (/3 * ione, n, izero, izero, izero/)) + goto 9999 + end if + + if (size(y,1) < m) then + info = 36 + call psb_errpush(info, name, i_err = (/5 * ione, m, izero, izero, izero/)) + goto 9999 + end if + + if (tra) then + call a%psb_s_csr_sparse_mat%spmm(alpha, x, beta, y, info, trans) + else + nxy = min(size(x,2), size(y,2)) + + !$acc parallel loop collapse(2) present(a, x, y) + do j = 1, nxy + do i = 1, m + y(i,j) = beta * y(i,j) + end do + end do + + !$acc parallel loop collapse(2) present(a, x, y) + do j = 1, nxy + do i = 1, n + do k = a%irp(i), a%irp(i+1) - 1 + y(a%ja(k), j) = y(a%ja(k), j) + alpha * a%val(k) * x(i, j) + end do + end do + end do + endif + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + + end subroutine psb_s_oacc_csr_csmm +end submodule psb_s_oacc_csr_csmm_impl + diff --git a/openacc/impl/psb_s_oacc_csr_csmv.F90 b/openacc/impl/psb_s_oacc_csr_csmv.F90 new file mode 100644 index 00000000..ba673941 --- /dev/null +++ b/openacc/impl/psb_s_oacc_csr_csmv.F90 @@ -0,0 +1,81 @@ +submodule (psb_s_oacc_csr_mat_mod) psb_s_oacc_csr_csmv_impl + use psb_base_mod +contains + module subroutine psb_s_oacc_csr_csmv(alpha, a, x, beta, y, info, trans) + implicit none + class(psb_s_oacc_csr_sparse_mat), intent(in) :: a + real(psb_spk_), intent(in) :: alpha, beta + real(psb_spk_), intent(in) :: x(:) + real(psb_spk_), intent(inout) :: y(:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + + character :: trans_ + integer(psb_ipk_) :: i, j, m, n + logical :: tra + integer(psb_ipk_) :: err_act + character(len=20) :: name = 's_oacc_csmv' + logical, parameter :: debug = .false. + + call psb_erractionsave(err_act) + info = psb_success_ + + if (present(trans)) then + trans_ = trans + else + trans_ = 'N' + end if + + if (.not.a%is_asb()) then + info = psb_err_invalis_mat_state_ + call psb_errpush(info, name) + goto 9999 + endif + + tra = (psb_toupper(trans_) == 'T') .or. (psb_toupper(trans_) == 'C') + + if (tra) then + m = a%get_ncols() + n = a%get_nrows() + else + n = a%get_ncols() + m = a%get_nrows() + end if + + if (size(x,1) < n) then + info = 36 + call psb_errpush(info, name, i_err = (/3 * ione, n, izero, izero, izero/)) + goto 9999 + end if + + if (size(y,1) < m) then + info = 36 + call psb_errpush(info, name, i_err = (/5 * ione, m, izero, izero, izero/)) + goto 9999 + end if + + if (tra) then + call a%psb_s_csr_sparse_mat%spmm(alpha, x, beta, y, info, trans) + else + !$acc parallel loop present(a, x, y) + do i = 1, m + y(i) = beta * y(i) + end do + + !$acc parallel loop present(a, x, y) + do i = 1, n + do j = a%irp(i), a%irp(i+1) - 1 + y(a%ja(j)) = y(a%ja(j)) + alpha * a%val(j) * x(i) + end do + end do + endif + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + + end subroutine psb_s_oacc_csr_csmv +end submodule psb_s_oacc_csr_csmv_impl + diff --git a/openacc/impl/psb_s_oacc_csr_inner_vect_sv.F90 b/openacc/impl/psb_s_oacc_csr_inner_vect_sv.F90 new file mode 100644 index 00000000..7af897a7 --- /dev/null +++ b/openacc/impl/psb_s_oacc_csr_inner_vect_sv.F90 @@ -0,0 +1,83 @@ +submodule (psb_s_oacc_csr_mat_mod) psb_s_oacc_csr_inner_vect_sv_impl + use psb_base_mod +contains + module subroutine psb_s_oacc_csr_inner_vect_sv(alpha, a, x, beta, y, info, trans) + implicit none + class(psb_s_oacc_csr_sparse_mat), intent(in) :: a + real(psb_spk_), intent(in) :: alpha, beta + class(psb_s_base_vect_type), intent(inout) :: x, y + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + + real(psb_spk_), allocatable :: rx(:), ry(:) + logical :: tra + character :: trans_ + integer(psb_ipk_) :: err_act + character(len=20) :: name = 's_oacc_csr_inner_vect_sv' + logical, parameter :: debug = .false. + integer(psb_ipk_) :: i + + call psb_get_erraction(err_act) + info = psb_success_ + + if (present(trans)) then + trans_ = trans + else + trans_ = 'N' + end if + + if (.not.a%is_asb()) then + info = psb_err_invalis_mat_state_ + call psb_errpush(info, name) + goto 9999 + endif + + tra = (psb_toupper(trans_) == 'T') .or. (psb_toupper(trans_) == 'C') + + if (tra .or. (beta /= dzero)) then + call x%sync() + call y%sync() + call a%psb_s_csr_sparse_mat%inner_spsm(alpha, x, beta, y, info, trans) + call y%set_host() + else + select type (xx => x) + type is (psb_s_vect_oacc) + select type(yy => y) + type is (psb_s_vect_oacc) + if (xx%is_host()) call xx%sync() + if (beta /= dzero) then + if (yy%is_host()) call yy%sync() + end if + !$acc parallel loop present(a, xx, yy) + do i = 1, size(a%val) + yy%v(i) = alpha * a%val(i) * xx%v(a%ja(i)) + beta * yy%v(i) + end do + call yy%set_dev() + class default + rx = xx%get_vect() + ry = y%get_vect() + call a%psb_s_csr_sparse_mat%inner_spsm(alpha, rx, beta, ry, info) + call y%bld(ry) + end select + class default + rx = x%get_vect() + ry = y%get_vect() + call a%psb_s_csr_sparse_mat%inner_spsm(alpha, rx, beta, ry, info) + call y%bld(ry) + end select + endif + + if (info /= psb_success_) then + info = psb_err_from_subroutine_ + call psb_errpush(info, name, a_err = 'csrg_vect_sv') + goto 9999 + endif + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + end subroutine psb_s_oacc_csr_inner_vect_sv +end submodule psb_s_oacc_csr_inner_vect_sv_impl + diff --git a/openacc/impl/psb_s_oacc_csr_mold.F90 b/openacc/impl/psb_s_oacc_csr_mold.F90 new file mode 100644 index 00000000..a85471e5 --- /dev/null +++ b/openacc/impl/psb_s_oacc_csr_mold.F90 @@ -0,0 +1,35 @@ +submodule (psb_s_oacc_csr_mat_mod) psb_s_oacc_csr_mold_impl + use psb_base_mod +contains + module subroutine psb_s_oacc_csr_mold(a, b, info) + implicit none + class(psb_s_oacc_csr_sparse_mat), intent(in) :: a + class(psb_s_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: err_act + character(len=20) :: name='csr_mold' + logical, parameter :: debug=.false. + + call psb_get_erraction(err_act) + + info = 0 + if (allocated(b)) then + call b%free() + deallocate(b, stat=info) + end if + if (info == 0) allocate(psb_s_oacc_csr_sparse_mat :: b, stat=info) + + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info, name) + goto 9999 + end if + return + +9999 call psb_error_handler(err_act) + + return + + end subroutine psb_s_oacc_csr_mold +end submodule psb_s_oacc_csr_mols_impl + diff --git a/openacc/impl/psb_s_oacc_csr_mv_from_coo.F90 b/openacc/impl/psb_s_oacc_csr_mv_from_coo.F90 new file mode 100644 index 00000000..e531d309 --- /dev/null +++ b/openacc/impl/psb_s_oacc_csr_mv_from_coo.F90 @@ -0,0 +1,25 @@ +submodule (psb_s_oacc_csr_mat_mod) psb_s_oacc_csr_mv_from_coo_impl + use psb_base_mod +contains + module subroutine psb_s_oacc_csr_mv_from_coo(a, b, info) + implicit none + + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_s_coo_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + + call a%psb_s_csr_sparse_mat%mv_from_coo(b, info) + if (info /= 0) goto 9999 + + !$acc update device(a%val, a%ja, a%irp) + + return + +9999 continue + info = psb_err_alloc_dealloc_ + return + + end subroutine psb_s_oacc_csr_mv_from_coo +end submodule psb_s_oacc_csr_mv_from_coo_impl diff --git a/openacc/impl/psb_s_oacc_csr_mv_from_fmt.F90 b/openacc/impl/psb_s_oacc_csr_mv_from_fmt.F90 new file mode 100644 index 00000000..a9dc0c70 --- /dev/null +++ b/openacc/impl/psb_s_oacc_csr_mv_from_fmt.F90 @@ -0,0 +1,24 @@ +submodule (psb_s_oacc_csr_mat_mod) psb_s_oacc_csr_mv_from_fmt_impl + use psb_base_mod +contains + module subroutine psb_s_oacc_csr_mv_from_fmt(a, b, info) + implicit none + + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_s_base_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + + select type(b) + type is (psb_s_coo_sparse_mat) + call a%mv_from_coo(b, info) + class default + call a%psb_s_csr_sparse_mat%mv_from_fmt(b, info) + if (info /= 0) return + + !$acc update device(a%val, a%ja, a%irp) + end select + + end subroutine psb_s_oacc_csr_mv_from_fmt +end submodule psb_s_oacc_csr_mv_from_fmt_impl diff --git a/openacc/impl/psb_s_oacc_csr_reallocate_nz.F90 b/openacc/impl/psb_s_oacc_csr_reallocate_nz.F90 new file mode 100644 index 00000000..77c17120 --- /dev/null +++ b/openacc/impl/psb_s_oacc_csr_reallocate_nz.F90 @@ -0,0 +1,28 @@ +submodule (psb_s_oacc_csr_mat_mod) psb_s_oacc_csr_reallocate_nz_impl + use psb_base_mod +contains + module subroutine psb_s_oacc_csr_reallocate_nz(nz, a) + implicit none + integer(psb_ipk_), intent(in) :: nz + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + integer(psb_ipk_) :: info + integer(psb_ipk_) :: err_act + character(len=20) :: name='s_oacc_csr_reallocate_nz' + logical, parameter :: debug=.false. + + call psb_erractionsave(err_act) + info = psb_success_ + + call a%psb_s_csr_sparse_mat%reallocate(nz) + + call a%set_dev() + if (info /= 0) goto 9999 + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + + end subroutine psb_s_oacc_csr_reallocate_nz +end submodule psb_s_oacc_csr_reallocate_nz_impl diff --git a/openacc/impl/psb_s_oacc_csr_scal.F90 b/openacc/impl/psb_s_oacc_csr_scal.F90 new file mode 100644 index 00000000..b9c8a986 --- /dev/null +++ b/openacc/impl/psb_s_oacc_csr_scal.F90 @@ -0,0 +1,53 @@ +submodule (psb_s_oacc_csr_mat_mod) psb_s_oacc_csr_scal_impl + use psb_base_mod +contains + module subroutine psb_s_oacc_csr_scal(d, a, info, side) + implicit none + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + real(psb_spk_), intent(in) :: d(:) + integer(psb_ipk_), intent(out) :: info + character, intent(in), optional :: side + + integer(psb_ipk_) :: err_act + character(len=20) :: name='scal' + logical, parameter :: debug=.false. + integer(psb_ipk_) :: i, j + + info = psb_success_ + call psb_erractionsave(err_act) + + if (a%is_host()) call a%sync() + + if (present(side)) then + if (side == 'L') then + !$acc parallel loop present(a, d) + do i = 1, a%get_nrows() + do j = a%irp(i), a%irp(i+1) - 1 + a%val(j) = a%val(j) * d(i) + end do + end do + else if (side == 'R') then + !$acc parallel loop present(a, d) + do i = 1, a%get_ncols() + do j = a%irp(i), a%irp(i+1) - 1 + a%val(j) = a%val(j) * d(a%ja(j)) + end do + end do + end if + else + !$acc parallel loop present(a, d) + do i = 1, size(a%val) + a%val(i) = a%val(i) * d(i) + end do + end if + + call a%set_dev() + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + + end subroutine psb_s_oacc_csr_scal +end submodule psb_s_oacc_csr_scal_impl diff --git a/openacc/impl/psb_s_oacc_csr_scals.F90 b/openacc/impl/psb_s_oacc_csr_scals.F90 new file mode 100644 index 00000000..76ad7cf2 --- /dev/null +++ b/openacc/impl/psb_s_oacc_csr_scals.F90 @@ -0,0 +1,34 @@ +submodule (psb_s_oacc_csr_mat_mod) psb_s_oacc_csr_scals_impl + use psb_base_mod +contains + module subroutine psb_s_oacc_csr_scals(d, a, info) + implicit none + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + real(psb_spk_), intent(in) :: d + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: err_act + character(len=20) :: name='scal' + logical, parameter :: debug=.false. + integer(psb_ipk_) :: i + + info = psb_success_ + call psb_erractionsave(err_act) + + if (a%is_host()) call a%sync() + + !$acc parallel loop present(a) + do i = 1, size(a%val) + a%val(i) = a%val(i) * d + end do + + call a%set_dev() + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + + end subroutine psb_s_oacc_csr_scals +end submodule psb_s_oacc_csr_scals_impl diff --git a/openacc/impl/psb_s_oacc_csr_vect_mv.F90 b/openacc/impl/psb_s_oacc_csr_vect_mv.F90 new file mode 100644 index 00000000..9b15da3b --- /dev/null +++ b/openacc/impl/psb_s_oacc_csr_vect_mv.F90 @@ -0,0 +1,63 @@ +submodule (psb_s_oacc_csr_mat_mod) psb_s_oacc_csr_vect_mv_impl + use psb_base_mod +contains + module subroutine psb_s_oacc_csr_vect_mv(alpha, a, x, beta, y, info, trans) + implicit none + + real(psb_spk_), intent(in) :: alpha, beta + class(psb_s_oacc_csr_sparse_mat), intent(in) :: a + class(psb_s_base_vect_type), intent(inout) :: x, y + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + + integer(psb_ipk_) :: m, n + + info = psb_success_ + m = a%get_nrows() + n = a%get_ncols() + + if ((n /= size(x%v)) .or. (n /= size(y%v))) then + write(0,*) 'Size error ', m, n, size(x%v), size(y%v) + info = psb_err_invalis_mat_state_ + return + end if + + if (a%is_host()) call a%sync() + if (x%is_host()) call x%sync() + if (y%is_host()) call y%sync() + + call inner_spmv(m, n, alpha, a%val, a%ja, a%irp, x%v, beta, y%v, info) + call y%set_dev() + + contains + + subroutine inner_spmv(m, n, alpha, val, ja, irp, x, beta, y, info) + implicit none + integer(psb_ipk_) :: m, n + real(psb_spk_), intent(in) :: alpha, beta + real(psb_spk_) :: val(:), x(:), y(:) + integer(psb_ipk_) :: ja(:), irp(:) + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: i, j, ii, isz + real(psb_spk_) :: tmp + integer(psb_ipk_), parameter :: vsz = 256 + + info = 0 + + !$acc parallel loop vector_length(vsz) private(isz) + do ii = 1, m, vsz + isz = min(vsz, m - ii + 1) + !$acc loop independent private(tmp) + do i = ii, ii + isz - 1 + tmp = 0.0_psb_dpk_ + !$acc loop seq + do j = irp(i), irp(i + 1) - 1 + tmp = tmp + val(j) * x(ja(j)) + end do + y(i) = alpha * tmp + beta * y(i) + end do + end do + end subroutine inner_spmv + + end subroutine psb_s_oacc_csr_vect_mv +end submodule psb_s_oacc_csr_vect_mv_impl diff --git a/openacc/impl/psb_s_oacc_mlt_v.f90 b/openacc/impl/psb_s_oacc_mlt_v.f90 new file mode 100644 index 00000000..fb043cf2 --- /dev/null +++ b/openacc/impl/psb_s_oacc_mlt_v.f90 @@ -0,0 +1,31 @@ + +subroutine s_oacc_mlt_v(x, y, info) + use psb_s_oacc_vect_mod, psb_protect_name => s_oacc_mlt_v + + implicit none + class(psb_s_base_vect_type), intent(inout) :: x + class(psb_s_vect_oacc), intent(inout) :: y + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i, n + + info = 0 + n = min(x%get_nrows(), y%get_nrows()) + select type(xx => x) + class is (psb_s_vect_oacc) + if (y%is_host()) call y%sync() + if (xx%is_host()) call xx%sync() + !$acc parallel loop + do i = 1, n + y%v(i) = y%v(i) * xx%v(i) + end do + call y%set_dev() + class default + if (xx%is_dev()) call xx%sync() + if (y%is_dev()) call y%sync() + do i = 1, n + y%v(i) = y%v(i) * xx%v(i) + end do + call y%set_host() + end select +end subroutine s_oacc_mlt_v diff --git a/openacc/impl/psb_s_oacc_mlt_v_2.f90 b/openacc/impl/psb_s_oacc_mlt_v_2.f90 new file mode 100644 index 00000000..04ee8e09 --- /dev/null +++ b/openacc/impl/psb_s_oacc_mlt_v_2.f90 @@ -0,0 +1,98 @@ +subroutine s_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) + use psb_s_oacc_vect_mod, psb_protect_name => s_oacc_mlt_v_2 + use psb_string_mod + implicit none + real(psb_spk_), intent(in) :: alpha, beta + class(psb_s_base_vect_type), intent(inout) :: x + class(psb_s_base_vect_type), intent(inout) :: y + class(psb_s_vect_oacc), intent(inout) :: z + integer(psb_ipk_), intent(out) :: info + character(len=1), intent(in), optional :: conjgx, conjgy + integer(psb_ipk_) :: i, n + logical :: conjgx_, conjgy_ + + conjgx_ = .false. + conjgy_ = .false. + if (present(conjgx)) conjgx_ = (psb_toupper(conjgx) == 'C') + if (present(conjgy)) conjgy_ = (psb_toupper(conjgy) == 'C') + + n = min(x%get_nrows(), y%get_nrows(), z%get_nrows()) + info = 0 + select type(xx => x) + class is (psb_s_vect_oacc) + select type (yy => y) + class is (psb_s_vect_oacc) + if (xx%is_host()) call xx%sync() + if (yy%is_host()) call yy%sync() + if ((beta /= szero) .and. (z%is_host())) call z%sync() + if (conjgx_.and.conjgy_) then + !$acc parallel loop + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + else if (conjgx_.and.(.not.conjgy_)) then + !$acc parallel loop + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + else if ((.not.conjgx_).and.(conjgy_)) then + !$acc parallel loop + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + else + !$acc parallel loop + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + + end if + call z%set_dev() + class default + if (xx%is_dev()) call xx%sync() + if (yy%is_dev()) call yy%sync() + if ((beta /= szero) .and. (z%is_dev())) call z%sync() + if (conjgx_.and.conjgy_) then + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + else if (conjgx_.and.(.not.conjgy_)) then + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + else if ((.not.conjgx_).and.(conjgy_)) then + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + else + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + end if + call z%set_host() + end select + class default + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + if ((beta /= szero) .and. (z%is_dev())) call z%sync() + if (conjgx_.and.conjgy_) then + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + else if (conjgx_.and.(.not.conjgy_)) then + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + else if ((.not.conjgx_).and.(conjgy_)) then + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + else + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + end if + call z%set_host() + end select +end subroutine s_oacc_mlt_v_2 + diff --git a/openacc/impl/psb_z_oacc_csr_allocate_mnnz.F90 b/openacc/impl/psb_z_oacc_csr_allocate_mnnz.F90 new file mode 100644 index 00000000..fd19d6f9 --- /dev/null +++ b/openacc/impl/psb_z_oacc_csr_allocate_mnnz.F90 @@ -0,0 +1,35 @@ +submodule (psb_z_oacc_csr_mat_mod) psb_z_oacc_csr_allocate_mnnz_impl + use psb_base_mod +contains + module subroutine psb_z_oacc_csr_allocate_mnnz(m, n, a, nz) + implicit none + integer(psb_ipk_), intent(in) :: m, n + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + integer(psb_ipk_), intent(in), optional :: nz + integer(psb_ipk_) :: info + integer(psb_ipk_) :: err_act, nz_ + character(len=20) :: name='allocate_mnz' + logical, parameter :: debug=.false. + + call psb_erractionsave(err_act) + info = psb_success_ + + call a%psb_z_csr_sparse_mat%allocate(m, n, nz) + + if (.not.allocated(a%val)) then + allocate(a%val(nz)) + allocate(a%ja(nz)) + allocate(a%irp(m+1)) + end if + + call a%set_dev() + if (info /= 0) goto 9999 + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + + end subroutine psb_z_oacc_csr_allocate_mnnz +end submodule psb_z_oacc_csr_allocate_mnnz_impl diff --git a/openacc/impl/psb_z_oacc_csr_cp_from_coo.F90 b/openacc/impl/psb_z_oacc_csr_cp_from_coo.F90 new file mode 100644 index 00000000..0485c9ca --- /dev/null +++ b/openacc/impl/psb_z_oacc_csr_cp_from_coo.F90 @@ -0,0 +1,26 @@ +submodule (psb_z_oacc_csr_mat_mod) psb_z_oacc_csr_cp_from_coo_impl + use psb_base_mod +contains + module subroutine psb_z_oacc_csr_cp_from_coo(a, b, info) + implicit none + + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_z_coo_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + + call a%psb_z_csr_sparse_mat%cp_from_coo(b, info) + if (info /= 0) goto 9999 + + call a%set_dev() + if (info /= 0) goto 9999 + + return + +9999 continue + info = psb_err_alloc_dealloc_ + return + + end subroutine psb_z_oacc_csr_cp_from_coo +end submodule psb_z_oacc_csr_cp_from_coo_impl diff --git a/openacc/impl/psb_z_oacc_csr_cp_from_fmt.F90 b/openacc/impl/psb_z_oacc_csr_cp_from_fmt.F90 new file mode 100644 index 00000000..f2c68816 --- /dev/null +++ b/openacc/impl/psb_z_oacc_csr_cp_from_fmt.F90 @@ -0,0 +1,24 @@ +submodule (psb_z_oacc_csr_mat_mod) psb_z_oacc_csr_cp_from_fmt_impl + use psb_base_mod +contains + module subroutine psb_z_oacc_csr_cp_from_fmt(a, b, info) + implicit none + + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_z_base_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + + select type(b) + type is (psb_z_coo_sparse_mat) + call a%cp_from_coo(b, info) + class default + call a%psb_z_csr_sparse_mat%cp_from_fmt(b, info) + if (info /= 0) return + + !$acc update device(a%val, a%ja, a%irp) + end select + + end subroutine psb_z_oacc_csr_cp_from_fmt +end submodule psb_z_oacc_csr_cp_from_fmt_impl diff --git a/openacc/impl/psb_z_oacc_csr_csmm.F90 b/openacc/impl/psb_z_oacc_csr_csmm.F90 new file mode 100644 index 00000000..aeaaab33 --- /dev/null +++ b/openacc/impl/psb_z_oacc_csr_csmm.F90 @@ -0,0 +1,86 @@ +submodule (psb_z_oacc_csr_mat_mod) psb_z_oacc_csr_csmm_impl + use psb_base_mod +contains + module subroutine psb_z_oacc_csr_csmm(alpha, a, x, beta, y, info, trans) + implicit none + class(psb_z_oacc_csr_sparse_mat), intent(in) :: a + complex(psb_dpk_), intent(in) :: alpha, beta + complex(psb_dpk_), intent(in) :: x(:,:) + complex(psb_dpk_), intent(inout) :: y(:,:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + + character :: trans_ + integer(psb_ipk_) :: i, j, m, n,k, nxy + logical :: tra + integer(psb_ipk_) :: err_act + character(len=20) :: name = 'z_oacc_csmm' + logical, parameter :: debug = .false. + + info = psb_success_ + call psb_erractionsave(err_act) + + if (present(trans)) then + trans_ = trans + else + trans_ = 'N' + end if + + if (.not.a%is_asb()) then + info = psb_err_invaliz_mat_state_ + call psb_errpush(info, name) + goto 9999 + endif + tra = (psb_toupper(trans_) == 'T') .or. (psb_toupper(trans_) == 'C') + + if (tra) then + m = a%get_ncols() + n = a%get_nrows() + else + n = a%get_ncols() + m = a%get_nrows() + end if + + if (size(x,1) < n) then + info = 36 + call psb_errpush(info, name, i_err = (/3 * ione, n, izero, izero, izero/)) + goto 9999 + end if + + if (size(y,1) < m) then + info = 36 + call psb_errpush(info, name, i_err = (/5 * ione, m, izero, izero, izero/)) + goto 9999 + end if + + if (tra) then + call a%psb_z_csr_sparse_mat%spmm(alpha, x, beta, y, info, trans) + else + nxy = min(size(x,2), size(y,2)) + + !$acc parallel loop collapse(2) present(a, x, y) + do j = 1, nxy + do i = 1, m + y(i,j) = beta * y(i,j) + end do + end do + + !$acc parallel loop collapse(2) present(a, x, y) + do j = 1, nxy + do i = 1, n + do k = a%irp(i), a%irp(i+1) - 1 + y(a%ja(k), j) = y(a%ja(k), j) + alpha * a%val(k) * x(i, j) + end do + end do + end do + endif + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + + end subroutine psb_z_oacc_csr_csmm +end submodule psb_z_oacc_csr_csmm_impl + diff --git a/openacc/impl/psb_z_oacc_csr_csmv.F90 b/openacc/impl/psb_z_oacc_csr_csmv.F90 new file mode 100644 index 00000000..f5501b21 --- /dev/null +++ b/openacc/impl/psb_z_oacc_csr_csmv.F90 @@ -0,0 +1,81 @@ +submodule (psb_z_oacc_csr_mat_mod) psb_z_oacc_csr_csmv_impl + use psb_base_mod +contains + module subroutine psb_z_oacc_csr_csmv(alpha, a, x, beta, y, info, trans) + implicit none + class(psb_z_oacc_csr_sparse_mat), intent(in) :: a + complex(psb_dpk_), intent(in) :: alpha, beta + complex(psb_dpk_), intent(in) :: x(:) + complex(psb_dpk_), intent(inout) :: y(:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + + character :: trans_ + integer(psb_ipk_) :: i, j, m, n + logical :: tra + integer(psb_ipk_) :: err_act + character(len=20) :: name = 'z_oacc_csmv' + logical, parameter :: debug = .false. + + call psb_erractionsave(err_act) + info = psb_success_ + + if (present(trans)) then + trans_ = trans + else + trans_ = 'N' + end if + + if (.not.a%is_asb()) then + info = psb_err_invaliz_mat_state_ + call psb_errpush(info, name) + goto 9999 + endif + + tra = (psb_toupper(trans_) == 'T') .or. (psb_toupper(trans_) == 'C') + + if (tra) then + m = a%get_ncols() + n = a%get_nrows() + else + n = a%get_ncols() + m = a%get_nrows() + end if + + if (size(x,1) < n) then + info = 36 + call psb_errpush(info, name, i_err = (/3 * ione, n, izero, izero, izero/)) + goto 9999 + end if + + if (size(y,1) < m) then + info = 36 + call psb_errpush(info, name, i_err = (/5 * ione, m, izero, izero, izero/)) + goto 9999 + end if + + if (tra) then + call a%psb_z_csr_sparse_mat%spmm(alpha, x, beta, y, info, trans) + else + !$acc parallel loop present(a, x, y) + do i = 1, m + y(i) = beta * y(i) + end do + + !$acc parallel loop present(a, x, y) + do i = 1, n + do j = a%irp(i), a%irp(i+1) - 1 + y(a%ja(j)) = y(a%ja(j)) + alpha * a%val(j) * x(i) + end do + end do + endif + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + + end subroutine psb_z_oacc_csr_csmv +end submodule psb_z_oacc_csr_csmv_impl + diff --git a/openacc/impl/psb_z_oacc_csr_inner_vect_sv.F90 b/openacc/impl/psb_z_oacc_csr_inner_vect_sv.F90 new file mode 100644 index 00000000..b5d552d3 --- /dev/null +++ b/openacc/impl/psb_z_oacc_csr_inner_vect_sv.F90 @@ -0,0 +1,83 @@ +submodule (psb_z_oacc_csr_mat_mod) psb_z_oacc_csr_inner_vect_sv_impl + use psb_base_mod +contains + module subroutine psb_z_oacc_csr_inner_vect_sv(alpha, a, x, beta, y, info, trans) + implicit none + class(psb_z_oacc_csr_sparse_mat), intent(in) :: a + complex(psb_dpk_), intent(in) :: alpha, beta + class(psb_z_base_vect_type), intent(inout) :: x, y + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + + complex(psb_dpk_), allocatable :: rx(:), ry(:) + logical :: tra + character :: trans_ + integer(psb_ipk_) :: err_act + character(len=20) :: name = 'z_oacc_csr_inner_vect_sv' + logical, parameter :: debug = .false. + integer(psb_ipk_) :: i + + call psb_get_erraction(err_act) + info = psb_success_ + + if (present(trans)) then + trans_ = trans + else + trans_ = 'N' + end if + + if (.not.a%is_asb()) then + info = psb_err_invaliz_mat_state_ + call psb_errpush(info, name) + goto 9999 + endif + + tra = (psb_toupper(trans_) == 'T') .or. (psb_toupper(trans_) == 'C') + + if (tra .or. (beta /= dzero)) then + call x%sync() + call y%sync() + call a%psb_z_csr_sparse_mat%inner_spsm(alpha, x, beta, y, info, trans) + call y%set_host() + else + select type (xx => x) + type is (psb_z_vect_oacc) + select type(yy => y) + type is (psb_z_vect_oacc) + if (xx%is_host()) call xx%sync() + if (beta /= dzero) then + if (yy%is_host()) call yy%sync() + end if + !$acc parallel loop present(a, xx, yy) + do i = 1, size(a%val) + yy%v(i) = alpha * a%val(i) * xx%v(a%ja(i)) + beta * yy%v(i) + end do + call yy%set_dev() + class default + rx = xx%get_vect() + ry = y%get_vect() + call a%psb_z_csr_sparse_mat%inner_spsm(alpha, rx, beta, ry, info) + call y%bld(ry) + end select + class default + rx = x%get_vect() + ry = y%get_vect() + call a%psb_z_csr_sparse_mat%inner_spsm(alpha, rx, beta, ry, info) + call y%bld(ry) + end select + endif + + if (info /= psb_success_) then + info = psb_err_from_subroutine_ + call psb_errpush(info, name, a_err = 'csrg_vect_sv') + goto 9999 + endif + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + end subroutine psb_z_oacc_csr_inner_vect_sv +end submodule psb_z_oacc_csr_inner_vect_sv_impl + diff --git a/openacc/impl/psb_z_oacc_csr_mold.F90 b/openacc/impl/psb_z_oacc_csr_mold.F90 new file mode 100644 index 00000000..e7e9e8b9 --- /dev/null +++ b/openacc/impl/psb_z_oacc_csr_mold.F90 @@ -0,0 +1,35 @@ +submodule (psb_z_oacc_csr_mat_mod) psb_z_oacc_csr_mold_impl + use psb_base_mod +contains + module subroutine psb_z_oacc_csr_mold(a, b, info) + implicit none + class(psb_z_oacc_csr_sparse_mat), intent(in) :: a + class(psb_z_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: err_act + character(len=20) :: name='csr_mold' + logical, parameter :: debug=.false. + + call psb_get_erraction(err_act) + + info = 0 + if (allocated(b)) then + call b%free() + deallocate(b, stat=info) + end if + if (info == 0) allocate(psb_z_oacc_csr_sparse_mat :: b, stat=info) + + if (info /= psb_success_) then + info = psb_err_alloc_dealloc_ + call psb_errpush(info, name) + goto 9999 + end if + return + +9999 call psb_error_handler(err_act) + + return + + end subroutine psb_z_oacc_csr_mold +end submodule psb_z_oacc_csr_molz_impl + diff --git a/openacc/impl/psb_z_oacc_csr_mv_from_coo.F90 b/openacc/impl/psb_z_oacc_csr_mv_from_coo.F90 new file mode 100644 index 00000000..44b01b68 --- /dev/null +++ b/openacc/impl/psb_z_oacc_csr_mv_from_coo.F90 @@ -0,0 +1,25 @@ +submodule (psb_z_oacc_csr_mat_mod) psb_z_oacc_csr_mv_from_coo_impl + use psb_base_mod +contains + module subroutine psb_z_oacc_csr_mv_from_coo(a, b, info) + implicit none + + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_z_coo_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + + call a%psb_z_csr_sparse_mat%mv_from_coo(b, info) + if (info /= 0) goto 9999 + + !$acc update device(a%val, a%ja, a%irp) + + return + +9999 continue + info = psb_err_alloc_dealloc_ + return + + end subroutine psb_z_oacc_csr_mv_from_coo +end submodule psb_z_oacc_csr_mv_from_coo_impl diff --git a/openacc/impl/psb_z_oacc_csr_mv_from_fmt.F90 b/openacc/impl/psb_z_oacc_csr_mv_from_fmt.F90 new file mode 100644 index 00000000..bf777e85 --- /dev/null +++ b/openacc/impl/psb_z_oacc_csr_mv_from_fmt.F90 @@ -0,0 +1,24 @@ +submodule (psb_z_oacc_csr_mat_mod) psb_z_oacc_csr_mv_from_fmt_impl + use psb_base_mod +contains + module subroutine psb_z_oacc_csr_mv_from_fmt(a, b, info) + implicit none + + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_z_base_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + + info = psb_success_ + + select type(b) + type is (psb_z_coo_sparse_mat) + call a%mv_from_coo(b, info) + class default + call a%psb_z_csr_sparse_mat%mv_from_fmt(b, info) + if (info /= 0) return + + !$acc update device(a%val, a%ja, a%irp) + end select + + end subroutine psb_z_oacc_csr_mv_from_fmt +end submodule psb_z_oacc_csr_mv_from_fmt_impl diff --git a/openacc/impl/psb_z_oacc_csr_reallocate_nz.F90 b/openacc/impl/psb_z_oacc_csr_reallocate_nz.F90 new file mode 100644 index 00000000..bdfb88d6 --- /dev/null +++ b/openacc/impl/psb_z_oacc_csr_reallocate_nz.F90 @@ -0,0 +1,28 @@ +submodule (psb_z_oacc_csr_mat_mod) psb_z_oacc_csr_reallocate_nz_impl + use psb_base_mod +contains + module subroutine psb_z_oacc_csr_reallocate_nz(nz, a) + implicit none + integer(psb_ipk_), intent(in) :: nz + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + integer(psb_ipk_) :: info + integer(psb_ipk_) :: err_act + character(len=20) :: name='z_oacc_csr_reallocate_nz' + logical, parameter :: debug=.false. + + call psb_erractionsave(err_act) + info = psb_success_ + + call a%psb_z_csr_sparse_mat%reallocate(nz) + + call a%set_dev() + if (info /= 0) goto 9999 + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + + end subroutine psb_z_oacc_csr_reallocate_nz +end submodule psb_z_oacc_csr_reallocate_nz_impl diff --git a/openacc/impl/psb_z_oacc_csr_scal.F90 b/openacc/impl/psb_z_oacc_csr_scal.F90 new file mode 100644 index 00000000..f09ff595 --- /dev/null +++ b/openacc/impl/psb_z_oacc_csr_scal.F90 @@ -0,0 +1,53 @@ +submodule (psb_z_oacc_csr_mat_mod) psb_z_oacc_csr_scal_impl + use psb_base_mod +contains + module subroutine psb_z_oacc_csr_scal(d, a, info, side) + implicit none + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + complex(psb_dpk_), intent(in) :: d(:) + integer(psb_ipk_), intent(out) :: info + character, intent(in), optional :: side + + integer(psb_ipk_) :: err_act + character(len=20) :: name='scal' + logical, parameter :: debug=.false. + integer(psb_ipk_) :: i, j + + info = psb_success_ + call psb_erractionsave(err_act) + + if (a%is_host()) call a%sync() + + if (present(side)) then + if (side == 'L') then + !$acc parallel loop present(a, d) + do i = 1, a%get_nrows() + do j = a%irp(i), a%irp(i+1) - 1 + a%val(j) = a%val(j) * d(i) + end do + end do + else if (side == 'R') then + !$acc parallel loop present(a, d) + do i = 1, a%get_ncols() + do j = a%irp(i), a%irp(i+1) - 1 + a%val(j) = a%val(j) * d(a%ja(j)) + end do + end do + end if + else + !$acc parallel loop present(a, d) + do i = 1, size(a%val) + a%val(i) = a%val(i) * d(i) + end do + end if + + call a%set_dev() + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + + end subroutine psb_z_oacc_csr_scal +end submodule psb_z_oacc_csr_scal_impl diff --git a/openacc/impl/psb_z_oacc_csr_scals.F90 b/openacc/impl/psb_z_oacc_csr_scals.F90 new file mode 100644 index 00000000..1fe64951 --- /dev/null +++ b/openacc/impl/psb_z_oacc_csr_scals.F90 @@ -0,0 +1,34 @@ +submodule (psb_z_oacc_csr_mat_mod) psb_z_oacc_csr_scals_impl + use psb_base_mod +contains + module subroutine psb_z_oacc_csr_scals(d, a, info) + implicit none + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + complex(psb_dpk_), intent(in) :: d + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: err_act + character(len=20) :: name='scal' + logical, parameter :: debug=.false. + integer(psb_ipk_) :: i + + info = psb_success_ + call psb_erractionsave(err_act) + + if (a%is_host()) call a%sync() + + !$acc parallel loop present(a) + do i = 1, size(a%val) + a%val(i) = a%val(i) * d + end do + + call a%set_dev() + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(err_act) + return + + end subroutine psb_z_oacc_csr_scals +end submodule psb_z_oacc_csr_scals_impl diff --git a/openacc/impl/psb_z_oacc_csr_vect_mv.F90 b/openacc/impl/psb_z_oacc_csr_vect_mv.F90 new file mode 100644 index 00000000..437dd70a --- /dev/null +++ b/openacc/impl/psb_z_oacc_csr_vect_mv.F90 @@ -0,0 +1,63 @@ +submodule (psb_z_oacc_csr_mat_mod) psb_z_oacc_csr_vect_mv_impl + use psb_base_mod +contains + module subroutine psb_z_oacc_csr_vect_mv(alpha, a, x, beta, y, info, trans) + implicit none + + complex(psb_dpk_), intent(in) :: alpha, beta + class(psb_z_oacc_csr_sparse_mat), intent(in) :: a + class(psb_z_base_vect_type), intent(inout) :: x, y + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + + integer(psb_ipk_) :: m, n + + info = psb_success_ + m = a%get_nrows() + n = a%get_ncols() + + if ((n /= size(x%v)) .or. (n /= size(y%v))) then + write(0,*) 'Size error ', m, n, size(x%v), size(y%v) + info = psb_err_invaliz_mat_state_ + return + end if + + if (a%is_host()) call a%sync() + if (x%is_host()) call x%sync() + if (y%is_host()) call y%sync() + + call inner_spmv(m, n, alpha, a%val, a%ja, a%irp, x%v, beta, y%v, info) + call y%set_dev() + + contains + + subroutine inner_spmv(m, n, alpha, val, ja, irp, x, beta, y, info) + implicit none + integer(psb_ipk_) :: m, n + complex(psb_dpk_), intent(in) :: alpha, beta + complex(psb_dpk_) :: val(:), x(:), y(:) + integer(psb_ipk_) :: ja(:), irp(:) + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: i, j, ii, isz + complex(psb_dpk_) :: tmp + integer(psb_ipk_), parameter :: vsz = 256 + + info = 0 + + !$acc parallel loop vector_length(vsz) private(isz) + do ii = 1, m, vsz + isz = min(vsz, m - ii + 1) + !$acc loop independent private(tmp) + do i = ii, ii + isz - 1 + tmp = 0.0_psb_dpk_ + !$acc loop seq + do j = irp(i), irp(i + 1) - 1 + tmp = tmp + val(j) * x(ja(j)) + end do + y(i) = alpha * tmp + beta * y(i) + end do + end do + end subroutine inner_spmv + + end subroutine psb_z_oacc_csr_vect_mv +end submodule psb_z_oacc_csr_vect_mv_impl diff --git a/openacc/impl/psb_z_oacc_mlt_v.f90 b/openacc/impl/psb_z_oacc_mlt_v.f90 new file mode 100644 index 00000000..7018f009 --- /dev/null +++ b/openacc/impl/psb_z_oacc_mlt_v.f90 @@ -0,0 +1,31 @@ + +subroutine z_oacc_mlt_v(x, y, info) + use psb_z_oacc_vect_mod, psb_protect_name => z_oacc_mlt_v + + implicit none + class(psb_z_base_vect_type), intent(inout) :: x + class(psb_z_vect_oacc), intent(inout) :: y + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i, n + + info = 0 + n = min(x%get_nrows(), y%get_nrows()) + select type(xx => x) + class is (psb_z_vect_oacc) + if (y%is_host()) call y%sync() + if (xx%is_host()) call xx%sync() + !$acc parallel loop + do i = 1, n + y%v(i) = y%v(i) * xx%v(i) + end do + call y%set_dev() + class default + if (xx%is_dev()) call xx%sync() + if (y%is_dev()) call y%sync() + do i = 1, n + y%v(i) = y%v(i) * xx%v(i) + end do + call y%set_host() + end select +end subroutine z_oacc_mlt_v diff --git a/openacc/impl/psb_z_oacc_mlt_v_2.f90 b/openacc/impl/psb_z_oacc_mlt_v_2.f90 new file mode 100644 index 00000000..dbc0929c --- /dev/null +++ b/openacc/impl/psb_z_oacc_mlt_v_2.f90 @@ -0,0 +1,98 @@ +subroutine z_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) + use psb_z_oacc_vect_mod, psb_protect_name => z_oacc_mlt_v_2 + use psb_string_mod + implicit none + complex(psb_dpk_), intent(in) :: alpha, beta + class(psb_z_base_vect_type), intent(inout) :: x + class(psb_z_base_vect_type), intent(inout) :: y + class(psb_z_vect_oacc), intent(inout) :: z + integer(psb_ipk_), intent(out) :: info + character(len=1), intent(in), optional :: conjgx, conjgy + integer(psb_ipk_) :: i, n + logical :: conjgx_, conjgy_ + + conjgx_ = .false. + conjgy_ = .false. + if (present(conjgx)) conjgx_ = (psb_toupper(conjgx) == 'C') + if (present(conjgy)) conjgy_ = (psb_toupper(conjgy) == 'C') + + n = min(x%get_nrows(), y%get_nrows(), z%get_nrows()) + info = 0 + select type(xx => x) + class is (psb_z_vect_oacc) + select type (yy => y) + class is (psb_z_vect_oacc) + if (xx%is_host()) call xx%sync() + if (yy%is_host()) call yy%sync() + if ((beta /= zzero) .and. (z%is_host())) call z%sync() + if (conjgx_.and.conjgy_) then + !$acc parallel loop + do i = 1, n + z%v(i) = alpha * conjg(xx%v(i)) * conjg(yy%v(i)) + beta * z%v(i) + end do + else if (conjgx_.and.(.not.conjgy_)) then + !$acc parallel loop + do i = 1, n + z%v(i) = alpha * conjg(xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + else if ((.not.conjgx_).and.(conjgy_)) then + !$acc parallel loop + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * conjg(yy%v(i)) + beta * z%v(i) + end do + else + !$acc parallel loop + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + + end if + call z%set_dev() + class default + if (xx%is_dev()) call xx%sync() + if (yy%is_dev()) call yy%sync() + if ((beta /= zzero) .and. (z%is_dev())) call z%sync() + if (conjgx_.and.conjgy_) then + do i = 1, n + z%v(i) = alpha * conjg(xx%v(i)) * conjg(yy%v(i)) + beta * z%v(i) + end do + else if (conjgx_.and.(.not.conjgy_)) then + do i = 1, n + z%v(i) = alpha * conjg(xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + else if ((.not.conjgx_).and.(conjgy_)) then + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * conjg(yy%v(i)) + beta * z%v(i) + end do + else + do i = 1, n + z%v(i) = alpha * (xx%v(i)) * (yy%v(i)) + beta * z%v(i) + end do + end if + call z%set_host() + end select + class default + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + if ((beta /= zzero) .and. (z%is_dev())) call z%sync() + if (conjgx_.and.conjgy_) then + do i = 1, n + z%v(i) = alpha * conjg(x%v(i)) * conjg(y%v(i)) + beta * z%v(i) + end do + else if (conjgx_.and.(.not.conjgy_)) then + do i = 1, n + z%v(i) = alpha * conjg(x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + else if ((.not.conjgx_).and.(conjgy_)) then + do i = 1, n + z%v(i) = alpha * (x%v(i)) * conjg(y%v(i)) + beta * z%v(i) + end do + else + do i = 1, n + z%v(i) = alpha * (x%v(i)) * (y%v(i)) + beta * z%v(i) + end do + end if + call z%set_host() + end select +end subroutine z_oacc_mlt_v_2 + diff --git a/openacc/psb_c_oacc_csr_mat_mod.F90 b/openacc/psb_c_oacc_csr_mat_mod.F90 new file mode 100644 index 00000000..00e79570 --- /dev/null +++ b/openacc/psb_c_oacc_csr_mat_mod.F90 @@ -0,0 +1,343 @@ +module psb_c_oacc_csr_mat_mod + + use iso_c_binding + use psb_c_mat_mod + use psb_c_oacc_vect_mod + !use oaccsparse_mod + + integer(psb_ipk_), parameter, private :: is_host = -1 + integer(psb_ipk_), parameter, private :: is_sync = 0 + integer(psb_ipk_), parameter, private :: is_dev = 1 + + type, extends(psb_c_csr_sparse_mat) :: psb_c_oacc_csr_sparse_mat + integer(psb_ipk_) :: devstate = is_host + contains + procedure, nopass :: get_fmt => c_oacc_csr_get_fmt + procedure, pass(a) :: sizeof => c_oacc_csr_sizeof + procedure, pass(a) :: vect_mv => psb_c_oacc_csr_vect_mv + procedure, pass(a) :: in_vect_sv => psb_c_oacc_csr_inner_vect_sv + procedure, pass(a) :: csmm => psb_c_oacc_csr_csmm + procedure, pass(a) :: csmv => psb_c_oacc_csr_csmv + procedure, pass(a) :: scals => psb_c_oacc_csr_scals + procedure, pass(a) :: scalv => psb_c_oacc_csr_scal + procedure, pass(a) :: reallocate_nz => psb_c_oacc_csr_reallocate_nz + procedure, pass(a) :: allocate_mnnz => psb_c_oacc_csr_allocate_mnnz + procedure, pass(a) :: cp_from_coo => psb_c_oacc_csr_cp_from_coo + procedure, pass(a) :: cp_from_fmt => psb_c_oacc_csr_cp_from_fmt + procedure, pass(a) :: mv_from_coo => psb_c_oacc_csr_mv_from_coo + procedure, pass(a) :: mv_from_fmt => psb_c_oacc_csr_mv_from_fmt + procedure, pass(a) :: free => c_oacc_csr_free + procedure, pass(a) :: mold => psb_c_oacc_csr_mold + procedure, pass(a) :: all => c_oacc_csr_all + procedure, pass(a) :: is_host => c_oacc_csr_is_host + procedure, pass(a) :: is_sync => c_oacc_csr_is_sync + procedure, pass(a) :: is_dev => c_oacc_csr_is_dev + procedure, pass(a) :: set_host => c_oacc_csr_set_host + procedure, pass(a) :: set_sync => c_oacc_csr_set_sync + procedure, pass(a) :: set_dev => c_oacc_csr_set_dev + procedure, pass(a) :: sync_space => c_oacc_csr_sync_space + procedure, pass(a) :: sync => c_oacc_csr_sync + end type psb_c_oacc_csr_sparse_mat + + interface + module subroutine psb_c_oacc_csr_mold(a,b,info) + class(psb_c_oacc_csr_sparse_mat), intent(in) :: a + class(psb_c_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_c_oacc_csr_mold + end interface + + interface + module subroutine psb_c_oacc_csr_cp_from_fmt(a,b,info) + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_c_base_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_c_oacc_csr_cp_from_fmt + end interface + + interface + module subroutine psb_c_oacc_csr_mv_from_coo(a,b,info) + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_c_coo_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_c_oacc_csr_mv_from_coo + end interface + + interface + module subroutine psb_c_oacc_csr_mv_from_fmt(a,b,info) + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_c_base_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_c_oacc_csr_mv_from_fmt + end interface + + interface + module subroutine psb_c_oacc_csr_vect_mv(alpha, a, x, beta, y, info, trans) + class(psb_c_oacc_csr_sparse_mat), intent(in) :: a + complex(psb_spk_), intent(in) :: alpha, beta + class(psb_c_base_vect_type), intent(inout) :: x, y + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + end subroutine psb_c_oacc_csr_vect_mv + end interface + + interface + module subroutine psb_c_oacc_csr_inner_vect_sv(alpha, a, x, beta, y, info, trans) + class(psb_c_oacc_csr_sparse_mat), intent(in) :: a + complex(psb_spk_), intent(in) :: alpha, beta + class(psb_c_base_vect_type), intent(inout) :: x,y + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + end subroutine psb_c_oacc_csr_inner_vect_sv + end interface + + interface + module subroutine psb_c_oacc_csr_csmm(alpha, a, x, beta, y, info, trans) + class(psb_c_oacc_csr_sparse_mat), intent(in) :: a + complex(psb_spk_), intent(in) :: alpha, beta, x(:,:) + complex(psb_spk_), intent(inout) :: y(:,:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + end subroutine psb_c_oacc_csr_csmm + end interface + + interface + module subroutine psb_c_oacc_csr_csmv(alpha, a, x, beta, y, info, trans) + class(psb_c_oacc_csr_sparse_mat), intent(in) :: a + complex(psb_spk_), intent(in) :: alpha, beta, x(:) + complex(psb_spk_), intent(inout) :: y(:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + end subroutine psb_c_oacc_csr_csmv + end interface + + interface + module subroutine psb_c_oacc_csr_scals(d, a, info) + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + complex(psb_spk_), intent(in) :: d + integer(psb_ipk_), intent(out) :: info + end subroutine psb_c_oacc_csr_scals + end interface + + interface + module subroutine psb_c_oacc_csr_scal(d,a,info,side) + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + complex(psb_spk_), intent(in) :: d(:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: side + end subroutine psb_c_oacc_csr_scal + end interface + + interface + module subroutine psb_c_oacc_csr_reallocate_nz(nz,a) + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + integer(psb_ipk_), intent(in) :: nz + end subroutine psb_c_oacc_csr_reallocate_nz + end interface + + interface + module subroutine psb_c_oacc_csr_allocate_mnnz(m,n,a,nz) + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + integer(psb_ipk_), intent(in) :: m,n + integer(psb_ipk_), intent(in), optional :: nz + end subroutine psb_c_oacc_csr_allocate_mnnz + end interface + + interface + module subroutine psb_c_oacc_csr_cp_from_coo(a,b,info) + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_c_coo_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_c_oacc_csr_cp_from_coo + end interface + +contains + + + subroutine c_oacc_csr_free(a) + use psb_base_mod + implicit none + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + integer(psb_ipk_) :: info + + if (allocated(a%val)) then + !$acc exit data delete(a%val) + end if + if (allocated(a%ja)) then + !$acc exit data delete(a%ja) + end if + if (allocated(a%irp)) then + !$acc exit data delete(a%irp) + end if + + call a%psb_c_csr_sparse_mat%free() + + return + end subroutine c_oacc_csr_free + + function c_oacc_csr_sizeof(a) result(res) + implicit none + class(psb_c_oacc_csr_sparse_mat), intent(in) :: a + integer(psb_epk_) :: res + + if (a%is_dev()) call a%sync() + + res = 8 + res = res + (2*psb_sizeof_sp) * size(a%val) + res = res + psb_sizeof_ip * size(a%irp) + res = res + psb_sizeof_ip * size(a%ja) + + end function c_oacc_csr_sizeof + + + function c_oacc_csr_get_fmt() result(res) + implicit none + character(len=5) :: res + res = 'CSR_oacc' + end function c_oacc_csr_get_fmt + + subroutine c_oacc_csr_all(m, n, nz, a, info) + implicit none + integer(psb_ipk_), intent(in) :: m, n, nz + class(psb_c_oacc_csr_sparse_mat), intent(out) :: a + integer(psb_ipk_), intent(out) :: info + + info = 0 + if (allocated(a%val)) then + !$acc exit data delete(a%val) finalize + deallocate(a%val, stat=info) + end if + if (allocated(a%ja)) then + !$acc exit data delete(a%ja) finalize + deallocate(a%ja, stat=info) + end if + if (allocated(a%irp)) then + !$acc exit data delete(a%irp) finalize + deallocate(a%irp, stat=info) + end if + + call a%set_nrows(m) + call a%set_ncols(n) + + allocate(a%val(nz),stat=info) + allocate(a%ja(nz),stat=info) + allocate(a%irp(m+1),stat=info) + if (info == 0) call a%set_host() + if (info == 0) call a%sync_space() + end subroutine c_oacc_csr_all + + function c_oacc_csr_is_host(a) result(res) + implicit none + class(psb_c_oacc_csr_sparse_mat), intent(in) :: a + logical :: res + + res = (a%devstate == is_host) + end function c_oacc_csr_is_host + + function c_oacc_csr_is_sync(a) result(res) + implicit none + class(psb_c_oacc_csr_sparse_mat), intent(in) :: a + logical :: res + + res = (a%devstate == is_sync) + end function c_oacc_csr_is_sync + + function c_oacc_csr_is_dev(a) result(res) + implicit none + class(psb_c_oacc_csr_sparse_mat), intent(in) :: a + logical :: res + + res = (a%devstate == is_dev) + end function c_oacc_csr_is_dev + + subroutine c_oacc_csr_set_host(a) + implicit none + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + + a%devstate = is_host + end subroutine c_oacc_csr_set_host + + subroutine c_oacc_csr_set_sync(a) + implicit none + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + + a%devstate = is_sync + end subroutine c_oacc_csr_set_sync + + subroutine c_oacc_csr_set_dev(a) + implicit none + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + + a%devstate = is_dev + end subroutine c_oacc_csr_set_dev + + subroutine c_oacc_csr_sync_space(a) + implicit none + class(psb_c_oacc_csr_sparse_mat), intent(inout) :: a + if (allocated(a%val)) then + call c_oacc_create_dev(a%val) + end if + if (allocated(a%ja)) then + call i_oacc_create_dev(a%ja) + end if + if (allocated(a%irp)) then + call i_oacc_create_dev(a%irp) + end if + contains + subroutine c_oacc_create_dev(v) + implicit none + complex(psb_spk_), intent(in) :: v(:) + !$acc enter data copyin(v) + end subroutine c_oacc_create_dev + subroutine i_oacc_create_dev(v) + implicit none + integer(psb_ipk_), intent(in) :: v(:) + !$acc enter data copyin(v) + end subroutine i_oacc_create_dev + end subroutine c_oacc_csr_sync_space + + subroutine c_oacc_csr_sync(a) + implicit none + class(psb_c_oacc_csr_sparse_mat), target, intent(in) :: a + class(psb_c_oacc_csr_sparse_mat), pointer :: tmpa + integer(psb_ipk_) :: info + + tmpa => a + if (a%is_dev()) then + call c_oacc_csr_to_host(a%val) + call i_oacc_csr_to_host(a%ja) + call i_oacc_csr_to_host(a%irp) + else if (a%is_host()) then + call c_oacc_csr_to_dev(a%val) + call i_oacc_csr_to_dev(a%ja) + call i_oacc_csr_to_dev(a%irp) + end if + call tmpa%set_sync() + end subroutine c_oacc_csr_sync + + subroutine c_oacc_csr_to_dev(v) + implicit none + complex(psb_spk_), intent(in) :: v(:) + !$acc update device(v) + end subroutine c_oacc_csr_to_dev + + subroutine c_oacc_csr_to_host(v) + implicit none + complex(psb_spk_), intent(in) :: v(:) + !$acc update self(v) + end subroutine c_oacc_csr_to_host + + subroutine i_oacc_csr_to_dev(v) + implicit none + integer(psb_ipk_), intent(in) :: v(:) + !$acc update device(v) + end subroutine i_oacc_csr_to_dev + + subroutine i_oacc_csr_to_host(v) + implicit none + integer(psb_ipk_), intent(in) :: v(:) + !$acc update self(v) + end subroutine i_oacc_csr_to_host + +end module psb_c_oacc_csr_mat_mod + diff --git a/openacc/psb_c_oacc_vect_mod.F90 b/openacc/psb_c_oacc_vect_mod.F90 new file mode 100644 index 00000000..6f9545ea --- /dev/null +++ b/openacc/psb_c_oacc_vect_mod.F90 @@ -0,0 +1,935 @@ +module psb_c_oacc_vect_mod + use iso_c_binding + use psb_const_mod + use psb_error_mod + use psb_c_vect_mod + use psb_i_vect_mod + use psb_i_oacc_vect_mod + + integer(psb_ipk_), parameter, private :: is_host = -1 + integer(psb_ipk_), parameter, private :: is_sync = 0 + integer(psb_ipk_), parameter, private :: is_dev = 1 + + type, extends(psb_c_base_vect_type) :: psb_c_vect_oacc + integer :: state = is_host + + contains + procedure, pass(x) :: get_nrows => c_oacc_get_nrows + procedure, nopass :: get_fmt => c_oacc_get_fmt + + procedure, pass(x) :: all => c_oacc_vect_all + procedure, pass(x) :: zero => c_oacc_zero + procedure, pass(x) :: asb_m => c_oacc_asb_m + procedure, pass(x) :: sync => c_oacc_sync + procedure, pass(x) :: sync_space => c_oacc_sync_space + procedure, pass(x) :: bld_x => c_oacc_bld_x + procedure, pass(x) :: bld_mn => c_oacc_bld_mn + procedure, pass(x) :: free => c_oacc_vect_free + procedure, pass(x) :: ins_a => c_oacc_ins_a + procedure, pass(x) :: ins_v => c_oacc_ins_v + procedure, pass(x) :: is_host => c_oacc_is_host + procedure, pass(x) :: is_dev => c_oacc_is_dev + procedure, pass(x) :: is_sync => c_oacc_is_sync + procedure, pass(x) :: set_host => c_oacc_set_host + procedure, pass(x) :: set_dev => c_oacc_set_dev + procedure, pass(x) :: set_sync => c_oacc_set_sync + procedure, pass(x) :: set_scal => c_oacc_set_scal + + procedure, pass(x) :: gthzv_x => c_oacc_gthzv_x + procedure, pass(x) :: gthzbuf_x => c_oacc_gthzbuf + procedure, pass(y) :: sctb => c_oacc_sctb + procedure, pass(y) :: sctb_x => c_oacc_sctb_x + procedure, pass(y) :: sctb_buf => c_oacc_sctb_buf + + procedure, pass(x) :: get_size => c_oacc_get_size + procedure, pass(x) :: dot_v => c_oacc_vect_dot + procedure, pass(x) :: dot_a => c_oacc_dot_a + procedure, pass(y) :: axpby_v => c_oacc_axpby_v + procedure, pass(y) :: axpby_a => c_oacc_axpby_a + procedure, pass(z) :: abgdxyz => c_oacc_abgdxyz + procedure, pass(y) :: mlt_a => c_oacc_mlt_a + procedure, pass(z) :: mlt_a_2 => c_oacc_mlt_a_2 + procedure, pass(y) :: mlt_v => c_oacc_mlt_v + procedure, pass(z) :: mlt_v_2 => c_oacc_mlt_v_2 + procedure, pass(x) :: scal => c_oacc_scal + procedure, pass(x) :: nrm2 => c_oacc_nrm2 + procedure, pass(x) :: amax => c_oacc_amax + procedure, pass(x) :: asum => c_oacc_asum + procedure, pass(x) :: absval1 => c_oacc_absval1 + procedure, pass(x) :: absval2 => c_oacc_absval2 + + end type psb_c_vect_oacc + + interface + subroutine c_oacc_mlt_v(x, y, info) + import + implicit none + class(psb_c_base_vect_type), intent(inout) :: x + class(psb_c_vect_oacc), intent(inout) :: y + integer(psb_ipk_), intent(out) :: info + end subroutine c_oacc_mlt_v + end interface + + + interface + subroutine c_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) + import + implicit none + complex(psb_spk_), intent(in) :: alpha, beta + class(psb_c_base_vect_type), intent(inout) :: x + class(psb_c_base_vect_type), intent(inout) :: y + class(psb_c_vect_oacc), intent(inout) :: z + integer(psb_ipk_), intent(out) :: info + character(len=1), intent(in), optional :: conjgx, conjgy + end subroutine c_oacc_mlt_v_2 + end interface + +contains + + subroutine c_oacc_absval1(x) + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + integer(psb_ipk_) :: n, i + + if (x%is_host()) call x%sync_space() + n = size(x%v) + !$acc parallel loop + do i = 1, n + x%v(i) = abs(x%v(i)) + end do + call x%set_dev() + end subroutine c_oacc_absval1 + + subroutine c_oacc_absval2(x, y) + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + class(psb_c_base_vect_type), intent(inout) :: y + integer(psb_ipk_) :: n + integer(psb_ipk_) :: i + + n = min(size(x%v), size(y%v)) + select type (yy => y) + class is (psb_c_vect_oacc) + if (x%is_host()) call x%sync() + if (yy%is_host()) call yy%sync() + !$acc parallel loop + do i = 1, n + yy%v(i) = abs(x%v(i)) + end do + class default + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + call x%psb_c_base_vect_type%absval(y) + end select + end subroutine c_oacc_absval2 + + subroutine c_oacc_scal(alpha, x) + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + complex(psb_spk_), intent(in) :: alpha + integer(psb_ipk_) :: info + integer(psb_ipk_) :: i + + if (x%is_host()) call x%sync_space() + !$acc parallel loop + do i = 1, size(x%v) + x%v(i) = alpha * x%v(i) + end do + call x%set_dev() + end subroutine c_oacc_scal + + function c_oacc_nrm2(n, x) result(res) + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + real(psb_spk_) :: res + integer(psb_ipk_) :: info + real(psb_spk_) :: sum + integer(psb_ipk_) :: i + + if (x%is_host()) call x%sync_space() + sum = 0.0 + !$acc parallel loop reduction(+:sum) + do i = 1, n + sum = sum + abs(x%v(i))**2 + end do + res = sqrt(sum) + end function c_oacc_nrm2 + + function c_oacc_amax(n, x) result(res) + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + real(psb_spk_) :: res + integer(psb_ipk_) :: info + real(psb_spk_) :: max_val + integer(psb_ipk_) :: i + + if (x%is_host()) call x%sync_space() + max_val = -huge(0.0) + !$acc parallel loop reduction(max:max_val) + do i = 1, n + if (abs(x%v(i)) > max_val) max_val = abs(x%v(i)) + end do + res = max_val + end function c_oacc_amax + + function c_oacc_asum(n, x) result(res) + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + real(psb_spk_) :: res + integer(psb_ipk_) :: info + complex(psb_spk_) :: sum + integer(psb_ipk_) :: i + + if (x%is_host()) call x%sync_space() + sum = 0.0 + !$acc parallel loop reduction(+:sum) + do i = 1, n + sum = sum + abs(x%v(i)) + end do + res = sum + end function c_oacc_asum + + + subroutine c_oacc_mlt_a(x, y, info) + implicit none + complex(psb_spk_), intent(in) :: x(:) + class(psb_c_vect_oacc), intent(inout) :: y + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: i, n + + info = 0 + if (y%is_dev()) call y%sync_space() + !$acc parallel loop + do i = 1, size(x) + y%v(i) = y%v(i) * x(i) + end do + call y%set_host() + end subroutine c_oacc_mlt_a + + subroutine c_oacc_mlt_a_2(alpha, x, y, beta, z, info) + implicit none + complex(psb_spk_), intent(in) :: alpha, beta + complex(psb_spk_), intent(in) :: x(:) + complex(psb_spk_), intent(in) :: y(:) + class(psb_c_vect_oacc), intent(inout) :: z + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: i, n + + info = 0 + if (z%is_dev()) call z%sync_space() + !$acc parallel loop + do i = 1, size(x) + z%v(i) = alpha * x(i) * y(i) + beta * z%v(i) + end do + call z%set_host() + end subroutine c_oacc_mlt_a_2 + + +!!$ subroutine c_oacc_mlt_v(x, y, info) +!!$ implicit none +!!$ class(psb_c_base_vect_type), intent(inout) :: x +!!$ class(psb_c_vect_oacc), intent(inout) :: y +!!$ integer(psb_ipk_), intent(out) :: info +!!$ +!!$ integer(psb_ipk_) :: i, n +!!$ +!!$ info = 0 +!!$ n = min(x%get_nrows(), y%get_nrows()) +!!$ select type(xx => x) +!!$ type is (psb_c_base_vect_type) +!!$ if (y%is_dev()) call y%sync() +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ y%v(i) = y%v(i) * xx%v(i) +!!$ end do +!!$ call y%set_host() +!!$ class default +!!$ if (xx%is_dev()) call xx%sync() +!!$ if (y%is_dev()) call y%sync() +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ y%v(i) = y%v(i) * xx%v(i) +!!$ end do +!!$ call y%set_host() +!!$ end select +!!$ end subroutine c_oacc_mlt_v +!!$ +!!$ subroutine c_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) +!!$ use psi_serial_mod +!!$ use psb_string_mod +!!$ implicit none +!!$ complex(psb_spk_), intent(in) :: alpha, beta +!!$ class(psb_c_base_vect_type), intent(inout) :: x +!!$ class(psb_c_base_vect_type), intent(inout) :: y +!!$ class(psb_c_vect_oacc), intent(inout) :: z +!!$ integer(psb_ipk_), intent(out) :: info +!!$ character(len=1), intent(in), optional :: conjgx, conjgy +!!$ integer(psb_ipk_) :: i, n +!!$ logical :: conjgx_, conjgy_ +!!$ +!!$ conjgx_ = .false. +!!$ conjgy_ = .false. +!!$ if (present(conjgx)) conjgx_ = (psb_toupper(conjgx) == 'C') +!!$ if (present(conjgy)) conjgy_ = (psb_toupper(conjgy) == 'C') +!!$ +!!$ n = min(x%get_nrows(), y%get_nrows(), z%get_nrows()) +!!$ +!!$ info = 0 +!!$ select type(xx => x) +!!$ class is (psb_c_vect_oacc) +!!$ select type (yy => y) +!!$ class is (psb_c_vect_oacc) +!!$ if (xx%is_host()) call xx%sync_space() +!!$ if (yy%is_host()) call yy%sync_space() +!!$ if ((beta /= czero) .and. (z%is_host())) call z%sync_space() +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) +!!$ end do +!!$ call z%set_dev() +!!$ class default +!!$ if (xx%is_dev()) call xx%sync_space() +!!$ if (yy%is_dev()) call yy%sync() +!!$ if ((beta /= czero) .and. (z%is_dev())) call z%sync_space() +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) +!!$ end do +!!$ call z%set_host() +!!$ end select +!!$ class default +!!$ if (x%is_dev()) call x%sync() +!!$ if (y%is_dev()) call y%sync() +!!$ if ((beta /= czero) .and. (z%is_dev())) call z%sync_space() +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ z%v(i) = alpha * x%v(i) * y%v(i) + beta * z%v(i) +!!$ end do +!!$ call z%set_host() +!!$ end select +!!$ end subroutine c_oacc_mlt_v_2 + + + subroutine c_oacc_axpby_v(m, alpha, x, beta, y, info) + !use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_c_base_vect_type), intent(inout) :: x + class(psb_c_vect_oacc), intent(inout) :: y + complex(psb_spk_), intent(in) :: alpha, beta + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nx, ny, i + + info = psb_success_ + + select type(xx => x) + type is (psb_c_vect_oacc) + if ((beta /= czero) .and. y%is_host()) call y%sync_space() + if (xx%is_host()) call xx%sync_space() + nx = size(xx%v) + ny = size(y%v) + if ((nx < m) .or. (ny < m)) then + info = psb_err_internal_error_ + else + !$acc parallel loop + do i = 1, m + y%v(i) = alpha * xx%v(i) + beta * y%v(i) + end do + end if + call y%set_dev() + class default + if ((alpha /= czero) .and. (x%is_dev())) call x%sync() + call y%axpby(m, alpha, x%v, beta, info) + end select + end subroutine c_oacc_axpby_v + + subroutine c_oacc_axpby_a(m, alpha, x, beta, y, info) + !use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + complex(psb_spk_), intent(in) :: x(:) + class(psb_c_vect_oacc), intent(inout) :: y + complex(psb_spk_), intent(in) :: alpha, beta + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: i + + if ((beta /= czero) .and. (y%is_dev())) call y%sync_space() + !$acc parallel loop + do i = 1, m + y%v(i) = alpha * x(i) + beta * y%v(i) + end do + call y%set_host() + end subroutine c_oacc_axpby_a + + subroutine c_oacc_abgdxyz(m, alpha, beta, gamma, delta, x, y, z, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_c_base_vect_type), intent(inout) :: x + class(psb_c_base_vect_type), intent(inout) :: y + class(psb_c_vect_oacc), intent(inout) :: z + complex(psb_spk_), intent(in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nx, ny, nz, i + logical :: gpu_done + + info = psb_success_ + gpu_done = .false. + + select type(xx => x) + class is (psb_c_vect_oacc) + select type(yy => y) + class is (psb_c_vect_oacc) + select type(zz => z) + class is (psb_c_vect_oacc) + if ((beta /= czero) .and. yy%is_host()) call yy%sync_space() + if ((delta /= czero) .and. zz%is_host()) call zz%sync_space() + if (xx%is_host()) call xx%sync_space() + nx = size(xx%v) + ny = size(yy%v) + nz = size(zz%v) + if ((nx < m) .or. (ny < m) .or. (nz < m)) then + info = psb_err_internal_error_ + else + !$acc parallel loop + do i = 1, m + yy%v(i) = alpha * xx%v(i) + beta * yy%v(i) + zz%v(i) = gamma * yy%v(i) + delta * zz%v(i) + end do + end if + call yy%set_dev() + call zz%set_dev() + gpu_done = .true. + end select + end select + end select + + if (.not. gpu_done) then + if (x%is_host()) call x%sync() + if (y%is_host()) call y%sync() + if (z%is_host()) call z%sync() + call y%axpby(m, alpha, x, beta, info) + call z%axpby(m, gamma, y, delta, info) + end if + end subroutine c_oacc_abgdxyz + + subroutine c_oacc_sctb_buf(i, n, idx, beta, y) + use psb_base_mod + implicit none + integer(psb_ipk_) :: i, n + class(psb_i_base_vect_type) :: idx + complex(psb_spk_) :: beta + class(psb_c_vect_oacc) :: y + integer(psb_ipk_) :: info + + if (.not.allocated(y%combuf)) then + call psb_errpush(psb_err_alloc_dealloc_, 'sctb_buf') + return + end if + + select type(ii => idx) + class is (psb_i_vect_oacc) + if (ii%is_host()) call ii%sync_space(info) + if (y%is_host()) call y%sync_space() + + !$acc parallel loop + do i = 1, n + y%v(ii%v(i)) = beta * y%v(ii%v(i)) + y%combuf(i) + end do + + class default + !$acc parallel loop + do i = 1, n + y%v(idx%v(i)) = beta * y%v(idx%v(i)) + y%combuf(i) + end do + end select + end subroutine c_oacc_sctb_buf + + subroutine c_oacc_sctb_x(i, n, idx, x, beta, y) + use psb_base_mod + implicit none + integer(psb_ipk_):: i, n + class(psb_i_base_vect_type) :: idx + complex(psb_spk_) :: beta, x(:) + class(psb_c_vect_oacc) :: y + integer(psb_ipk_) :: info, ni + + select type(ii => idx) + class is (psb_i_vect_oacc) + if (ii%is_host()) call ii%sync_space(info) + class default + call psb_errpush(info, 'c_oacc_sctb_x') + return + end select + + if (y%is_host()) call y%sync_space() + + !$acc parallel loop + do i = 1, n + y%v(idx%v(i)) = beta * y%v(idx%v(i)) + x(i) + end do + + call y%set_dev() + end subroutine c_oacc_sctb_x + + + + subroutine c_oacc_sctb(n, idx, x, beta, y) + use psb_base_mod + implicit none + integer(psb_ipk_) :: n + integer(psb_ipk_) :: idx(:) + complex(psb_spk_) :: beta, x(:) + class(psb_c_vect_oacc) :: y + integer(psb_ipk_) :: info + integer(psb_ipk_) :: i + + if (n == 0) return + if (y%is_dev()) call y%sync_space() + + !$acc parallel loop + do i = 1, n + y%v(idx(i)) = beta * y%v(idx(i)) + x(i) + end do + + call y%set_host() + end subroutine c_oacc_sctb + + + subroutine c_oacc_gthzbuf(i, n, idx, x) + use psb_base_mod + implicit none + integer(psb_ipk_) :: i, n + class(psb_i_base_vect_type) :: idx + class(psb_c_vect_oacc) :: x + integer(psb_ipk_) :: info + + info = 0 + if (.not.allocated(x%combuf)) then + call psb_errpush(psb_err_alloc_dealloc_, 'gthzbuf') + return + end if + + select type(ii => idx) + class is (psb_i_vect_oacc) + if (ii%is_host()) call ii%sync_space(info) + class default + call psb_errpush(info, 'c_oacc_gthzbuf') + return + end select + + if (x%is_host()) call x%sync_space() + + !$acc parallel loop + do i = 1, n + x%combuf(i) = x%v(idx%v(i)) + end do + end subroutine c_oacc_gthzbuf + + subroutine c_oacc_gthzv_x(i, n, idx, x, y) + use psb_base_mod + implicit none + integer(psb_ipk_) :: i, n + class(psb_i_base_vect_type):: idx + complex(psb_spk_) :: y(:) + class(psb_c_vect_oacc):: x + integer(psb_ipk_) :: info + + info = 0 + + select type(ii => idx) + class is (psb_i_vect_oacc) + if (ii%is_host()) call ii%sync_space(info) + class default + call psb_errpush(info, 'c_oacc_gthzv_x') + return + end select + + if (x%is_host()) call x%sync_space() + + !$acc parallel loop + do i = 1, n + y(i) = x%v(idx%v(i)) + end do + end subroutine c_oacc_gthzv_x + + subroutine c_oacc_ins_v(n, irl, val, dupl, x, info) + use psi_serial_mod + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n, dupl + class(psb_i_base_vect_type), intent(inout) :: irl + class(psb_c_base_vect_type), intent(inout) :: val + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i, isz + logical :: done_oacc + + info = 0 + if (psb_errstatus_fatal()) return + + done_oacc = .false. + select type(virl => irl) + type is (psb_i_vect_oacc) + select type(vval => val) + type is (psb_c_vect_oacc) + if (vval%is_host()) call vval%sync_space() + if (virl%is_host()) call virl%sync_space(info) + if (x%is_host()) call x%sync_space() + !$acc parallel loop + do i = 1, n + x%v(virl%v(i)) = vval%v(i) + end do + call x%set_dev() + done_oacc = .true. + end select + end select + + if (.not.done_oacc) then + select type(virl => irl) + type is (psb_i_vect_oacc) + if (virl%is_dev()) call virl%sync_space(info) + end select + select type(vval => val) + type is (psb_c_vect_oacc) + if (vval%is_dev()) call vval%sync_space() + end select + call x%ins(n, irl%v, val%v, dupl, info) + end if + + if (info /= 0) then + call psb_errpush(info, 'oacc_vect_ins') + return + end if + + end subroutine c_oacc_ins_v + + + + subroutine c_oacc_ins_a(n, irl, val, dupl, x, info) + use psi_serial_mod + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n, dupl + integer(psb_ipk_), intent(in) :: irl(:) + complex(psb_spk_), intent(in) :: val(:) + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i + + info = 0 + if (x%is_dev()) call x%sync_space() + call x%psb_c_base_vect_type%ins(n, irl, val, dupl, info) + call x%set_host() + !$acc update device(x%v) + + end subroutine c_oacc_ins_a + + + + subroutine c_oacc_bld_mn(x, n) + use psb_base_mod + implicit none + integer(psb_mpk_), intent(in) :: n + class(psb_c_vect_oacc), intent(inout) :: x + integer(psb_ipk_) :: info + + call x%all(n, info) + if (info /= 0) then + call psb_errpush(info, 'c_oacc_bld_mn', i_err=(/n, n, n, n, n/)) + end if + call x%set_host() + !$acc update device(x%v) + + end subroutine c_oacc_bld_mn + + + subroutine c_oacc_bld_x(x, this) + use psb_base_mod + implicit none + complex(psb_spk_), intent(in) :: this(:) + class(psb_c_vect_oacc), intent(inout) :: x + integer(psb_ipk_) :: info + + call psb_realloc(size(this), x%v, info) + if (info /= 0) then + info = psb_err_alloc_request_ + call psb_errpush(info, 'c_oacc_bld_x', & + i_err=(/size(this), izero, izero, izero, izero/)) + return + end if + + x%v(:) = this(:) + call x%set_host() + !$acc update device(x%v) + + end subroutine c_oacc_bld_x + + + subroutine c_oacc_asb_m(n, x, info) + use psb_base_mod + implicit none + integer(psb_mpk_), intent(in) :: n + class(psb_c_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + integer(psb_mpk_) :: nd + + info = psb_success_ + + if (x%is_dev()) then + nd = size(x%v) + if (nd < n) then + call x%sync() + call x%psb_c_base_vect_type%asb(n, info) + if (info == psb_success_) call x%sync_space() + call x%set_host() + end if + else + if (size(x%v) < n) then + call x%psb_c_base_vect_type%asb(n, info) + if (info == psb_success_) call x%sync_space() + call x%set_host() + end if + end if + end subroutine c_oacc_asb_m + + + + subroutine c_oacc_set_scal(x, val, first, last) + class(psb_c_vect_oacc), intent(inout) :: x + complex(psb_spk_), intent(in) :: val + integer(psb_ipk_), optional :: first, last + + integer(psb_ipk_) :: first_, last_ + first_ = 1 + last_ = x%get_nrows() + if (present(first)) first_ = max(1, first) + if (present(last)) last_ = min(last, last_) + + !$acc parallel loop + do i = first_, last_ + x%v(i) = val + end do + !$acc end parallel loop + + call x%set_dev() + end subroutine c_oacc_set_scal + + + + subroutine c_oacc_zero(x) + use psi_serial_mod + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + call x%set_dev() + call x%set_scal(czero) + end subroutine c_oacc_zero + + function c_oacc_get_nrows(x) result(res) + implicit none + class(psb_c_vect_oacc), intent(in) :: x + integer(psb_ipk_) :: res + + if (allocated(x%v)) res = size(x%v) + end function c_oacc_get_nrows + + function c_oacc_get_fmt() result(res) + implicit none + character(len=5) :: res + res = "cOACC" + + end function c_oacc_get_fmt + + function c_oacc_vect_dot(n, x, y) result(res) + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + class(psb_c_base_vect_type), intent(inout) :: y + integer(psb_ipk_), intent(in) :: n + complex(psb_spk_) :: res + complex(psb_spk_), external :: ddot + integer(psb_ipk_) :: info + integer(psb_ipk_) :: i + + res = czero + + select type(yy => y) + type is (psb_c_base_vect_type) + if (x%is_dev()) call x%sync() + res = ddot(n, x%v, 1, yy%v, 1) + type is (psb_c_vect_oacc) + if (x%is_host()) call x%sync() + if (yy%is_host()) call yy%sync() + + !$acc parallel loop reduction(+:res) present(x%v, yy%v) + do i = 1, n + res = res + x%v(i) * yy%v(i) + end do + !$acc end parallel loop + + class default + call x%sync() + res = y%dot(n, x%v) + end select + + end function c_oacc_vect_dot + + + + + function c_oacc_dot_a(n, x, y) result(res) + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + complex(psb_spk_), intent(in) :: y(:) + integer(psb_ipk_), intent(in) :: n + complex(psb_spk_) :: res + complex(psb_spk_), external :: ddot + + if (x%is_dev()) call x%sync() + res = ddot(n, y, 1, x%v, 1) + + end function c_oacc_dot_a + + ! subroutine c_oacc_set_vect(x,y) + ! implicit none + ! class(psb_c_vect_oacc), intent(inout) :: x + ! complex(psb_spk_), intent(in) :: y(:) + ! integer(psb_ipk_) :: info + + ! if (size(x%v) /= size(y)) then + ! call x%free(info) + ! call x%all(size(y),info) + ! end if + ! x%v(:) = y(:) + ! call x%set_host() + ! end subroutine c_oacc_set_vect + + subroutine c_oacc_to_dev(v) + implicit none + complex(psb_spk_) :: v(:) + !$acc update device(v) + end subroutine c_oacc_to_dev + + subroutine c_oacc_to_host(v) + implicit none + complex(psb_spk_) :: v(:) + !$acc update self(v) + end subroutine c_oacc_to_host + + subroutine c_oacc_sync_space(x) + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + if (allocated(x%v)) then + call c_oacc_create_dev(x%v) + end if + contains + subroutine c_oacc_create_dev(v) + implicit none + complex(psb_spk_) :: v(:) + !$acc enter data copyin(v) + end subroutine c_oacc_create_dev + end subroutine c_oacc_sync_space + + subroutine c_oacc_sync(x) + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + if (x%is_dev()) then + call c_oacc_to_host(x%v) + end if + if (x%is_host()) then + call c_oacc_to_dev(x%v) + end if + call x%set_sync() + end subroutine c_oacc_sync + + subroutine c_oacc_set_host(x) + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + + x%state = is_host + end subroutine c_oacc_set_host + + subroutine c_oacc_set_dev(x) + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + + x%state = is_dev + end subroutine c_oacc_set_dev + + subroutine c_oacc_set_sync(x) + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + + x%state = is_sync + end subroutine c_oacc_set_sync + + function c_oacc_is_dev(x) result(res) + implicit none + class(psb_c_vect_oacc), intent(in) :: x + logical :: res + + res = (x%state == is_dev) + end function c_oacc_is_dev + + function c_oacc_is_host(x) result(res) + implicit none + class(psb_c_vect_oacc), intent(in) :: x + logical :: res + + res = (x%state == is_host) + end function c_oacc_is_host + + function c_oacc_is_sync(x) result(res) + implicit none + class(psb_c_vect_oacc), intent(in) :: x + logical :: res + + res = (x%state == is_sync) + end function c_oacc_is_sync + + subroutine c_oacc_vect_all(n, x, info) + use psi_serial_mod + use psb_realloc_mod + implicit none + integer(psb_ipk_), intent(in) :: n + class(psb_c_vect_oacc), intent(out) :: x + integer(psb_ipk_), intent(out) :: info + + call psb_realloc(n, x%v, info) + if (info == 0) then + call x%set_host() + !$acc enter data create(x%v) + call x%sync_space() + end if + if (info /= 0) then + info = psb_err_alloc_request_ + call psb_errpush(info, 'c_oacc_all', & + i_err=(/n, n, n, n, n/)) + end if + end subroutine c_oacc_vect_all + + + subroutine c_oacc_vect_free(x, info) + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + info = 0 + if (allocated(x%v)) then + !$acc exit data delete(x%v) finalize + deallocate(x%v, stat=info) + end if + + end subroutine c_oacc_vect_free + + function c_oacc_get_size(x) result(res) + implicit none + class(psb_c_vect_oacc), intent(inout) :: x + integer(psb_ipk_) :: res + + if (x%is_dev()) call x%sync() + res = size(x%v) + end function c_oacc_get_size + +end module psb_c_oacc_vect_mod diff --git a/openacc/psb_d_oacc_csr_mat_mod.F90 b/openacc/psb_d_oacc_csr_mat_mod.F90 index ca4199a8..8b7e111e 100644 --- a/openacc/psb_d_oacc_csr_mat_mod.F90 +++ b/openacc/psb_d_oacc_csr_mat_mod.F90 @@ -175,8 +175,6 @@ contains return end subroutine d_oacc_csr_free - - function d_oacc_csr_sizeof(a) result(res) implicit none class(psb_d_oacc_csr_sparse_mat), intent(in) :: a @@ -341,6 +339,5 @@ contains !$acc update self(v) end subroutine i_oacc_csr_to_host - end module psb_d_oacc_csr_mat_mod diff --git a/openacc/psb_d_oacc_vect_mod.F90 b/openacc/psb_d_oacc_vect_mod.F90 index 3385f1ec..7d51766d 100644 --- a/openacc/psb_d_oacc_vect_mod.F90 +++ b/openacc/psb_d_oacc_vect_mod.F90 @@ -41,7 +41,7 @@ module psb_d_oacc_vect_mod procedure, pass(y) :: sctb_x => d_oacc_sctb_x procedure, pass(y) :: sctb_buf => d_oacc_sctb_buf - procedure, pass(x) :: get_size => oacc_get_size + procedure, pass(x) :: get_size => d_oacc_get_size procedure, pass(x) :: dot_v => d_oacc_vect_dot procedure, pass(x) :: dot_a => d_oacc_dot_a procedure, pass(y) :: axpby_v => d_oacc_axpby_v @@ -49,8 +49,8 @@ module psb_d_oacc_vect_mod procedure, pass(z) :: abgdxyz => d_oacc_abgdxyz procedure, pass(y) :: mlt_a => d_oacc_mlt_a procedure, pass(z) :: mlt_a_2 => d_oacc_mlt_a_2 - procedure, pass(y) :: mlt_v => psb_d_oacc_mlt_v - procedure, pass(z) :: mlt_v_2 => psb_d_oacc_mlt_v_2 + procedure, pass(y) :: mlt_v => d_oacc_mlt_v + procedure, pass(z) :: mlt_v_2 => d_oacc_mlt_v_2 procedure, pass(x) :: scal => d_oacc_scal procedure, pass(x) :: nrm2 => d_oacc_nrm2 procedure, pass(x) :: amax => d_oacc_amax @@ -60,20 +60,20 @@ module psb_d_oacc_vect_mod end type psb_d_vect_oacc - real(psb_dpk_), allocatable :: v1(:),v2(:),p(:) - interface - module subroutine psb_d_oacc_mlt_v(x, y, info) + subroutine d_oacc_mlt_v(x, y, info) + import implicit none class(psb_d_base_vect_type), intent(inout) :: x class(psb_d_vect_oacc), intent(inout) :: y integer(psb_ipk_), intent(out) :: info - end subroutine psb_d_oacc_mlt_v + end subroutine d_oacc_mlt_v end interface interface - module subroutine psb_d_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) + subroutine d_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) + import implicit none real(psb_dpk_), intent(in) :: alpha, beta class(psb_d_base_vect_type), intent(inout) :: x @@ -81,7 +81,7 @@ module psb_d_oacc_vect_mod class(psb_d_vect_oacc), intent(inout) :: z integer(psb_ipk_), intent(out) :: info character(len=1), intent(in), optional :: conjgx, conjgy - end subroutine psb_d_oacc_mlt_v_2 + end subroutine d_oacc_mlt_v_2 end interface contains @@ -151,7 +151,7 @@ contains sum = 0.0 !$acc parallel loop reduction(+:sum) do i = 1, n - sum = sum + x%v(i) * x%v(i) + sum = sum + abs(x%v(i))**2 end do res = sqrt(sum) end function d_oacc_nrm2 @@ -169,7 +169,7 @@ contains max_val = -huge(0.0) !$acc parallel loop reduction(max:max_val) do i = 1, n - if (x%v(i) > max_val) max_val = x%v(i) + if (abs(x%v(i)) > max_val) max_val = abs(x%v(i)) end do res = max_val end function d_oacc_amax @@ -923,41 +923,13 @@ contains end subroutine d_oacc_vect_free - function oacc_get_size(x) result(res) + function d_oacc_get_size(x) result(res) implicit none class(psb_d_vect_oacc), intent(inout) :: x integer(psb_ipk_) :: res if (x%is_dev()) call x%sync() res = size(x%v) - end function oacc_get_size - -!!$ -!!$ subroutine initialize(N) -!!$ integer(psb_ipk_) :: N -!!$ integer(psb_ipk_) :: i -!!$ allocate(v1(N),v2(N),p(N)) -!!$ !$acc enter data create(v1,v2,p) -!!$ !$acc parallel -!!$ !$acc loop -!!$ do i=1,n -!!$ v1(i) = i -!!$ v2(i) = n+i -!!$ end do -!!$ !$acc end parallel -!!$ end subroutine initialize -!!$ subroutine finalize_dev() -!!$ !$acc exit data delete(v1,v2,p) -!!$ end subroutine finalize_dev -!!$ subroutine finalize_host() -!!$ deallocate(v1,v2,p) -!!$ end subroutine finalize_host -!!$ subroutine to_dev() -!!$ !$acc update device(v1,v2) -!!$ end subroutine to_dev -!!$ subroutine to_host() -!!$ !$acc update self(v1,v2) -!!$ end subroutine to_host -!!$ + end function d_oacc_get_size end module psb_d_oacc_vect_mod diff --git a/openacc/psb_s_oacc_csr_mat_mod.F90 b/openacc/psb_s_oacc_csr_mat_mod.F90 new file mode 100644 index 00000000..89b10d08 --- /dev/null +++ b/openacc/psb_s_oacc_csr_mat_mod.F90 @@ -0,0 +1,343 @@ +module psb_s_oacc_csr_mat_mod + + use iso_c_binding + use psb_s_mat_mod + use psb_s_oacc_vect_mod + !use oaccsparse_mod + + integer(psb_ipk_), parameter, private :: is_host = -1 + integer(psb_ipk_), parameter, private :: is_sync = 0 + integer(psb_ipk_), parameter, private :: is_dev = 1 + + type, extends(psb_s_csr_sparse_mat) :: psb_s_oacc_csr_sparse_mat + integer(psb_ipk_) :: devstate = is_host + contains + procedure, nopass :: get_fmt => s_oacc_csr_get_fmt + procedure, pass(a) :: sizeof => s_oacc_csr_sizeof + procedure, pass(a) :: vect_mv => psb_s_oacc_csr_vect_mv + procedure, pass(a) :: in_vect_sv => psb_s_oacc_csr_inner_vect_sv + procedure, pass(a) :: csmm => psb_s_oacc_csr_csmm + procedure, pass(a) :: csmv => psb_s_oacc_csr_csmv + procedure, pass(a) :: scals => psb_s_oacc_csr_scals + procedure, pass(a) :: scalv => psb_s_oacc_csr_scal + procedure, pass(a) :: reallocate_nz => psb_s_oacc_csr_reallocate_nz + procedure, pass(a) :: allocate_mnnz => psb_s_oacc_csr_allocate_mnnz + procedure, pass(a) :: cp_from_coo => psb_s_oacc_csr_cp_from_coo + procedure, pass(a) :: cp_from_fmt => psb_s_oacc_csr_cp_from_fmt + procedure, pass(a) :: mv_from_coo => psb_s_oacc_csr_mv_from_coo + procedure, pass(a) :: mv_from_fmt => psb_s_oacc_csr_mv_from_fmt + procedure, pass(a) :: free => s_oacc_csr_free + procedure, pass(a) :: mold => psb_s_oacc_csr_mold + procedure, pass(a) :: all => s_oacc_csr_all + procedure, pass(a) :: is_host => s_oacc_csr_is_host + procedure, pass(a) :: is_sync => s_oacc_csr_is_sync + procedure, pass(a) :: is_dev => s_oacc_csr_is_dev + procedure, pass(a) :: set_host => s_oacc_csr_set_host + procedure, pass(a) :: set_sync => s_oacc_csr_set_sync + procedure, pass(a) :: set_dev => s_oacc_csr_set_dev + procedure, pass(a) :: sync_space => s_oacc_csr_sync_space + procedure, pass(a) :: sync => s_oacc_csr_sync + end type psb_s_oacc_csr_sparse_mat + + interface + module subroutine psb_s_oacc_csr_mold(a,b,info) + class(psb_s_oacc_csr_sparse_mat), intent(in) :: a + class(psb_s_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_s_oacc_csr_mold + end interface + + interface + module subroutine psb_s_oacc_csr_cp_from_fmt(a,b,info) + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_s_base_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_s_oacc_csr_cp_from_fmt + end interface + + interface + module subroutine psb_s_oacc_csr_mv_from_coo(a,b,info) + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_s_coo_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_s_oacc_csr_mv_from_coo + end interface + + interface + module subroutine psb_s_oacc_csr_mv_from_fmt(a,b,info) + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_s_base_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_s_oacc_csr_mv_from_fmt + end interface + + interface + module subroutine psb_s_oacc_csr_vect_mv(alpha, a, x, beta, y, info, trans) + class(psb_s_oacc_csr_sparse_mat), intent(in) :: a + real(psb_spk_), intent(in) :: alpha, beta + class(psb_s_base_vect_type), intent(inout) :: x, y + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + end subroutine psb_s_oacc_csr_vect_mv + end interface + + interface + module subroutine psb_s_oacc_csr_inner_vect_sv(alpha, a, x, beta, y, info, trans) + class(psb_s_oacc_csr_sparse_mat), intent(in) :: a + real(psb_spk_), intent(in) :: alpha, beta + class(psb_s_base_vect_type), intent(inout) :: x,y + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + end subroutine psb_s_oacc_csr_inner_vect_sv + end interface + + interface + module subroutine psb_s_oacc_csr_csmm(alpha, a, x, beta, y, info, trans) + class(psb_s_oacc_csr_sparse_mat), intent(in) :: a + real(psb_spk_), intent(in) :: alpha, beta, x(:,:) + real(psb_spk_), intent(inout) :: y(:,:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + end subroutine psb_s_oacc_csr_csmm + end interface + + interface + module subroutine psb_s_oacc_csr_csmv(alpha, a, x, beta, y, info, trans) + class(psb_s_oacc_csr_sparse_mat), intent(in) :: a + real(psb_spk_), intent(in) :: alpha, beta, x(:) + real(psb_spk_), intent(inout) :: y(:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + end subroutine psb_s_oacc_csr_csmv + end interface + + interface + module subroutine psb_s_oacc_csr_scals(d, a, info) + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + real(psb_spk_), intent(in) :: d + integer(psb_ipk_), intent(out) :: info + end subroutine psb_s_oacc_csr_scals + end interface + + interface + module subroutine psb_s_oacc_csr_scal(d,a,info,side) + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + real(psb_spk_), intent(in) :: d(:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: side + end subroutine psb_s_oacc_csr_scal + end interface + + interface + module subroutine psb_s_oacc_csr_reallocate_nz(nz,a) + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + integer(psb_ipk_), intent(in) :: nz + end subroutine psb_s_oacc_csr_reallocate_nz + end interface + + interface + module subroutine psb_s_oacc_csr_allocate_mnnz(m,n,a,nz) + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + integer(psb_ipk_), intent(in) :: m,n + integer(psb_ipk_), intent(in), optional :: nz + end subroutine psb_s_oacc_csr_allocate_mnnz + end interface + + interface + module subroutine psb_s_oacc_csr_cp_from_coo(a,b,info) + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_s_coo_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_s_oacc_csr_cp_from_coo + end interface + +contains + + + subroutine s_oacc_csr_free(a) + use psb_base_mod + implicit none + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + integer(psb_ipk_) :: info + + if (allocated(a%val)) then + !$acc exit data delete(a%val) + end if + if (allocated(a%ja)) then + !$acc exit data delete(a%ja) + end if + if (allocated(a%irp)) then + !$acc exit data delete(a%irp) + end if + + call a%psb_s_csr_sparse_mat%free() + + return + end subroutine s_oacc_csr_free + + function s_oacc_csr_sizeof(a) result(res) + implicit none + class(psb_s_oacc_csr_sparse_mat), intent(in) :: a + integer(psb_epk_) :: res + + if (a%is_dev()) call a%sync() + + res = 8 + res = res + psb_sizeof_sp * size(a%val) + res = res + psb_sizeof_ip * size(a%irp) + res = res + psb_sizeof_ip * size(a%ja) + + end function s_oacc_csr_sizeof + + + function s_oacc_csr_get_fmt() result(res) + implicit none + character(len=5) :: res + res = 'CSR_oacc' + end function s_oacc_csr_get_fmt + + subroutine s_oacc_csr_all(m, n, nz, a, info) + implicit none + integer(psb_ipk_), intent(in) :: m, n, nz + class(psb_s_oacc_csr_sparse_mat), intent(out) :: a + integer(psb_ipk_), intent(out) :: info + + info = 0 + if (allocated(a%val)) then + !$acc exit data delete(a%val) finalize + deallocate(a%val, stat=info) + end if + if (allocated(a%ja)) then + !$acc exit data delete(a%ja) finalize + deallocate(a%ja, stat=info) + end if + if (allocated(a%irp)) then + !$acc exit data delete(a%irp) finalize + deallocate(a%irp, stat=info) + end if + + call a%set_nrows(m) + call a%set_ncols(n) + + allocate(a%val(nz),stat=info) + allocate(a%ja(nz),stat=info) + allocate(a%irp(m+1),stat=info) + if (info == 0) call a%set_host() + if (info == 0) call a%sync_space() + end subroutine s_oacc_csr_all + + function s_oacc_csr_is_host(a) result(res) + implicit none + class(psb_s_oacc_csr_sparse_mat), intent(in) :: a + logical :: res + + res = (a%devstate == is_host) + end function s_oacc_csr_is_host + + function s_oacc_csr_is_sync(a) result(res) + implicit none + class(psb_s_oacc_csr_sparse_mat), intent(in) :: a + logical :: res + + res = (a%devstate == is_sync) + end function s_oacc_csr_is_sync + + function s_oacc_csr_is_dev(a) result(res) + implicit none + class(psb_s_oacc_csr_sparse_mat), intent(in) :: a + logical :: res + + res = (a%devstate == is_dev) + end function s_oacc_csr_is_dev + + subroutine s_oacc_csr_set_host(a) + implicit none + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + + a%devstate = is_host + end subroutine s_oacc_csr_set_host + + subroutine s_oacc_csr_set_sync(a) + implicit none + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + + a%devstate = is_sync + end subroutine s_oacc_csr_set_sync + + subroutine s_oacc_csr_set_dev(a) + implicit none + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + + a%devstate = is_dev + end subroutine s_oacc_csr_set_dev + + subroutine s_oacc_csr_sync_space(a) + implicit none + class(psb_s_oacc_csr_sparse_mat), intent(inout) :: a + if (allocated(a%val)) then + call s_oacc_create_dev(a%val) + end if + if (allocated(a%ja)) then + call i_oacc_create_dev(a%ja) + end if + if (allocated(a%irp)) then + call i_oacc_create_dev(a%irp) + end if + contains + subroutine s_oacc_create_dev(v) + implicit none + real(psb_spk_), intent(in) :: v(:) + !$acc enter data copyin(v) + end subroutine s_oacc_create_dev + subroutine i_oacc_create_dev(v) + implicit none + integer(psb_ipk_), intent(in) :: v(:) + !$acc enter data copyin(v) + end subroutine i_oacc_create_dev + end subroutine s_oacc_csr_sync_space + + subroutine s_oacc_csr_sync(a) + implicit none + class(psb_s_oacc_csr_sparse_mat), target, intent(in) :: a + class(psb_s_oacc_csr_sparse_mat), pointer :: tmpa + integer(psb_ipk_) :: info + + tmpa => a + if (a%is_dev()) then + call s_oacc_csr_to_host(a%val) + call i_oacc_csr_to_host(a%ja) + call i_oacc_csr_to_host(a%irp) + else if (a%is_host()) then + call s_oacc_csr_to_dev(a%val) + call i_oacc_csr_to_dev(a%ja) + call i_oacc_csr_to_dev(a%irp) + end if + call tmpa%set_sync() + end subroutine s_oacc_csr_sync + + subroutine s_oacc_csr_to_dev(v) + implicit none + real(psb_spk_), intent(in) :: v(:) + !$acc update device(v) + end subroutine s_oacc_csr_to_dev + + subroutine s_oacc_csr_to_host(v) + implicit none + real(psb_spk_), intent(in) :: v(:) + !$acc update self(v) + end subroutine s_oacc_csr_to_host + + subroutine i_oacc_csr_to_dev(v) + implicit none + integer(psb_ipk_), intent(in) :: v(:) + !$acc update device(v) + end subroutine i_oacc_csr_to_dev + + subroutine i_oacc_csr_to_host(v) + implicit none + integer(psb_ipk_), intent(in) :: v(:) + !$acc update self(v) + end subroutine i_oacc_csr_to_host + +end module psb_s_oacc_csr_mat_mod + diff --git a/openacc/psb_s_oacc_vect_mod.F90 b/openacc/psb_s_oacc_vect_mod.F90 new file mode 100644 index 00000000..36ae7da8 --- /dev/null +++ b/openacc/psb_s_oacc_vect_mod.F90 @@ -0,0 +1,935 @@ +module psb_s_oacc_vect_mod + use iso_c_binding + use psb_const_mod + use psb_error_mod + use psb_s_vect_mod + use psb_i_vect_mod + use psb_i_oacc_vect_mod + + integer(psb_ipk_), parameter, private :: is_host = -1 + integer(psb_ipk_), parameter, private :: is_sync = 0 + integer(psb_ipk_), parameter, private :: is_dev = 1 + + type, extends(psb_s_base_vect_type) :: psb_s_vect_oacc + integer :: state = is_host + + contains + procedure, pass(x) :: get_nrows => s_oacc_get_nrows + procedure, nopass :: get_fmt => s_oacc_get_fmt + + procedure, pass(x) :: all => s_oacc_vect_all + procedure, pass(x) :: zero => s_oacc_zero + procedure, pass(x) :: asb_m => s_oacc_asb_m + procedure, pass(x) :: sync => s_oacc_sync + procedure, pass(x) :: sync_space => s_oacc_sync_space + procedure, pass(x) :: bld_x => s_oacc_bld_x + procedure, pass(x) :: bld_mn => s_oacc_bld_mn + procedure, pass(x) :: free => s_oacc_vect_free + procedure, pass(x) :: ins_a => s_oacc_ins_a + procedure, pass(x) :: ins_v => s_oacc_ins_v + procedure, pass(x) :: is_host => s_oacc_is_host + procedure, pass(x) :: is_dev => s_oacc_is_dev + procedure, pass(x) :: is_sync => s_oacc_is_sync + procedure, pass(x) :: set_host => s_oacc_set_host + procedure, pass(x) :: set_dev => s_oacc_set_dev + procedure, pass(x) :: set_sync => s_oacc_set_sync + procedure, pass(x) :: set_scal => s_oacc_set_scal + + procedure, pass(x) :: gthzv_x => s_oacc_gthzv_x + procedure, pass(x) :: gthzbuf_x => s_oacc_gthzbuf + procedure, pass(y) :: sctb => s_oacc_sctb + procedure, pass(y) :: sctb_x => s_oacc_sctb_x + procedure, pass(y) :: sctb_buf => s_oacc_sctb_buf + + procedure, pass(x) :: get_size => s_oacc_get_size + procedure, pass(x) :: dot_v => s_oacc_vect_dot + procedure, pass(x) :: dot_a => s_oacc_dot_a + procedure, pass(y) :: axpby_v => s_oacc_axpby_v + procedure, pass(y) :: axpby_a => s_oacc_axpby_a + procedure, pass(z) :: abgdxyz => s_oacc_abgdxyz + procedure, pass(y) :: mlt_a => s_oacc_mlt_a + procedure, pass(z) :: mlt_a_2 => s_oacc_mlt_a_2 + procedure, pass(y) :: mlt_v => s_oacc_mlt_v + procedure, pass(z) :: mlt_v_2 => s_oacc_mlt_v_2 + procedure, pass(x) :: scal => s_oacc_scal + procedure, pass(x) :: nrm2 => s_oacc_nrm2 + procedure, pass(x) :: amax => s_oacc_amax + procedure, pass(x) :: asum => s_oacc_asum + procedure, pass(x) :: absval1 => s_oacc_absval1 + procedure, pass(x) :: absval2 => s_oacc_absval2 + + end type psb_s_vect_oacc + + interface + subroutine s_oacc_mlt_v(x, y, info) + import + implicit none + class(psb_s_base_vect_type), intent(inout) :: x + class(psb_s_vect_oacc), intent(inout) :: y + integer(psb_ipk_), intent(out) :: info + end subroutine s_oacc_mlt_v + end interface + + + interface + subroutine s_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) + import + implicit none + real(psb_spk_), intent(in) :: alpha, beta + class(psb_s_base_vect_type), intent(inout) :: x + class(psb_s_base_vect_type), intent(inout) :: y + class(psb_s_vect_oacc), intent(inout) :: z + integer(psb_ipk_), intent(out) :: info + character(len=1), intent(in), optional :: conjgx, conjgy + end subroutine s_oacc_mlt_v_2 + end interface + +contains + + subroutine s_oacc_absval1(x) + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + integer(psb_ipk_) :: n, i + + if (x%is_host()) call x%sync_space() + n = size(x%v) + !$acc parallel loop + do i = 1, n + x%v(i) = abs(x%v(i)) + end do + call x%set_dev() + end subroutine s_oacc_absval1 + + subroutine s_oacc_absval2(x, y) + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + class(psb_s_base_vect_type), intent(inout) :: y + integer(psb_ipk_) :: n + integer(psb_ipk_) :: i + + n = min(size(x%v), size(y%v)) + select type (yy => y) + class is (psb_s_vect_oacc) + if (x%is_host()) call x%sync() + if (yy%is_host()) call yy%sync() + !$acc parallel loop + do i = 1, n + yy%v(i) = abs(x%v(i)) + end do + class default + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + call x%psb_s_base_vect_type%absval(y) + end select + end subroutine s_oacc_absval2 + + subroutine s_oacc_scal(alpha, x) + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + real(psb_spk_), intent(in) :: alpha + integer(psb_ipk_) :: info + integer(psb_ipk_) :: i + + if (x%is_host()) call x%sync_space() + !$acc parallel loop + do i = 1, size(x%v) + x%v(i) = alpha * x%v(i) + end do + call x%set_dev() + end subroutine s_oacc_scal + + function s_oacc_nrm2(n, x) result(res) + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + real(psb_spk_) :: res + integer(psb_ipk_) :: info + real(psb_spk_) :: sum + integer(psb_ipk_) :: i + + if (x%is_host()) call x%sync_space() + sum = 0.0 + !$acc parallel loop reduction(+:sum) + do i = 1, n + sum = sum + abs(x%v(i))**2 + end do + res = sqrt(sum) + end function s_oacc_nrm2 + + function s_oacc_amax(n, x) result(res) + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + real(psb_spk_) :: res + integer(psb_ipk_) :: info + real(psb_spk_) :: max_val + integer(psb_ipk_) :: i + + if (x%is_host()) call x%sync_space() + max_val = -huge(0.0) + !$acc parallel loop reduction(max:max_val) + do i = 1, n + if (abs(x%v(i)) > max_val) max_val = abs(x%v(i)) + end do + res = max_val + end function s_oacc_amax + + function s_oacc_asum(n, x) result(res) + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + real(psb_spk_) :: res + integer(psb_ipk_) :: info + real(psb_spk_) :: sum + integer(psb_ipk_) :: i + + if (x%is_host()) call x%sync_space() + sum = 0.0 + !$acc parallel loop reduction(+:sum) + do i = 1, n + sum = sum + abs(x%v(i)) + end do + res = sum + end function s_oacc_asum + + + subroutine s_oacc_mlt_a(x, y, info) + implicit none + real(psb_spk_), intent(in) :: x(:) + class(psb_s_vect_oacc), intent(inout) :: y + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: i, n + + info = 0 + if (y%is_dev()) call y%sync_space() + !$acc parallel loop + do i = 1, size(x) + y%v(i) = y%v(i) * x(i) + end do + call y%set_host() + end subroutine s_oacc_mlt_a + + subroutine s_oacc_mlt_a_2(alpha, x, y, beta, z, info) + implicit none + real(psb_spk_), intent(in) :: alpha, beta + real(psb_spk_), intent(in) :: x(:) + real(psb_spk_), intent(in) :: y(:) + class(psb_s_vect_oacc), intent(inout) :: z + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: i, n + + info = 0 + if (z%is_dev()) call z%sync_space() + !$acc parallel loop + do i = 1, size(x) + z%v(i) = alpha * x(i) * y(i) + beta * z%v(i) + end do + call z%set_host() + end subroutine s_oacc_mlt_a_2 + + +!!$ subroutine s_oacc_mlt_v(x, y, info) +!!$ implicit none +!!$ class(psb_s_base_vect_type), intent(inout) :: x +!!$ class(psb_s_vect_oacc), intent(inout) :: y +!!$ integer(psb_ipk_), intent(out) :: info +!!$ +!!$ integer(psb_ipk_) :: i, n +!!$ +!!$ info = 0 +!!$ n = min(x%get_nrows(), y%get_nrows()) +!!$ select type(xx => x) +!!$ type is (psb_s_base_vect_type) +!!$ if (y%is_dev()) call y%sync() +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ y%v(i) = y%v(i) * xx%v(i) +!!$ end do +!!$ call y%set_host() +!!$ class default +!!$ if (xx%is_dev()) call xx%sync() +!!$ if (y%is_dev()) call y%sync() +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ y%v(i) = y%v(i) * xx%v(i) +!!$ end do +!!$ call y%set_host() +!!$ end select +!!$ end subroutine s_oacc_mlt_v +!!$ +!!$ subroutine s_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) +!!$ use psi_serial_mod +!!$ use psb_string_mod +!!$ implicit none +!!$ real(psb_spk_), intent(in) :: alpha, beta +!!$ class(psb_s_base_vect_type), intent(inout) :: x +!!$ class(psb_s_base_vect_type), intent(inout) :: y +!!$ class(psb_s_vect_oacc), intent(inout) :: z +!!$ integer(psb_ipk_), intent(out) :: info +!!$ character(len=1), intent(in), optional :: conjgx, conjgy +!!$ integer(psb_ipk_) :: i, n +!!$ logical :: conjgx_, conjgy_ +!!$ +!!$ conjgx_ = .false. +!!$ conjgy_ = .false. +!!$ if (present(conjgx)) conjgx_ = (psb_toupper(conjgx) == 'C') +!!$ if (present(conjgy)) conjgy_ = (psb_toupper(conjgy) == 'C') +!!$ +!!$ n = min(x%get_nrows(), y%get_nrows(), z%get_nrows()) +!!$ +!!$ info = 0 +!!$ select type(xx => x) +!!$ class is (psb_s_vect_oacc) +!!$ select type (yy => y) +!!$ class is (psb_s_vect_oacc) +!!$ if (xx%is_host()) call xx%sync_space() +!!$ if (yy%is_host()) call yy%sync_space() +!!$ if ((beta /= szero) .and. (z%is_host())) call z%sync_space() +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) +!!$ end do +!!$ call z%set_dev() +!!$ class default +!!$ if (xx%is_dev()) call xx%sync_space() +!!$ if (yy%is_dev()) call yy%sync() +!!$ if ((beta /= szero) .and. (z%is_dev())) call z%sync_space() +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) +!!$ end do +!!$ call z%set_host() +!!$ end select +!!$ class default +!!$ if (x%is_dev()) call x%sync() +!!$ if (y%is_dev()) call y%sync() +!!$ if ((beta /= szero) .and. (z%is_dev())) call z%sync_space() +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ z%v(i) = alpha * x%v(i) * y%v(i) + beta * z%v(i) +!!$ end do +!!$ call z%set_host() +!!$ end select +!!$ end subroutine s_oacc_mlt_v_2 + + + subroutine s_oacc_axpby_v(m, alpha, x, beta, y, info) + !use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_s_base_vect_type), intent(inout) :: x + class(psb_s_vect_oacc), intent(inout) :: y + real(psb_spk_), intent(in) :: alpha, beta + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nx, ny, i + + info = psb_success_ + + select type(xx => x) + type is (psb_s_vect_oacc) + if ((beta /= szero) .and. y%is_host()) call y%sync_space() + if (xx%is_host()) call xx%sync_space() + nx = size(xx%v) + ny = size(y%v) + if ((nx < m) .or. (ny < m)) then + info = psb_err_internal_error_ + else + !$acc parallel loop + do i = 1, m + y%v(i) = alpha * xx%v(i) + beta * y%v(i) + end do + end if + call y%set_dev() + class default + if ((alpha /= szero) .and. (x%is_dev())) call x%sync() + call y%axpby(m, alpha, x%v, beta, info) + end select + end subroutine s_oacc_axpby_v + + subroutine s_oacc_axpby_a(m, alpha, x, beta, y, info) + !use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + real(psb_spk_), intent(in) :: x(:) + class(psb_s_vect_oacc), intent(inout) :: y + real(psb_spk_), intent(in) :: alpha, beta + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: i + + if ((beta /= szero) .and. (y%is_dev())) call y%sync_space() + !$acc parallel loop + do i = 1, m + y%v(i) = alpha * x(i) + beta * y%v(i) + end do + call y%set_host() + end subroutine s_oacc_axpby_a + + subroutine s_oacc_abgdxyz(m, alpha, beta, gamma, delta, x, y, z, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_s_base_vect_type), intent(inout) :: x + class(psb_s_base_vect_type), intent(inout) :: y + class(psb_s_vect_oacc), intent(inout) :: z + real(psb_spk_), intent(in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nx, ny, nz, i + logical :: gpu_done + + info = psb_success_ + gpu_done = .false. + + select type(xx => x) + class is (psb_s_vect_oacc) + select type(yy => y) + class is (psb_s_vect_oacc) + select type(zz => z) + class is (psb_s_vect_oacc) + if ((beta /= szero) .and. yy%is_host()) call yy%sync_space() + if ((delta /= szero) .and. zz%is_host()) call zz%sync_space() + if (xx%is_host()) call xx%sync_space() + nx = size(xx%v) + ny = size(yy%v) + nz = size(zz%v) + if ((nx < m) .or. (ny < m) .or. (nz < m)) then + info = psb_err_internal_error_ + else + !$acc parallel loop + do i = 1, m + yy%v(i) = alpha * xx%v(i) + beta * yy%v(i) + zz%v(i) = gamma * yy%v(i) + delta * zz%v(i) + end do + end if + call yy%set_dev() + call zz%set_dev() + gpu_done = .true. + end select + end select + end select + + if (.not. gpu_done) then + if (x%is_host()) call x%sync() + if (y%is_host()) call y%sync() + if (z%is_host()) call z%sync() + call y%axpby(m, alpha, x, beta, info) + call z%axpby(m, gamma, y, delta, info) + end if + end subroutine s_oacc_abgdxyz + + subroutine s_oacc_sctb_buf(i, n, idx, beta, y) + use psb_base_mod + implicit none + integer(psb_ipk_) :: i, n + class(psb_i_base_vect_type) :: idx + real(psb_spk_) :: beta + class(psb_s_vect_oacc) :: y + integer(psb_ipk_) :: info + + if (.not.allocated(y%combuf)) then + call psb_errpush(psb_err_alloc_dealloc_, 'sctb_buf') + return + end if + + select type(ii => idx) + class is (psb_i_vect_oacc) + if (ii%is_host()) call ii%sync_space(info) + if (y%is_host()) call y%sync_space() + + !$acc parallel loop + do i = 1, n + y%v(ii%v(i)) = beta * y%v(ii%v(i)) + y%combuf(i) + end do + + class default + !$acc parallel loop + do i = 1, n + y%v(idx%v(i)) = beta * y%v(idx%v(i)) + y%combuf(i) + end do + end select + end subroutine s_oacc_sctb_buf + + subroutine s_oacc_sctb_x(i, n, idx, x, beta, y) + use psb_base_mod + implicit none + integer(psb_ipk_):: i, n + class(psb_i_base_vect_type) :: idx + real(psb_spk_) :: beta, x(:) + class(psb_s_vect_oacc) :: y + integer(psb_ipk_) :: info, ni + + select type(ii => idx) + class is (psb_i_vect_oacc) + if (ii%is_host()) call ii%sync_space(info) + class default + call psb_errpush(info, 's_oacc_sctb_x') + return + end select + + if (y%is_host()) call y%sync_space() + + !$acc parallel loop + do i = 1, n + y%v(idx%v(i)) = beta * y%v(idx%v(i)) + x(i) + end do + + call y%set_dev() + end subroutine s_oacc_sctb_x + + + + subroutine s_oacc_sctb(n, idx, x, beta, y) + use psb_base_mod + implicit none + integer(psb_ipk_) :: n + integer(psb_ipk_) :: idx(:) + real(psb_spk_) :: beta, x(:) + class(psb_s_vect_oacc) :: y + integer(psb_ipk_) :: info + integer(psb_ipk_) :: i + + if (n == 0) return + if (y%is_dev()) call y%sync_space() + + !$acc parallel loop + do i = 1, n + y%v(idx(i)) = beta * y%v(idx(i)) + x(i) + end do + + call y%set_host() + end subroutine s_oacc_sctb + + + subroutine s_oacc_gthzbuf(i, n, idx, x) + use psb_base_mod + implicit none + integer(psb_ipk_) :: i, n + class(psb_i_base_vect_type) :: idx + class(psb_s_vect_oacc) :: x + integer(psb_ipk_) :: info + + info = 0 + if (.not.allocated(x%combuf)) then + call psb_errpush(psb_err_alloc_dealloc_, 'gthzbuf') + return + end if + + select type(ii => idx) + class is (psb_i_vect_oacc) + if (ii%is_host()) call ii%sync_space(info) + class default + call psb_errpush(info, 's_oacc_gthzbuf') + return + end select + + if (x%is_host()) call x%sync_space() + + !$acc parallel loop + do i = 1, n + x%combuf(i) = x%v(idx%v(i)) + end do + end subroutine s_oacc_gthzbuf + + subroutine s_oacc_gthzv_x(i, n, idx, x, y) + use psb_base_mod + implicit none + integer(psb_ipk_) :: i, n + class(psb_i_base_vect_type):: idx + real(psb_spk_) :: y(:) + class(psb_s_vect_oacc):: x + integer(psb_ipk_) :: info + + info = 0 + + select type(ii => idx) + class is (psb_i_vect_oacc) + if (ii%is_host()) call ii%sync_space(info) + class default + call psb_errpush(info, 's_oacc_gthzv_x') + return + end select + + if (x%is_host()) call x%sync_space() + + !$acc parallel loop + do i = 1, n + y(i) = x%v(idx%v(i)) + end do + end subroutine s_oacc_gthzv_x + + subroutine s_oacc_ins_v(n, irl, val, dupl, x, info) + use psi_serial_mod + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n, dupl + class(psb_i_base_vect_type), intent(inout) :: irl + class(psb_s_base_vect_type), intent(inout) :: val + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i, isz + logical :: done_oacc + + info = 0 + if (psb_errstatus_fatal()) return + + done_oacc = .false. + select type(virl => irl) + type is (psb_i_vect_oacc) + select type(vval => val) + type is (psb_s_vect_oacc) + if (vval%is_host()) call vval%sync_space() + if (virl%is_host()) call virl%sync_space(info) + if (x%is_host()) call x%sync_space() + !$acc parallel loop + do i = 1, n + x%v(virl%v(i)) = vval%v(i) + end do + call x%set_dev() + done_oacc = .true. + end select + end select + + if (.not.done_oacc) then + select type(virl => irl) + type is (psb_i_vect_oacc) + if (virl%is_dev()) call virl%sync_space(info) + end select + select type(vval => val) + type is (psb_s_vect_oacc) + if (vval%is_dev()) call vval%sync_space() + end select + call x%ins(n, irl%v, val%v, dupl, info) + end if + + if (info /= 0) then + call psb_errpush(info, 'oacc_vect_ins') + return + end if + + end subroutine s_oacc_ins_v + + + + subroutine s_oacc_ins_a(n, irl, val, dupl, x, info) + use psi_serial_mod + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n, dupl + integer(psb_ipk_), intent(in) :: irl(:) + real(psb_spk_), intent(in) :: val(:) + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i + + info = 0 + if (x%is_dev()) call x%sync_space() + call x%psb_s_base_vect_type%ins(n, irl, val, dupl, info) + call x%set_host() + !$acc update device(x%v) + + end subroutine s_oacc_ins_a + + + + subroutine s_oacc_bld_mn(x, n) + use psb_base_mod + implicit none + integer(psb_mpk_), intent(in) :: n + class(psb_s_vect_oacc), intent(inout) :: x + integer(psb_ipk_) :: info + + call x%all(n, info) + if (info /= 0) then + call psb_errpush(info, 's_oacc_bld_mn', i_err=(/n, n, n, n, n/)) + end if + call x%set_host() + !$acc update device(x%v) + + end subroutine s_oacc_bld_mn + + + subroutine s_oacc_bld_x(x, this) + use psb_base_mod + implicit none + real(psb_spk_), intent(in) :: this(:) + class(psb_s_vect_oacc), intent(inout) :: x + integer(psb_ipk_) :: info + + call psb_realloc(size(this), x%v, info) + if (info /= 0) then + info = psb_err_alloc_request_ + call psb_errpush(info, 's_oacc_bld_x', & + i_err=(/size(this), izero, izero, izero, izero/)) + return + end if + + x%v(:) = this(:) + call x%set_host() + !$acc update device(x%v) + + end subroutine s_oacc_bld_x + + + subroutine s_oacc_asb_m(n, x, info) + use psb_base_mod + implicit none + integer(psb_mpk_), intent(in) :: n + class(psb_s_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + integer(psb_mpk_) :: nd + + info = psb_success_ + + if (x%is_dev()) then + nd = size(x%v) + if (nd < n) then + call x%sync() + call x%psb_s_base_vect_type%asb(n, info) + if (info == psb_success_) call x%sync_space() + call x%set_host() + end if + else + if (size(x%v) < n) then + call x%psb_s_base_vect_type%asb(n, info) + if (info == psb_success_) call x%sync_space() + call x%set_host() + end if + end if + end subroutine s_oacc_asb_m + + + + subroutine s_oacc_set_scal(x, val, first, last) + class(psb_s_vect_oacc), intent(inout) :: x + real(psb_spk_), intent(in) :: val + integer(psb_ipk_), optional :: first, last + + integer(psb_ipk_) :: first_, last_ + first_ = 1 + last_ = x%get_nrows() + if (present(first)) first_ = max(1, first) + if (present(last)) last_ = min(last, last_) + + !$acc parallel loop + do i = first_, last_ + x%v(i) = val + end do + !$acc end parallel loop + + call x%set_dev() + end subroutine s_oacc_set_scal + + + + subroutine s_oacc_zero(x) + use psi_serial_mod + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + call x%set_dev() + call x%set_scal(szero) + end subroutine s_oacc_zero + + function s_oacc_get_nrows(x) result(res) + implicit none + class(psb_s_vect_oacc), intent(in) :: x + integer(psb_ipk_) :: res + + if (allocated(x%v)) res = size(x%v) + end function s_oacc_get_nrows + + function s_oacc_get_fmt() result(res) + implicit none + character(len=5) :: res + res = "sOACC" + + end function s_oacc_get_fmt + + function s_oacc_vect_dot(n, x, y) result(res) + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + class(psb_s_base_vect_type), intent(inout) :: y + integer(psb_ipk_), intent(in) :: n + real(psb_spk_) :: res + real(psb_spk_), external :: ddot + integer(psb_ipk_) :: info + integer(psb_ipk_) :: i + + res = szero + + select type(yy => y) + type is (psb_s_base_vect_type) + if (x%is_dev()) call x%sync() + res = ddot(n, x%v, 1, yy%v, 1) + type is (psb_s_vect_oacc) + if (x%is_host()) call x%sync() + if (yy%is_host()) call yy%sync() + + !$acc parallel loop reduction(+:res) present(x%v, yy%v) + do i = 1, n + res = res + x%v(i) * yy%v(i) + end do + !$acc end parallel loop + + class default + call x%sync() + res = y%dot(n, x%v) + end select + + end function s_oacc_vect_dot + + + + + function s_oacc_dot_a(n, x, y) result(res) + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + real(psb_spk_), intent(in) :: y(:) + integer(psb_ipk_), intent(in) :: n + real(psb_spk_) :: res + real(psb_spk_), external :: ddot + + if (x%is_dev()) call x%sync() + res = ddot(n, y, 1, x%v, 1) + + end function s_oacc_dot_a + + ! subroutine s_oacc_set_vect(x,y) + ! implicit none + ! class(psb_s_vect_oacc), intent(inout) :: x + ! real(psb_spk_), intent(in) :: y(:) + ! integer(psb_ipk_) :: info + + ! if (size(x%v) /= size(y)) then + ! call x%free(info) + ! call x%all(size(y),info) + ! end if + ! x%v(:) = y(:) + ! call x%set_host() + ! end subroutine s_oacc_set_vect + + subroutine s_oacc_to_dev(v) + implicit none + real(psb_spk_) :: v(:) + !$acc update device(v) + end subroutine s_oacc_to_dev + + subroutine s_oacc_to_host(v) + implicit none + real(psb_spk_) :: v(:) + !$acc update self(v) + end subroutine s_oacc_to_host + + subroutine s_oacc_sync_space(x) + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + if (allocated(x%v)) then + call s_oacc_create_dev(x%v) + end if + contains + subroutine s_oacc_create_dev(v) + implicit none + real(psb_spk_) :: v(:) + !$acc enter data copyin(v) + end subroutine s_oacc_create_dev + end subroutine s_oacc_sync_space + + subroutine s_oacc_sync(x) + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + if (x%is_dev()) then + call s_oacc_to_host(x%v) + end if + if (x%is_host()) then + call s_oacc_to_dev(x%v) + end if + call x%set_sync() + end subroutine s_oacc_sync + + subroutine s_oacc_set_host(x) + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + + x%state = is_host + end subroutine s_oacc_set_host + + subroutine s_oacc_set_dev(x) + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + + x%state = is_dev + end subroutine s_oacc_set_dev + + subroutine s_oacc_set_sync(x) + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + + x%state = is_sync + end subroutine s_oacc_set_sync + + function s_oacc_is_dev(x) result(res) + implicit none + class(psb_s_vect_oacc), intent(in) :: x + logical :: res + + res = (x%state == is_dev) + end function s_oacc_is_dev + + function s_oacc_is_host(x) result(res) + implicit none + class(psb_s_vect_oacc), intent(in) :: x + logical :: res + + res = (x%state == is_host) + end function s_oacc_is_host + + function s_oacc_is_sync(x) result(res) + implicit none + class(psb_s_vect_oacc), intent(in) :: x + logical :: res + + res = (x%state == is_sync) + end function s_oacc_is_sync + + subroutine s_oacc_vect_all(n, x, info) + use psi_serial_mod + use psb_realloc_mod + implicit none + integer(psb_ipk_), intent(in) :: n + class(psb_s_vect_oacc), intent(out) :: x + integer(psb_ipk_), intent(out) :: info + + call psb_realloc(n, x%v, info) + if (info == 0) then + call x%set_host() + !$acc enter data create(x%v) + call x%sync_space() + end if + if (info /= 0) then + info = psb_err_alloc_request_ + call psb_errpush(info, 's_oacc_all', & + i_err=(/n, n, n, n, n/)) + end if + end subroutine s_oacc_vect_all + + + subroutine s_oacc_vect_free(x, info) + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + info = 0 + if (allocated(x%v)) then + !$acc exit data delete(x%v) finalize + deallocate(x%v, stat=info) + end if + + end subroutine s_oacc_vect_free + + function s_oacc_get_size(x) result(res) + implicit none + class(psb_s_vect_oacc), intent(inout) :: x + integer(psb_ipk_) :: res + + if (x%is_dev()) call x%sync() + res = size(x%v) + end function s_oacc_get_size + +end module psb_s_oacc_vect_mod diff --git a/openacc/psb_z_oacc_csr_mat_mod.F90 b/openacc/psb_z_oacc_csr_mat_mod.F90 new file mode 100644 index 00000000..7842d96c --- /dev/null +++ b/openacc/psb_z_oacc_csr_mat_mod.F90 @@ -0,0 +1,343 @@ +module psb_z_oacc_csr_mat_mod + + use iso_c_binding + use psb_z_mat_mod + use psb_z_oacc_vect_mod + !use oaccsparse_mod + + integer(psb_ipk_), parameter, private :: is_host = -1 + integer(psb_ipk_), parameter, private :: is_sync = 0 + integer(psb_ipk_), parameter, private :: is_dev = 1 + + type, extends(psb_z_csr_sparse_mat) :: psb_z_oacc_csr_sparse_mat + integer(psb_ipk_) :: devstate = is_host + contains + procedure, nopass :: get_fmt => z_oacc_csr_get_fmt + procedure, pass(a) :: sizeof => z_oacc_csr_sizeof + procedure, pass(a) :: vect_mv => psb_z_oacc_csr_vect_mv + procedure, pass(a) :: in_vect_sv => psb_z_oacc_csr_inner_vect_sv + procedure, pass(a) :: csmm => psb_z_oacc_csr_csmm + procedure, pass(a) :: csmv => psb_z_oacc_csr_csmv + procedure, pass(a) :: scals => psb_z_oacc_csr_scals + procedure, pass(a) :: scalv => psb_z_oacc_csr_scal + procedure, pass(a) :: reallocate_nz => psb_z_oacc_csr_reallocate_nz + procedure, pass(a) :: allocate_mnnz => psb_z_oacc_csr_allocate_mnnz + procedure, pass(a) :: cp_from_coo => psb_z_oacc_csr_cp_from_coo + procedure, pass(a) :: cp_from_fmt => psb_z_oacc_csr_cp_from_fmt + procedure, pass(a) :: mv_from_coo => psb_z_oacc_csr_mv_from_coo + procedure, pass(a) :: mv_from_fmt => psb_z_oacc_csr_mv_from_fmt + procedure, pass(a) :: free => z_oacc_csr_free + procedure, pass(a) :: mold => psb_z_oacc_csr_mold + procedure, pass(a) :: all => z_oacc_csr_all + procedure, pass(a) :: is_host => z_oacc_csr_is_host + procedure, pass(a) :: is_sync => z_oacc_csr_is_sync + procedure, pass(a) :: is_dev => z_oacc_csr_is_dev + procedure, pass(a) :: set_host => z_oacc_csr_set_host + procedure, pass(a) :: set_sync => z_oacc_csr_set_sync + procedure, pass(a) :: set_dev => z_oacc_csr_set_dev + procedure, pass(a) :: sync_space => z_oacc_csr_sync_space + procedure, pass(a) :: sync => z_oacc_csr_sync + end type psb_z_oacc_csr_sparse_mat + + interface + module subroutine psb_z_oacc_csr_mold(a,b,info) + class(psb_z_oacc_csr_sparse_mat), intent(in) :: a + class(psb_z_base_sparse_mat), intent(inout), allocatable :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_z_oacc_csr_mold + end interface + + interface + module subroutine psb_z_oacc_csr_cp_from_fmt(a,b,info) + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_z_base_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_z_oacc_csr_cp_from_fmt + end interface + + interface + module subroutine psb_z_oacc_csr_mv_from_coo(a,b,info) + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_z_coo_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_z_oacc_csr_mv_from_coo + end interface + + interface + module subroutine psb_z_oacc_csr_mv_from_fmt(a,b,info) + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_z_base_sparse_mat), intent(inout) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_z_oacc_csr_mv_from_fmt + end interface + + interface + module subroutine psb_z_oacc_csr_vect_mv(alpha, a, x, beta, y, info, trans) + class(psb_z_oacc_csr_sparse_mat), intent(in) :: a + complex(psb_dpk_), intent(in) :: alpha, beta + class(psb_z_base_vect_type), intent(inout) :: x, y + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + end subroutine psb_z_oacc_csr_vect_mv + end interface + + interface + module subroutine psb_z_oacc_csr_inner_vect_sv(alpha, a, x, beta, y, info, trans) + class(psb_z_oacc_csr_sparse_mat), intent(in) :: a + complex(psb_dpk_), intent(in) :: alpha, beta + class(psb_z_base_vect_type), intent(inout) :: x,y + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + end subroutine psb_z_oacc_csr_inner_vect_sv + end interface + + interface + module subroutine psb_z_oacc_csr_csmm(alpha, a, x, beta, y, info, trans) + class(psb_z_oacc_csr_sparse_mat), intent(in) :: a + complex(psb_dpk_), intent(in) :: alpha, beta, x(:,:) + complex(psb_dpk_), intent(inout) :: y(:,:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + end subroutine psb_z_oacc_csr_csmm + end interface + + interface + module subroutine psb_z_oacc_csr_csmv(alpha, a, x, beta, y, info, trans) + class(psb_z_oacc_csr_sparse_mat), intent(in) :: a + complex(psb_dpk_), intent(in) :: alpha, beta, x(:) + complex(psb_dpk_), intent(inout) :: y(:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: trans + end subroutine psb_z_oacc_csr_csmv + end interface + + interface + module subroutine psb_z_oacc_csr_scals(d, a, info) + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + complex(psb_dpk_), intent(in) :: d + integer(psb_ipk_), intent(out) :: info + end subroutine psb_z_oacc_csr_scals + end interface + + interface + module subroutine psb_z_oacc_csr_scal(d,a,info,side) + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + complex(psb_dpk_), intent(in) :: d(:) + integer(psb_ipk_), intent(out) :: info + character, optional, intent(in) :: side + end subroutine psb_z_oacc_csr_scal + end interface + + interface + module subroutine psb_z_oacc_csr_reallocate_nz(nz,a) + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + integer(psb_ipk_), intent(in) :: nz + end subroutine psb_z_oacc_csr_reallocate_nz + end interface + + interface + module subroutine psb_z_oacc_csr_allocate_mnnz(m,n,a,nz) + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + integer(psb_ipk_), intent(in) :: m,n + integer(psb_ipk_), intent(in), optional :: nz + end subroutine psb_z_oacc_csr_allocate_mnnz + end interface + + interface + module subroutine psb_z_oacc_csr_cp_from_coo(a,b,info) + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + class(psb_z_coo_sparse_mat), intent(in) :: b + integer(psb_ipk_), intent(out) :: info + end subroutine psb_z_oacc_csr_cp_from_coo + end interface + +contains + + + subroutine z_oacc_csr_free(a) + use psb_base_mod + implicit none + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + integer(psb_ipk_) :: info + + if (allocated(a%val)) then + !$acc exit data delete(a%val) + end if + if (allocated(a%ja)) then + !$acc exit data delete(a%ja) + end if + if (allocated(a%irp)) then + !$acc exit data delete(a%irp) + end if + + call a%psb_z_csr_sparse_mat%free() + + return + end subroutine z_oacc_csr_free + + function z_oacc_csr_sizeof(a) result(res) + implicit none + class(psb_z_oacc_csr_sparse_mat), intent(in) :: a + integer(psb_epk_) :: res + + if (a%is_dev()) call a%sync() + + res = 8 + res = res + (2*psb_sizeof_dp) * size(a%val) + res = res + psb_sizeof_ip * size(a%irp) + res = res + psb_sizeof_ip * size(a%ja) + + end function z_oacc_csr_sizeof + + + function z_oacc_csr_get_fmt() result(res) + implicit none + character(len=5) :: res + res = 'CSR_oacc' + end function z_oacc_csr_get_fmt + + subroutine z_oacc_csr_all(m, n, nz, a, info) + implicit none + integer(psb_ipk_), intent(in) :: m, n, nz + class(psb_z_oacc_csr_sparse_mat), intent(out) :: a + integer(psb_ipk_), intent(out) :: info + + info = 0 + if (allocated(a%val)) then + !$acc exit data delete(a%val) finalize + deallocate(a%val, stat=info) + end if + if (allocated(a%ja)) then + !$acc exit data delete(a%ja) finalize + deallocate(a%ja, stat=info) + end if + if (allocated(a%irp)) then + !$acc exit data delete(a%irp) finalize + deallocate(a%irp, stat=info) + end if + + call a%set_nrows(m) + call a%set_ncols(n) + + allocate(a%val(nz),stat=info) + allocate(a%ja(nz),stat=info) + allocate(a%irp(m+1),stat=info) + if (info == 0) call a%set_host() + if (info == 0) call a%sync_space() + end subroutine z_oacc_csr_all + + function z_oacc_csr_is_host(a) result(res) + implicit none + class(psb_z_oacc_csr_sparse_mat), intent(in) :: a + logical :: res + + res = (a%devstate == is_host) + end function z_oacc_csr_is_host + + function z_oacc_csr_is_sync(a) result(res) + implicit none + class(psb_z_oacc_csr_sparse_mat), intent(in) :: a + logical :: res + + res = (a%devstate == is_sync) + end function z_oacc_csr_is_sync + + function z_oacc_csr_is_dev(a) result(res) + implicit none + class(psb_z_oacc_csr_sparse_mat), intent(in) :: a + logical :: res + + res = (a%devstate == is_dev) + end function z_oacc_csr_is_dev + + subroutine z_oacc_csr_set_host(a) + implicit none + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + + a%devstate = is_host + end subroutine z_oacc_csr_set_host + + subroutine z_oacc_csr_set_sync(a) + implicit none + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + + a%devstate = is_sync + end subroutine z_oacc_csr_set_sync + + subroutine z_oacc_csr_set_dev(a) + implicit none + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + + a%devstate = is_dev + end subroutine z_oacc_csr_set_dev + + subroutine z_oacc_csr_sync_space(a) + implicit none + class(psb_z_oacc_csr_sparse_mat), intent(inout) :: a + if (allocated(a%val)) then + call z_oacc_create_dev(a%val) + end if + if (allocated(a%ja)) then + call i_oacc_create_dev(a%ja) + end if + if (allocated(a%irp)) then + call i_oacc_create_dev(a%irp) + end if + contains + subroutine z_oacc_create_dev(v) + implicit none + complex(psb_dpk_), intent(in) :: v(:) + !$acc enter data copyin(v) + end subroutine z_oacc_create_dev + subroutine i_oacc_create_dev(v) + implicit none + integer(psb_ipk_), intent(in) :: v(:) + !$acc enter data copyin(v) + end subroutine i_oacc_create_dev + end subroutine z_oacc_csr_sync_space + + subroutine z_oacc_csr_sync(a) + implicit none + class(psb_z_oacc_csr_sparse_mat), target, intent(in) :: a + class(psb_z_oacc_csr_sparse_mat), pointer :: tmpa + integer(psb_ipk_) :: info + + tmpa => a + if (a%is_dev()) then + call z_oacc_csr_to_host(a%val) + call i_oacc_csr_to_host(a%ja) + call i_oacc_csr_to_host(a%irp) + else if (a%is_host()) then + call z_oacc_csr_to_dev(a%val) + call i_oacc_csr_to_dev(a%ja) + call i_oacc_csr_to_dev(a%irp) + end if + call tmpa%set_sync() + end subroutine z_oacc_csr_sync + + subroutine z_oacc_csr_to_dev(v) + implicit none + complex(psb_dpk_), intent(in) :: v(:) + !$acc update device(v) + end subroutine z_oacc_csr_to_dev + + subroutine z_oacc_csr_to_host(v) + implicit none + complex(psb_dpk_), intent(in) :: v(:) + !$acc update self(v) + end subroutine z_oacc_csr_to_host + + subroutine i_oacc_csr_to_dev(v) + implicit none + integer(psb_ipk_), intent(in) :: v(:) + !$acc update device(v) + end subroutine i_oacc_csr_to_dev + + subroutine i_oacc_csr_to_host(v) + implicit none + integer(psb_ipk_), intent(in) :: v(:) + !$acc update self(v) + end subroutine i_oacc_csr_to_host + +end module psb_z_oacc_csr_mat_mod + diff --git a/openacc/psb_z_oacc_vect_mod.F90 b/openacc/psb_z_oacc_vect_mod.F90 new file mode 100644 index 00000000..5d03b49d --- /dev/null +++ b/openacc/psb_z_oacc_vect_mod.F90 @@ -0,0 +1,935 @@ +module psb_z_oacc_vect_mod + use iso_c_binding + use psb_const_mod + use psb_error_mod + use psb_z_vect_mod + use psb_i_vect_mod + use psb_i_oacc_vect_mod + + integer(psb_ipk_), parameter, private :: is_host = -1 + integer(psb_ipk_), parameter, private :: is_sync = 0 + integer(psb_ipk_), parameter, private :: is_dev = 1 + + type, extends(psb_z_base_vect_type) :: psb_z_vect_oacc + integer :: state = is_host + + contains + procedure, pass(x) :: get_nrows => z_oacc_get_nrows + procedure, nopass :: get_fmt => z_oacc_get_fmt + + procedure, pass(x) :: all => z_oacc_vect_all + procedure, pass(x) :: zero => z_oacc_zero + procedure, pass(x) :: asb_m => z_oacc_asb_m + procedure, pass(x) :: sync => z_oacc_sync + procedure, pass(x) :: sync_space => z_oacc_sync_space + procedure, pass(x) :: bld_x => z_oacc_bld_x + procedure, pass(x) :: bld_mn => z_oacc_bld_mn + procedure, pass(x) :: free => z_oacc_vect_free + procedure, pass(x) :: ins_a => z_oacc_ins_a + procedure, pass(x) :: ins_v => z_oacc_ins_v + procedure, pass(x) :: is_host => z_oacc_is_host + procedure, pass(x) :: is_dev => z_oacc_is_dev + procedure, pass(x) :: is_sync => z_oacc_is_sync + procedure, pass(x) :: set_host => z_oacc_set_host + procedure, pass(x) :: set_dev => z_oacc_set_dev + procedure, pass(x) :: set_sync => z_oacc_set_sync + procedure, pass(x) :: set_scal => z_oacc_set_scal + + procedure, pass(x) :: gthzv_x => z_oacc_gthzv_x + procedure, pass(x) :: gthzbuf_x => z_oacc_gthzbuf + procedure, pass(y) :: sctb => z_oacc_sctb + procedure, pass(y) :: sctb_x => z_oacc_sctb_x + procedure, pass(y) :: sctb_buf => z_oacc_sctb_buf + + procedure, pass(x) :: get_size => z_oacc_get_size + procedure, pass(x) :: dot_v => z_oacc_vect_dot + procedure, pass(x) :: dot_a => z_oacc_dot_a + procedure, pass(y) :: axpby_v => z_oacc_axpby_v + procedure, pass(y) :: axpby_a => z_oacc_axpby_a + procedure, pass(z) :: abgdxyz => z_oacc_abgdxyz + procedure, pass(y) :: mlt_a => z_oacc_mlt_a + procedure, pass(z) :: mlt_a_2 => z_oacc_mlt_a_2 + procedure, pass(y) :: mlt_v => z_oacc_mlt_v + procedure, pass(z) :: mlt_v_2 => z_oacc_mlt_v_2 + procedure, pass(x) :: scal => z_oacc_scal + procedure, pass(x) :: nrm2 => z_oacc_nrm2 + procedure, pass(x) :: amax => z_oacc_amax + procedure, pass(x) :: asum => z_oacc_asum + procedure, pass(x) :: absval1 => z_oacc_absval1 + procedure, pass(x) :: absval2 => z_oacc_absval2 + + end type psb_z_vect_oacc + + interface + subroutine z_oacc_mlt_v(x, y, info) + import + implicit none + class(psb_z_base_vect_type), intent(inout) :: x + class(psb_z_vect_oacc), intent(inout) :: y + integer(psb_ipk_), intent(out) :: info + end subroutine z_oacc_mlt_v + end interface + + + interface + subroutine z_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) + import + implicit none + complex(psb_dpk_), intent(in) :: alpha, beta + class(psb_z_base_vect_type), intent(inout) :: x + class(psb_z_base_vect_type), intent(inout) :: y + class(psb_z_vect_oacc), intent(inout) :: z + integer(psb_ipk_), intent(out) :: info + character(len=1), intent(in), optional :: conjgx, conjgy + end subroutine z_oacc_mlt_v_2 + end interface + +contains + + subroutine z_oacc_absval1(x) + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + integer(psb_ipk_) :: n, i + + if (x%is_host()) call x%sync_space() + n = size(x%v) + !$acc parallel loop + do i = 1, n + x%v(i) = abs(x%v(i)) + end do + call x%set_dev() + end subroutine z_oacc_absval1 + + subroutine z_oacc_absval2(x, y) + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + class(psb_z_base_vect_type), intent(inout) :: y + integer(psb_ipk_) :: n + integer(psb_ipk_) :: i + + n = min(size(x%v), size(y%v)) + select type (yy => y) + class is (psb_z_vect_oacc) + if (x%is_host()) call x%sync() + if (yy%is_host()) call yy%sync() + !$acc parallel loop + do i = 1, n + yy%v(i) = abs(x%v(i)) + end do + class default + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + call x%psb_z_base_vect_type%absval(y) + end select + end subroutine z_oacc_absval2 + + subroutine z_oacc_scal(alpha, x) + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + complex(psb_dpk_), intent(in) :: alpha + integer(psb_ipk_) :: info + integer(psb_ipk_) :: i + + if (x%is_host()) call x%sync_space() + !$acc parallel loop + do i = 1, size(x%v) + x%v(i) = alpha * x%v(i) + end do + call x%set_dev() + end subroutine z_oacc_scal + + function z_oacc_nrm2(n, x) result(res) + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + real(psb_dpk_) :: res + integer(psb_ipk_) :: info + real(psb_dpk_) :: sum + integer(psb_ipk_) :: i + + if (x%is_host()) call x%sync_space() + sum = 0.0 + !$acc parallel loop reduction(+:sum) + do i = 1, n + sum = sum + abs(x%v(i))**2 + end do + res = sqrt(sum) + end function z_oacc_nrm2 + + function z_oacc_amax(n, x) result(res) + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + real(psb_dpk_) :: res + integer(psb_ipk_) :: info + real(psb_dpk_) :: max_val + integer(psb_ipk_) :: i + + if (x%is_host()) call x%sync_space() + max_val = -huge(0.0) + !$acc parallel loop reduction(max:max_val) + do i = 1, n + if (abs(x%v(i)) > max_val) max_val = abs(x%v(i)) + end do + res = max_val + end function z_oacc_amax + + function z_oacc_asum(n, x) result(res) + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n + real(psb_dpk_) :: res + integer(psb_ipk_) :: info + complex(psb_dpk_) :: sum + integer(psb_ipk_) :: i + + if (x%is_host()) call x%sync_space() + sum = 0.0 + !$acc parallel loop reduction(+:sum) + do i = 1, n + sum = sum + abs(x%v(i)) + end do + res = sum + end function z_oacc_asum + + + subroutine z_oacc_mlt_a(x, y, info) + implicit none + complex(psb_dpk_), intent(in) :: x(:) + class(psb_z_vect_oacc), intent(inout) :: y + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: i, n + + info = 0 + if (y%is_dev()) call y%sync_space() + !$acc parallel loop + do i = 1, size(x) + y%v(i) = y%v(i) * x(i) + end do + call y%set_host() + end subroutine z_oacc_mlt_a + + subroutine z_oacc_mlt_a_2(alpha, x, y, beta, z, info) + implicit none + complex(psb_dpk_), intent(in) :: alpha, beta + complex(psb_dpk_), intent(in) :: x(:) + complex(psb_dpk_), intent(in) :: y(:) + class(psb_z_vect_oacc), intent(inout) :: z + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: i, n + + info = 0 + if (z%is_dev()) call z%sync_space() + !$acc parallel loop + do i = 1, size(x) + z%v(i) = alpha * x(i) * y(i) + beta * z%v(i) + end do + call z%set_host() + end subroutine z_oacc_mlt_a_2 + + +!!$ subroutine z_oacc_mlt_v(x, y, info) +!!$ implicit none +!!$ class(psb_z_base_vect_type), intent(inout) :: x +!!$ class(psb_z_vect_oacc), intent(inout) :: y +!!$ integer(psb_ipk_), intent(out) :: info +!!$ +!!$ integer(psb_ipk_) :: i, n +!!$ +!!$ info = 0 +!!$ n = min(x%get_nrows(), y%get_nrows()) +!!$ select type(xx => x) +!!$ type is (psb_z_base_vect_type) +!!$ if (y%is_dev()) call y%sync() +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ y%v(i) = y%v(i) * xx%v(i) +!!$ end do +!!$ call y%set_host() +!!$ class default +!!$ if (xx%is_dev()) call xx%sync() +!!$ if (y%is_dev()) call y%sync() +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ y%v(i) = y%v(i) * xx%v(i) +!!$ end do +!!$ call y%set_host() +!!$ end select +!!$ end subroutine z_oacc_mlt_v +!!$ +!!$ subroutine z_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) +!!$ use psi_serial_mod +!!$ use psb_string_mod +!!$ implicit none +!!$ complex(psb_dpk_), intent(in) :: alpha, beta +!!$ class(psb_z_base_vect_type), intent(inout) :: x +!!$ class(psb_z_base_vect_type), intent(inout) :: y +!!$ class(psb_z_vect_oacc), intent(inout) :: z +!!$ integer(psb_ipk_), intent(out) :: info +!!$ character(len=1), intent(in), optional :: conjgx, conjgy +!!$ integer(psb_ipk_) :: i, n +!!$ logical :: conjgx_, conjgy_ +!!$ +!!$ conjgx_ = .false. +!!$ conjgy_ = .false. +!!$ if (present(conjgx)) conjgx_ = (psb_toupper(conjgx) == 'C') +!!$ if (present(conjgy)) conjgy_ = (psb_toupper(conjgy) == 'C') +!!$ +!!$ n = min(x%get_nrows(), y%get_nrows(), z%get_nrows()) +!!$ +!!$ info = 0 +!!$ select type(xx => x) +!!$ class is (psb_z_vect_oacc) +!!$ select type (yy => y) +!!$ class is (psb_z_vect_oacc) +!!$ if (xx%is_host()) call xx%sync_space() +!!$ if (yy%is_host()) call yy%sync_space() +!!$ if ((beta /= zzero) .and. (z%is_host())) call z%sync_space() +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) +!!$ end do +!!$ call z%set_dev() +!!$ class default +!!$ if (xx%is_dev()) call xx%sync_space() +!!$ if (yy%is_dev()) call yy%sync() +!!$ if ((beta /= zzero) .and. (z%is_dev())) call z%sync_space() +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) +!!$ end do +!!$ call z%set_host() +!!$ end select +!!$ class default +!!$ if (x%is_dev()) call x%sync() +!!$ if (y%is_dev()) call y%sync() +!!$ if ((beta /= zzero) .and. (z%is_dev())) call z%sync_space() +!!$ !$acc parallel loop +!!$ do i = 1, n +!!$ z%v(i) = alpha * x%v(i) * y%v(i) + beta * z%v(i) +!!$ end do +!!$ call z%set_host() +!!$ end select +!!$ end subroutine z_oacc_mlt_v_2 + + + subroutine z_oacc_axpby_v(m, alpha, x, beta, y, info) + !use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_z_base_vect_type), intent(inout) :: x + class(psb_z_vect_oacc), intent(inout) :: y + complex(psb_dpk_), intent(in) :: alpha, beta + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nx, ny, i + + info = psb_success_ + + select type(xx => x) + type is (psb_z_vect_oacc) + if ((beta /= zzero) .and. y%is_host()) call y%sync_space() + if (xx%is_host()) call xx%sync_space() + nx = size(xx%v) + ny = size(y%v) + if ((nx < m) .or. (ny < m)) then + info = psb_err_internal_error_ + else + !$acc parallel loop + do i = 1, m + y%v(i) = alpha * xx%v(i) + beta * y%v(i) + end do + end if + call y%set_dev() + class default + if ((alpha /= zzero) .and. (x%is_dev())) call x%sync() + call y%axpby(m, alpha, x%v, beta, info) + end select + end subroutine z_oacc_axpby_v + + subroutine z_oacc_axpby_a(m, alpha, x, beta, y, info) + !use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + complex(psb_dpk_), intent(in) :: x(:) + class(psb_z_vect_oacc), intent(inout) :: y + complex(psb_dpk_), intent(in) :: alpha, beta + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: i + + if ((beta /= zzero) .and. (y%is_dev())) call y%sync_space() + !$acc parallel loop + do i = 1, m + y%v(i) = alpha * x(i) + beta * y%v(i) + end do + call y%set_host() + end subroutine z_oacc_axpby_a + + subroutine z_oacc_abgdxyz(m, alpha, beta, gamma, delta, x, y, z, info) + use psi_serial_mod + implicit none + integer(psb_ipk_), intent(in) :: m + class(psb_z_base_vect_type), intent(inout) :: x + class(psb_z_base_vect_type), intent(inout) :: y + class(psb_z_vect_oacc), intent(inout) :: z + complex(psb_dpk_), intent(in) :: alpha, beta, gamma, delta + integer(psb_ipk_), intent(out) :: info + integer(psb_ipk_) :: nx, ny, nz, i + logical :: gpu_done + + info = psb_success_ + gpu_done = .false. + + select type(xx => x) + class is (psb_z_vect_oacc) + select type(yy => y) + class is (psb_z_vect_oacc) + select type(zz => z) + class is (psb_z_vect_oacc) + if ((beta /= zzero) .and. yy%is_host()) call yy%sync_space() + if ((delta /= zzero) .and. zz%is_host()) call zz%sync_space() + if (xx%is_host()) call xx%sync_space() + nx = size(xx%v) + ny = size(yy%v) + nz = size(zz%v) + if ((nx < m) .or. (ny < m) .or. (nz < m)) then + info = psb_err_internal_error_ + else + !$acc parallel loop + do i = 1, m + yy%v(i) = alpha * xx%v(i) + beta * yy%v(i) + zz%v(i) = gamma * yy%v(i) + delta * zz%v(i) + end do + end if + call yy%set_dev() + call zz%set_dev() + gpu_done = .true. + end select + end select + end select + + if (.not. gpu_done) then + if (x%is_host()) call x%sync() + if (y%is_host()) call y%sync() + if (z%is_host()) call z%sync() + call y%axpby(m, alpha, x, beta, info) + call z%axpby(m, gamma, y, delta, info) + end if + end subroutine z_oacc_abgdxyz + + subroutine z_oacc_sctb_buf(i, n, idx, beta, y) + use psb_base_mod + implicit none + integer(psb_ipk_) :: i, n + class(psb_i_base_vect_type) :: idx + complex(psb_dpk_) :: beta + class(psb_z_vect_oacc) :: y + integer(psb_ipk_) :: info + + if (.not.allocated(y%combuf)) then + call psb_errpush(psb_err_alloc_dealloc_, 'sctb_buf') + return + end if + + select type(ii => idx) + class is (psb_i_vect_oacc) + if (ii%is_host()) call ii%sync_space(info) + if (y%is_host()) call y%sync_space() + + !$acc parallel loop + do i = 1, n + y%v(ii%v(i)) = beta * y%v(ii%v(i)) + y%combuf(i) + end do + + class default + !$acc parallel loop + do i = 1, n + y%v(idx%v(i)) = beta * y%v(idx%v(i)) + y%combuf(i) + end do + end select + end subroutine z_oacc_sctb_buf + + subroutine z_oacc_sctb_x(i, n, idx, x, beta, y) + use psb_base_mod + implicit none + integer(psb_ipk_):: i, n + class(psb_i_base_vect_type) :: idx + complex(psb_dpk_) :: beta, x(:) + class(psb_z_vect_oacc) :: y + integer(psb_ipk_) :: info, ni + + select type(ii => idx) + class is (psb_i_vect_oacc) + if (ii%is_host()) call ii%sync_space(info) + class default + call psb_errpush(info, 'z_oacc_sctb_x') + return + end select + + if (y%is_host()) call y%sync_space() + + !$acc parallel loop + do i = 1, n + y%v(idx%v(i)) = beta * y%v(idx%v(i)) + x(i) + end do + + call y%set_dev() + end subroutine z_oacc_sctb_x + + + + subroutine z_oacc_sctb(n, idx, x, beta, y) + use psb_base_mod + implicit none + integer(psb_ipk_) :: n + integer(psb_ipk_) :: idx(:) + complex(psb_dpk_) :: beta, x(:) + class(psb_z_vect_oacc) :: y + integer(psb_ipk_) :: info + integer(psb_ipk_) :: i + + if (n == 0) return + if (y%is_dev()) call y%sync_space() + + !$acc parallel loop + do i = 1, n + y%v(idx(i)) = beta * y%v(idx(i)) + x(i) + end do + + call y%set_host() + end subroutine z_oacc_sctb + + + subroutine z_oacc_gthzbuf(i, n, idx, x) + use psb_base_mod + implicit none + integer(psb_ipk_) :: i, n + class(psb_i_base_vect_type) :: idx + class(psb_z_vect_oacc) :: x + integer(psb_ipk_) :: info + + info = 0 + if (.not.allocated(x%combuf)) then + call psb_errpush(psb_err_alloc_dealloc_, 'gthzbuf') + return + end if + + select type(ii => idx) + class is (psb_i_vect_oacc) + if (ii%is_host()) call ii%sync_space(info) + class default + call psb_errpush(info, 'z_oacc_gthzbuf') + return + end select + + if (x%is_host()) call x%sync_space() + + !$acc parallel loop + do i = 1, n + x%combuf(i) = x%v(idx%v(i)) + end do + end subroutine z_oacc_gthzbuf + + subroutine z_oacc_gthzv_x(i, n, idx, x, y) + use psb_base_mod + implicit none + integer(psb_ipk_) :: i, n + class(psb_i_base_vect_type):: idx + complex(psb_dpk_) :: y(:) + class(psb_z_vect_oacc):: x + integer(psb_ipk_) :: info + + info = 0 + + select type(ii => idx) + class is (psb_i_vect_oacc) + if (ii%is_host()) call ii%sync_space(info) + class default + call psb_errpush(info, 'z_oacc_gthzv_x') + return + end select + + if (x%is_host()) call x%sync_space() + + !$acc parallel loop + do i = 1, n + y(i) = x%v(idx%v(i)) + end do + end subroutine z_oacc_gthzv_x + + subroutine z_oacc_ins_v(n, irl, val, dupl, x, info) + use psi_serial_mod + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n, dupl + class(psb_i_base_vect_type), intent(inout) :: irl + class(psb_z_base_vect_type), intent(inout) :: val + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i, isz + logical :: done_oacc + + info = 0 + if (psb_errstatus_fatal()) return + + done_oacc = .false. + select type(virl => irl) + type is (psb_i_vect_oacc) + select type(vval => val) + type is (psb_z_vect_oacc) + if (vval%is_host()) call vval%sync_space() + if (virl%is_host()) call virl%sync_space(info) + if (x%is_host()) call x%sync_space() + !$acc parallel loop + do i = 1, n + x%v(virl%v(i)) = vval%v(i) + end do + call x%set_dev() + done_oacc = .true. + end select + end select + + if (.not.done_oacc) then + select type(virl => irl) + type is (psb_i_vect_oacc) + if (virl%is_dev()) call virl%sync_space(info) + end select + select type(vval => val) + type is (psb_z_vect_oacc) + if (vval%is_dev()) call vval%sync_space() + end select + call x%ins(n, irl%v, val%v, dupl, info) + end if + + if (info /= 0) then + call psb_errpush(info, 'oacc_vect_ins') + return + end if + + end subroutine z_oacc_ins_v + + + + subroutine z_oacc_ins_a(n, irl, val, dupl, x, info) + use psi_serial_mod + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(in) :: n, dupl + integer(psb_ipk_), intent(in) :: irl(:) + complex(psb_dpk_), intent(in) :: val(:) + integer(psb_ipk_), intent(out) :: info + + integer(psb_ipk_) :: i + + info = 0 + if (x%is_dev()) call x%sync_space() + call x%psb_z_base_vect_type%ins(n, irl, val, dupl, info) + call x%set_host() + !$acc update device(x%v) + + end subroutine z_oacc_ins_a + + + + subroutine z_oacc_bld_mn(x, n) + use psb_base_mod + implicit none + integer(psb_mpk_), intent(in) :: n + class(psb_z_vect_oacc), intent(inout) :: x + integer(psb_ipk_) :: info + + call x%all(n, info) + if (info /= 0) then + call psb_errpush(info, 'z_oacc_bld_mn', i_err=(/n, n, n, n, n/)) + end if + call x%set_host() + !$acc update device(x%v) + + end subroutine z_oacc_bld_mn + + + subroutine z_oacc_bld_x(x, this) + use psb_base_mod + implicit none + complex(psb_dpk_), intent(in) :: this(:) + class(psb_z_vect_oacc), intent(inout) :: x + integer(psb_ipk_) :: info + + call psb_realloc(size(this), x%v, info) + if (info /= 0) then + info = psb_err_alloc_request_ + call psb_errpush(info, 'z_oacc_bld_x', & + i_err=(/size(this), izero, izero, izero, izero/)) + return + end if + + x%v(:) = this(:) + call x%set_host() + !$acc update device(x%v) + + end subroutine z_oacc_bld_x + + + subroutine z_oacc_asb_m(n, x, info) + use psb_base_mod + implicit none + integer(psb_mpk_), intent(in) :: n + class(psb_z_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + integer(psb_mpk_) :: nd + + info = psb_success_ + + if (x%is_dev()) then + nd = size(x%v) + if (nd < n) then + call x%sync() + call x%psb_z_base_vect_type%asb(n, info) + if (info == psb_success_) call x%sync_space() + call x%set_host() + end if + else + if (size(x%v) < n) then + call x%psb_z_base_vect_type%asb(n, info) + if (info == psb_success_) call x%sync_space() + call x%set_host() + end if + end if + end subroutine z_oacc_asb_m + + + + subroutine z_oacc_set_scal(x, val, first, last) + class(psb_z_vect_oacc), intent(inout) :: x + complex(psb_dpk_), intent(in) :: val + integer(psb_ipk_), optional :: first, last + + integer(psb_ipk_) :: first_, last_ + first_ = 1 + last_ = x%get_nrows() + if (present(first)) first_ = max(1, first) + if (present(last)) last_ = min(last, last_) + + !$acc parallel loop + do i = first_, last_ + x%v(i) = val + end do + !$acc end parallel loop + + call x%set_dev() + end subroutine z_oacc_set_scal + + + + subroutine z_oacc_zero(x) + use psi_serial_mod + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + call x%set_dev() + call x%set_scal(zzero) + end subroutine z_oacc_zero + + function z_oacc_get_nrows(x) result(res) + implicit none + class(psb_z_vect_oacc), intent(in) :: x + integer(psb_ipk_) :: res + + if (allocated(x%v)) res = size(x%v) + end function z_oacc_get_nrows + + function z_oacc_get_fmt() result(res) + implicit none + character(len=5) :: res + res = "zOACC" + + end function z_oacc_get_fmt + + function z_oacc_vect_dot(n, x, y) result(res) + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + class(psb_z_base_vect_type), intent(inout) :: y + integer(psb_ipk_), intent(in) :: n + complex(psb_dpk_) :: res + complex(psb_dpk_), external :: ddot + integer(psb_ipk_) :: info + integer(psb_ipk_) :: i + + res = zzero + + select type(yy => y) + type is (psb_z_base_vect_type) + if (x%is_dev()) call x%sync() + res = ddot(n, x%v, 1, yy%v, 1) + type is (psb_z_vect_oacc) + if (x%is_host()) call x%sync() + if (yy%is_host()) call yy%sync() + + !$acc parallel loop reduction(+:res) present(x%v, yy%v) + do i = 1, n + res = res + x%v(i) * yy%v(i) + end do + !$acc end parallel loop + + class default + call x%sync() + res = y%dot(n, x%v) + end select + + end function z_oacc_vect_dot + + + + + function z_oacc_dot_a(n, x, y) result(res) + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + complex(psb_dpk_), intent(in) :: y(:) + integer(psb_ipk_), intent(in) :: n + complex(psb_dpk_) :: res + complex(psb_dpk_), external :: ddot + + if (x%is_dev()) call x%sync() + res = ddot(n, y, 1, x%v, 1) + + end function z_oacc_dot_a + + ! subroutine z_oacc_set_vect(x,y) + ! implicit none + ! class(psb_z_vect_oacc), intent(inout) :: x + ! complex(psb_dpk_), intent(in) :: y(:) + ! integer(psb_ipk_) :: info + + ! if (size(x%v) /= size(y)) then + ! call x%free(info) + ! call x%all(size(y),info) + ! end if + ! x%v(:) = y(:) + ! call x%set_host() + ! end subroutine z_oacc_set_vect + + subroutine z_oacc_to_dev(v) + implicit none + complex(psb_dpk_) :: v(:) + !$acc update device(v) + end subroutine z_oacc_to_dev + + subroutine z_oacc_to_host(v) + implicit none + complex(psb_dpk_) :: v(:) + !$acc update self(v) + end subroutine z_oacc_to_host + + subroutine z_oacc_sync_space(x) + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + if (allocated(x%v)) then + call z_oacc_create_dev(x%v) + end if + contains + subroutine z_oacc_create_dev(v) + implicit none + complex(psb_dpk_) :: v(:) + !$acc enter data copyin(v) + end subroutine z_oacc_create_dev + end subroutine z_oacc_sync_space + + subroutine z_oacc_sync(x) + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + if (x%is_dev()) then + call z_oacc_to_host(x%v) + end if + if (x%is_host()) then + call z_oacc_to_dev(x%v) + end if + call x%set_sync() + end subroutine z_oacc_sync + + subroutine z_oacc_set_host(x) + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + + x%state = is_host + end subroutine z_oacc_set_host + + subroutine z_oacc_set_dev(x) + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + + x%state = is_dev + end subroutine z_oacc_set_dev + + subroutine z_oacc_set_sync(x) + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + + x%state = is_sync + end subroutine z_oacc_set_sync + + function z_oacc_is_dev(x) result(res) + implicit none + class(psb_z_vect_oacc), intent(in) :: x + logical :: res + + res = (x%state == is_dev) + end function z_oacc_is_dev + + function z_oacc_is_host(x) result(res) + implicit none + class(psb_z_vect_oacc), intent(in) :: x + logical :: res + + res = (x%state == is_host) + end function z_oacc_is_host + + function z_oacc_is_sync(x) result(res) + implicit none + class(psb_z_vect_oacc), intent(in) :: x + logical :: res + + res = (x%state == is_sync) + end function z_oacc_is_sync + + subroutine z_oacc_vect_all(n, x, info) + use psi_serial_mod + use psb_realloc_mod + implicit none + integer(psb_ipk_), intent(in) :: n + class(psb_z_vect_oacc), intent(out) :: x + integer(psb_ipk_), intent(out) :: info + + call psb_realloc(n, x%v, info) + if (info == 0) then + call x%set_host() + !$acc enter data create(x%v) + call x%sync_space() + end if + if (info /= 0) then + info = psb_err_alloc_request_ + call psb_errpush(info, 'z_oacc_all', & + i_err=(/n, n, n, n, n/)) + end if + end subroutine z_oacc_vect_all + + + subroutine z_oacc_vect_free(x, info) + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + integer(psb_ipk_), intent(out) :: info + info = 0 + if (allocated(x%v)) then + !$acc exit data delete(x%v) finalize + deallocate(x%v, stat=info) + end if + + end subroutine z_oacc_vect_free + + function z_oacc_get_size(x) result(res) + implicit none + class(psb_z_vect_oacc), intent(inout) :: x + integer(psb_ipk_) :: res + + if (x%is_dev()) call x%sync() + res = size(x%v) + end function z_oacc_get_size + +end module psb_z_oacc_vect_mod