From 4ee8b847e0258da6ee335358009c67e54e21032f Mon Sep 17 00:00:00 2001 From: federicamontes Date: Mon, 16 Mar 2026 16:09:25 +0100 Subject: [PATCH] feature(neighbor a2av communication single precision) added support for neighbor a2av in single precision for vect and multivect --- base/comm/internals/psi_sswapdata.F90 | 523 +++++++++++++++++++++++++- 1 file changed, 520 insertions(+), 3 deletions(-) diff --git a/base/comm/internals/psi_sswapdata.F90 b/base/comm/internals/psi_sswapdata.F90 index bd3e6992..ea0a2c0e 100644 --- a/base/comm/internals/psi_sswapdata.F90 +++ b/base/comm/internals/psi_sswapdata.F90 @@ -92,7 +92,7 @@ submodule (psi_s_comm_v_mod) psi_s_swapdata_impl use psb_base_mod contains - subroutine psi_sswapdata_vect(flag,beta,y,desc_a,work,info,data) + module subroutine psi_sswapdata_vect(flag,beta,y,desc_a,work,info,data) #ifdef PSB_MPI_MOD use mpi @@ -176,6 +176,104 @@ contains module subroutine psi_sswap_vidx_vect(ctxt,flag,beta,y,idx, & & totxch,totsnd,totrcv,work,info) + +#ifdef PSB_MPI_MOD + use mpi +#endif + implicit none +#ifdef PSB_MPI_H + include 'mpif.h' +#endif + + type(psb_ctxt_type), intent(in) :: ctxt + !integer(psb_mpk_), intent(in) :: icomm + integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(out) :: info + class(psb_s_base_vect_type) :: y + real(psb_spk_), intent(in) :: beta + real(psb_spk_), target :: work(:) + class(psb_i_base_vect_type), intent(inout) :: idx + integer(psb_ipk_), intent(in) :: totxch,totsnd, totrcv + + ! local variables used to detect the communication scheme + logical :: swap_mpi, swap_sync, swap_send, swap_recv, swap_start, swap_wait + logical :: baseline, neighbor_a2av + + ! local variable used for get the communicator + integer(psb_mpk_) :: icomm + + ! error handling variables + integer(psb_ipk_) :: err_act + integer(psb_mpk_) :: me, np + character(len=30) :: name + + + info=psb_success_ + name='psi_sswap_vidx_vect' + call psb_erractionsave(err_act) + call psb_info(ctxt,me,np) + if (np == -1) then + info=psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + + + swap_mpi = iand(flag,psb_swap_mpi_) /= 0 + swap_sync = iand(flag,psb_swap_sync_) /= 0 + swap_send = iand(flag,psb_swap_send_) /= 0 + swap_recv = iand(flag,psb_swap_recv_) /= 0 + swap_start = iand(flag,psb_swap_start_) /= 0 + swap_wait = iand(flag,psb_swap_wait_) /= 0 + + baseline = swap_mpi .or. swap_send .or. swap_recv .or. swap_sync + neighbor_a2av = swap_start .or. swap_wait + + icomm = ctxt%get_mpic() + + if( (baseline.eqv..true.).and.(neighbor_a2av.eqv..true.) ) then + info=psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Incompatible flag settings: both baseline and neighbor_a2av are true') + goto 9999 + end if + + + if (baseline) then + call psi_sswap_baseline_vect(ctxt,icomm,flag,beta,y,idx,totxch,totsnd,totrcv,work,info) + if (info /= psb_success_) then + call psb_errpush(info,name,a_err='baseline swap') + goto 9999 + end if + else if (neighbor_a2av) then + call psi_sswap_neighbor_topology_vect(ctxt,icomm,flag,beta,y,idx,totxch,totsnd,totrcv,work,info) + if (info /= psb_success_) then + call psb_errpush(info,name,a_err='neighbor a2av swap') + goto 9999 + end if + else + info = psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Incompatible flag settings: neither baseline nor neighbor_a2av is true') + goto 9999 + end if + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + + end subroutine psi_sswap_vidx_vect + + + +! +! subroutine psi_sswap_baseline_vect +! This performs Isend/Irecv as a baseline communication mode +! +subroutine psi_sswap_baseline_vect(ctxt,icomm,flag,beta,y,idx, & + & totxch,totsnd,totrcv,work,info) + #ifdef PSB_MPI_MOD use mpi #endif @@ -410,7 +508,171 @@ contains 9999 call psb_error_handler(ctxt,err_act) return - end subroutine psi_sswap_vidx_vect + +end subroutine psi_sswap_baseline_vect + + +subroutine psi_sswap_neighbor_topology_vect(ctxt,icomm,flag,beta,y,idx, & + & totxch,totsnd,totrcv,work,info) + +#ifdef PSB_MPI_MOD + use mpi +#endif + implicit none +#ifdef PSB_MPI_H + include 'mpif.h' +#endif + + type(psb_ctxt_type), intent(in) :: ctxt + integer(psb_mpk_), intent(in) :: icomm + integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(out) :: info + class(psb_s_base_vect_type) :: y + real(psb_spk_), intent(in) :: beta + real(psb_spk_), target :: work(:) + class(psb_i_base_vect_type), intent(inout) :: idx + integer(psb_ipk_), intent(in) :: totxch,totsnd, totrcv + + ! locals + integer(psb_mpk_) :: np, me + integer(psb_mpk_) :: iret, p2pstat(mpi_status_size) + integer(psb_ipk_) :: err_act, topology_total_send, topology_total_recv, buffer_size + logical :: do_start, do_wait + logical, parameter :: debug = .false. + character(len=30) :: name + + + info = psb_success_ + name = 'psi_sswap_nbr_vect' + call psb_erractionsave(err_act) + call psb_info(ctxt,me,np) + if (np == -1) then + info=psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + + do_start = iand(flag,psb_swap_start_) /= 0 + do_wait = iand(flag,psb_swap_wait_) /= 0 + + call idx%sync() + + ! --------------------------------------------------------- + ! START phase: build topology (if needed), gather, post MPI + ! --------------------------------------------------------- + if (do_start) then + if(debug) write(*,*) me,' nbr_vect: starting data exchange' + ! Lazy initialization: build the topology on first call + if (.not. y%neighbor_topology%is_initialized) then + if (debug) write(*,*) me,' nbr_vect: building topology' + call y%neighbor_topology%init(idx%v, totxch, totsnd, totrcv, & + & ctxt, icomm, info) + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_, name, & + & a_err='neighbor_topology_init') + goto 9999 + end if + end if + + topology_total_send = y%neighbor_topology%total_send + topology_total_recv = y%neighbor_topology%total_recv + + ! Buffer layout: + ! combuf(1 : total_send) = send area + ! combuf(total_send+1 : total_send+total_recv) = recv area + buffer_size = topology_total_send + topology_total_recv + + call y%new_buffer(buffer_size, info) + if (info /= 0) then + call psb_errpush(psb_err_alloc_dealloc_, name) + goto 9999 + end if + y%communication_handle = mpi_request_null + + ! Gather send data into contiguous send buffer (polymorphic for GPU) + if (debug) write(*,*) me,' nbr_vect: gathering send data,', topology_total_send,' elems' + call y%gth(int(topology_total_send,psb_mpk_), & + & y%neighbor_topology%send_indexes, & + & y%combuf(1:topology_total_send)) + + ! Wait for device (important for GPU subclasses) + call y%device_wait() + + ! Post non-blocking neighborhood alltoallv + if (debug) write(*,*) me,' nbr_vect: posting MPI_Ineighbor_alltoallv' + call mpi_ineighbor_alltoallv( & + & y%combuf(1), & ! send buffer + & y%neighbor_topology%send_counts, & + & y%neighbor_topology%send_displs, & + & psb_mpi_r_spk_, & + & y%combuf(topology_total_send + 1), & ! recv buffer + & y%neighbor_topology%recv_counts, & + & y%neighbor_topology%recv_displs, & + & psb_mpi_r_spk_, & + & y%neighbor_topology%graph_comm, & + & y%communication_handle, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/iret/)) + goto 9999 + end if + + end if ! do_start + + ! --------------------------------------------------------- + ! WAIT phase: complete MPI, scatter received data + ! --------------------------------------------------------- + if (do_wait) then + + if (y%communication_handle == mpi_request_null) then + ! No matching start? Something is wrong + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/-2/)) + goto 9999 + end if + + topology_total_send = y%neighbor_topology%total_send + topology_total_recv = y%neighbor_topology%total_recv + + ! Wait for the non-blocking collective to complete + if (debug) write(*,*) me,' nbr_vect: waiting on MPI request' + call mpi_wait(y%communication_handle, p2pstat, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/iret/)) + goto 9999 + end if + + ! Scatter received data to local vector positions (polymorphic for GPU) + if (debug) write(*,*) me,' nbr_vect: scattering recv data,', topology_total_recv,' elems' + call y%sct(int(topology_total_recv,psb_mpk_), & + & y%neighbor_topology%recv_indexes, & + & y%combuf(topology_total_send+1:topology_total_send+topology_total_recv), & + & beta) + + + ! Clean up + y%communication_handle = mpi_request_null + call y%device_wait() + call y%maybe_free_buffer(info) + if (info /= 0) then + call psb_errpush(psb_err_alloc_dealloc_, name) + goto 9999 + end if + if (debug) write(*,*) me,' nbr_vect: done' + + end if ! do_wait + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return +end subroutine psi_sswap_neighbor_topology_vect + + + ! ! @@ -502,6 +764,99 @@ contains ! module subroutine psi_sswap_vidx_multivect(ctxt,flag,beta,y,idx, & & totxch,totsnd,totrcv,work,info) + #ifdef PSB_MPI_MOD + use mpi +#endif + implicit none +#ifdef PSB_MPI_H + include 'mpif.h' +#endif + + type(psb_ctxt_type), intent(in) :: ctxt + !integer(psb_mpk_), intent(in) :: icomm + integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(out) :: info + class(psb_s_base_multivect_type) :: y + real(psb_spk_), intent(in) :: beta + real(psb_spk_), target :: work(:) + class(psb_i_base_vect_type), intent(inout) :: idx + integer(psb_ipk_), intent(in) :: totxch,totsnd, totrcv + + ! local variables used to detect the communication scheme + logical :: swap_mpi, swap_sync, swap_send, swap_recv, swap_start, swap_wait + logical :: baseline, neighbor_a2av + + ! local variable used to get communicator + integer(psb_mpk_) :: icomm + + ! error handling variables + integer(psb_ipk_) :: err_act + integer(psb_mpk_) :: me, np + character(len=30) :: name + + + info=psb_success_ + name='psi_sswap_vidx_multivect' + call psb_erractionsave(err_act) + call psb_info(ctxt,me,np) + if (np == -1) then + info=psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + + + swap_mpi = iand(flag,psb_swap_mpi_) /= 0 + swap_sync = iand(flag,psb_swap_sync_) /= 0 + swap_send = iand(flag,psb_swap_send_) /= 0 + swap_recv = iand(flag,psb_swap_recv_) /= 0 + swap_start = iand(flag,psb_swap_start_) /= 0 + swap_wait = iand(flag,psb_swap_wait_) /= 0 + + baseline = swap_mpi .or. swap_send .or. swap_recv .or. swap_sync + neighbor_a2av = swap_start .or. swap_wait + + icomm = ctxt%get_mpic() + + if( (baseline.eqv..true.).and.(neighbor_a2av.eqv..true.) ) then + info=psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Incompatible flag settings: both baseline and neighbor_a2av are true') + goto 9999 + end if + + + if (baseline) then + call psi_sswap_baseline_multivect(ctxt,icomm,flag,beta,y,idx,totxch,totsnd,totrcv,work,info) + if (info /= psb_success_) then + call psb_errpush(info,name,a_err='baseline swap') + goto 9999 + end if + else if (neighbor_a2av) then + call psi_sswap_neighbor_topology_multivect(ctxt,icomm,flag,beta,y,idx,totxch,totsnd,totrcv,work,info) + if (info /= psb_success_) then + call psb_errpush(info,name,a_err='neighbor a2av swap') + goto 9999 + end if + else + info = psb_err_mpi_error_ + call psb_errpush(info,name,a_err='Incompatible flag settings: neither baseline nor neighbor_a2av is true') + goto 9999 + end if + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return + + end subroutine psi_sswap_vidx_multivect + + + + +subroutine psi_sswap_baseline_multivect(ctxt,icomm,flag,beta,y,idx, & + & totxch,totsnd,totrcv,work,info) #ifdef PSB_MPI_MOD use mpi @@ -742,6 +1097,168 @@ contains 9999 call psb_error_handler(ctxt,err_act) return - end subroutine psi_sswap_vidx_multivect +end subroutine psi_sswap_baseline_vidx_multivect + + +subroutine psi_sswap_neighbor_topology_multivect(ctxt,icomm,flag,beta,y,idx, & + & totxch,totsnd,totrcv,work,info) + +#ifdef PSB_MPI_MOD + use mpi +#endif + implicit none +#ifdef PSB_MPI_H + include 'mpif.h' +#endif + + type(psb_ctxt_type), intent(in) :: ctxt + integer(psb_mpk_), intent(in) :: icomm + integer(psb_ipk_), intent(in) :: flag + integer(psb_ipk_), intent(out) :: info + class(psb_s_base_multivect_type) :: y + real(psb_spk_), intent(in) :: beta + real(psb_spk_), target :: work(:) + class(psb_i_base_vect_type), intent(inout) :: idx + integer(psb_ipk_), intent(in) :: totxch,totsnd, totrcv + + ! locals + integer(psb_mpk_) :: np, me + integer(psb_mpk_) :: iret, p2pstat(mpi_status_size) + integer(psb_ipk_) :: err_act, topology_total_send, topology_total_recv, buffer_size + logical :: do_start, do_wait + logical, parameter :: debug = .false. + character(len=30) :: name + + + info = psb_success_ + name = 'psi_sswap_nbr_vect' + call psb_erractionsave(err_act) + call psb_info(ctxt,me,np) + if (np == -1) then + info=psb_err_context_error_ + call psb_errpush(info,name) + goto 9999 + endif + + do_start = iand(flag,psb_swap_start_) /= 0 + do_wait = iand(flag,psb_swap_wait_) /= 0 + + call idx%sync() + + ! --------------------------------------------------------- + ! START phase: build topology (if needed), gather, post MPI + ! --------------------------------------------------------- + if (do_start) then + if(debug) write(*,*) me,' nbr_vect: starting data exchange' + ! Lazy initialization: build the topology on first call + if (.not. y%neighbor_topology%is_initialized) then + if (debug) write(*,*) me,' nbr_vect: building topology' + call y%neighbor_topology%init(idx%v, totxch, totsnd, totrcv, & + & ctxt, icomm, info) + if (info /= psb_success_) then + call psb_errpush(psb_err_internal_error_, name, & + & a_err='neighbor_topology_init') + goto 9999 + end if + end if + + topology_total_send = y%neighbor_topology%total_send + topology_total_recv = y%neighbor_topology%total_recv + + ! Buffer layout: + ! combuf(1 : total_send) = send area + ! combuf(total_send+1 : total_send+total_recv) = recv area + buffer_size = topology_total_send + topology_total_recv + + call y%new_buffer(buffer_size, info) + if (info /= 0) then + call psb_errpush(psb_err_alloc_dealloc_, name) + goto 9999 + end if + y%communication_handle = mpi_request_null + + ! Gather send data into contiguous send buffer (polymorphic for GPU) + if (debug) write(*,*) me,' nbr_vect: gathering send data,', topology_total_send,' elems' + call y%gth(int(topology_total_send,psb_mpk_), & + & y%neighbor_topology%send_indexes, & + & y%combuf(1:topology_total_send)) + + ! Wait for device (important for GPU subclasses) + call y%device_wait() + + ! Post non-blocking neighborhood alltoallv + if (debug) write(*,*) me,' nbr_vect: posting MPI_Ineighbor_alltoallv' + call mpi_ineighbor_alltoallv( & + & y%combuf(1), & ! send buffer + & y%neighbor_topology%send_counts, & + & y%neighbor_topology%send_displs, & + & psb_mpi_r_spk_, & + & y%combuf(topology_total_send + 1), & ! recv buffer + & y%neighbor_topology%recv_counts, & + & y%neighbor_topology%recv_displs, & + & psb_mpi_r_spk_, & + & y%neighbor_topology%graph_comm, & + & y%communication_handle, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/iret/)) + goto 9999 + end if + + end if ! do_start + + ! --------------------------------------------------------- + ! WAIT phase: complete MPI, scatter received data + ! --------------------------------------------------------- + if (do_wait) then + + if (y%communication_handle == mpi_request_null) then + ! No matching start? Something is wrong + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/-2/)) + goto 9999 + end if + + topology_total_send = y%neighbor_topology%total_send + topology_total_recv = y%neighbor_topology%total_recv + + ! Wait for the non-blocking collective to complete + if (debug) write(*,*) me,' nbr_vect: waiting on MPI request' + call mpi_wait(y%communication_handle, p2pstat, iret) + if (iret /= mpi_success) then + info = psb_err_mpi_error_ + call psb_errpush(info, name, m_err=(/iret/)) + goto 9999 + end if + + ! Scatter received data to local vector positions (polymorphic for GPU) + if (debug) write(*,*) me,' nbr_vect: scattering recv data,', topology_total_recv,' elems' + call y%sct(int(topology_total_recv,psb_mpk_), & + & y%neighbor_topology%recv_indexes, & + & y%combuf(topology_total_send+1:topology_total_send+topology_total_recv), & + & beta) + + + ! Clean up + y%communication_handle = mpi_request_null + call y%device_wait() + call y%maybe_free_buffer(info) + if (info /= 0) then + call psb_errpush(psb_err_alloc_dealloc_, name) + goto 9999 + end if + if (debug) write(*,*) me,' nbr_vect: done' + + end if ! do_wait + + call psb_erractionrestore(err_act) + return + +9999 call psb_error_handler(ctxt,err_act) + + return +end subroutine psi_sswap_neighbor_topology_multivect + + end submodule psi_s_swapdata_impl