diff --git a/openacc/Makefile b/openacc/Makefile index c82a9281..aa6ea23d 100644 --- a/openacc/Makefile +++ b/openacc/Makefile @@ -20,7 +20,7 @@ CINCLUDES= # Source files FOBJS= psb_i_oacc_vect_mod.o psb_d_oacc_vect_mod.o \ - psb_oacc_mod.o psb_d_oacc_csr_mat_mod.o + psb_oacc_mod.o psb_d_oacc_csr_mat_mod.o \ psb_oacc_env_mod.o # Library name @@ -43,7 +43,7 @@ ilib: $(OBJS) $(MAKE) -C impl lib psb_oacc_mod.o : psb_i_oacc_vect_mod.o psb_d_oacc_vect_mod.o \ - psb_d_oacc_csr_mat_mod.o + psb_d_oacc_csr_mat_mod.o psb_oacc_env_mod.o clean: cclean iclean /bin/rm -f $(FOBJS) *$(.mod) *.a diff --git a/openacc/impl/psb_d_oacc_mlt_v.f90 b/openacc/impl/psb_d_oacc_mlt_v.f90 index a4eb6660..bedd0247 100644 --- a/openacc/impl/psb_d_oacc_mlt_v.f90 +++ b/openacc/impl/psb_d_oacc_mlt_v.f90 @@ -10,22 +10,22 @@ subroutine d_oacc_mlt_v(x, y, info) integer(psb_ipk_) :: i, n info = 0 -!!$ n = min(x%get_nrows(), y%get_nrows()) -!!$ select type(xx => x) -!!$ type is (psb_d_base_vect_type) -!!$ if (y%is_dev()) call y%sync() -!!$ !$acc parallel loop -!!$ do i = 1, n -!!$ y%v(i) = y%v(i) * xx%v(i) -!!$ end do -!!$ call y%set_host() -!!$ class default -!!$ if (xx%is_dev()) call xx%sync() -!!$ if (y%is_dev()) call y%sync() -!!$ !$acc parallel loop -!!$ do i = 1, n -!!$ y%v(i) = y%v(i) * xx%v(i) -!!$ end do -!!$ call y%set_host() -!!$ end select + n = min(x%get_nrows(), y%get_nrows()) + select type(xx => x) + class is (psb_d_vect_oacc) + if (y%is_host()) call y%sync() + if (xx%is_host()) call xx%sync() + !$acc parallel loop + do i = 1, n + y%v(i) = y%v(i) * xx%v(i) + end do + call y%set_dev() + class default + if (xx%is_dev()) call xx%sync() + if (y%is_dev()) call y%sync() + do i = 1, n + y%v(i) = y%v(i) * xx%v(i) + end do + call y%set_host() + end select end subroutine d_oacc_mlt_v diff --git a/openacc/impl/psb_d_oacc_mlt_v_2.f90 b/openacc/impl/psb_d_oacc_mlt_v_2.f90 index b59b4f56..7e46495f 100644 --- a/openacc/impl/psb_d_oacc_mlt_v_2.f90 +++ b/openacc/impl/psb_d_oacc_mlt_v_2.f90 @@ -19,37 +19,35 @@ subroutine d_oacc_mlt_v_2(alpha, x, y, beta, z, info, conjgx, conjgy) n = min(x%get_nrows(), y%get_nrows(), z%get_nrows()) info = 0 -!!$ select type(xx => x) -!!$ class is (psb_d_vect_oacc) -!!$ select type (yy => y) -!!$ class is (psb_d_vect_oacc) -!!$ if (xx%is_host()) call xx%sync_space() -!!$ if (yy%is_host()) call yy%sync_space() -!!$ if ((beta /= dzero) .and. (z%is_host())) call z%sync_space() -!!$ !$acc parallel loop -!!$ do i = 1, n -!!$ z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) -!!$ end do -!!$ call z%set_dev() -!!$ class default -!!$ if (xx%is_dev()) call xx%sync_space() -!!$ if (yy%is_dev()) call yy%sync() -!!$ if ((beta /= dzero) .and. (z%is_dev())) call z%sync_space() -!!$ !$acc parallel loop -!!$ do i = 1, n -!!$ z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) -!!$ end do -!!$ call z%set_host() -!!$ end select -!!$ class default -!!$ if (x%is_dev()) call x%sync() -!!$ if (y%is_dev()) call y%sync() -!!$ if ((beta /= dzero) .and. (z%is_dev())) call z%sync_space() -!!$ !$acc parallel loop -!!$ do i = 1, n -!!$ z%v(i) = alpha * x%v(i) * y%v(i) + beta * z%v(i) -!!$ end do -!!$ call z%set_host() -!!$ end select + select type(xx => x) + class is (psb_d_vect_oacc) + select type (yy => y) + class is (psb_d_vect_oacc) + if (xx%is_host()) call xx%sync() + if (yy%is_host()) call yy%sync() + if ((beta /= dzero) .and. (z%is_host())) call z%sync() + !$acc parallel loop + do i = 1, n + z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) + end do + call z%set_dev() + class default + if (xx%is_dev()) call xx%sync() + if (yy%is_dev()) call yy%sync() + if ((beta /= dzero) .and. (z%is_dev())) call z%sync() + do i = 1, n + z%v(i) = alpha * xx%v(i) * yy%v(i) + beta * z%v(i) + end do + call z%set_host() + end select + class default + if (x%is_dev()) call x%sync() + if (y%is_dev()) call y%sync() + if ((beta /= dzero) .and. (z%is_dev())) call z%sync() + do i = 1, n + z%v(i) = alpha * x%v(i) * y%v(i) + beta * z%v(i) + end do + call z%set_host() + end select end subroutine d_oacc_mlt_v_2