From fcae4a16338146ae46ef699cc2df0aacf606429f Mon Sep 17 00:00:00 2001 From: Stack-1 Date: Sun, 19 Apr 2026 00:11:45 +0200 Subject: [PATCH] [UPDATE] Transient version usefull for debug on local server --- base/comm/internals/psi_dswapdata.F90 | 318 +++++++++++++++++++-- linsolve/impl/psb_dcg.F90 | 5 + test/comm/cg/psb_comm_cg_test.F90 | 48 ++-- test/comm/spmv/psb_spmv_overlap_sbatch.sh | 50 ++++ test/comm/spmv/psb_spmv_overlap_test.f90 | 234 +++++++++++---- test/comm/swapdata/psb_comm_test.F90 | 126 ++++++-- test/comm/swapdata/psb_comm_test_sbatch.sh | 62 ++++ 7 files changed, 723 insertions(+), 120 deletions(-) create mode 100644 test/comm/spmv/psb_spmv_overlap_sbatch.sh create mode 100644 test/comm/swapdata/psb_comm_test_sbatch.sh diff --git a/base/comm/internals/psi_dswapdata.F90 b/base/comm/internals/psi_dswapdata.F90 index 5d8b8ff3..6c0262dd 100644 --- a/base/comm/internals/psi_dswapdata.F90 +++ b/base/comm/internals/psi_dswapdata.F90 @@ -82,8 +82,61 @@ submodule (psi_d_comm_v_mod) psi_d_swapdata_impl use psb_desc_const_mod, only: psb_swap_start_, psb_swap_wait_ use psb_base_mod + use psb_error_mod, only: psb_get_debug_level, psb_get_debug_unit, psb_debug_ext_ use psb_comm_factory_mod + + logical, save :: psb_swap_timing_inited = .false. + logical, save :: psb_swap_timing_enabled = .false. + integer(psb_ipk_), save :: psb_swap_timing_max_report = 32 + integer(psb_ipk_), save :: psb_swap_timing_report_count = 0 + integer(psb_ipk_), save :: psb_swap_timing_wrapper_calls = 0 + integer(psb_ipk_), save :: psb_swap_timing_baseline_calls = 0 + integer(psb_ipk_), save :: psb_swap_timing_neighbor_calls = 0 + contains + + subroutine psb_swap_timing_setup() + implicit none + character(len=64) :: env_buf + integer(psb_ipk_) :: env_len, env_status, ios + + if (psb_swap_timing_inited) return + + psb_swap_timing_inited = .true. + psb_swap_timing_enabled = .false. + psb_swap_timing_max_report = 32 + + call get_environment_variable('PSB_SWAP_TIMING', env_buf, length=env_len, status=env_status) + if ((env_status == 0) .and. (env_len > 0)) then + select case(env_buf(1:1)) + case('1','t','T','y','Y') + psb_swap_timing_enabled = .true. + case default + psb_swap_timing_enabled = .false. + end select + end if + + call get_environment_variable('PSB_SWAP_TIMING_MAX_REPORT', env_buf, length=env_len, status=env_status) + if ((env_status == 0) .and. (env_len > 0)) then + read(env_buf(1:env_len), *, iostat=ios) psb_swap_timing_max_report + if ((ios /= 0) .or. (psb_swap_timing_max_report < 1)) psb_swap_timing_max_report = 32 + end if + + end subroutine psb_swap_timing_setup + + logical function psb_swap_timing_should_report() + implicit none + + call psb_swap_timing_setup() + + psb_swap_timing_should_report = .false. + if (.not. psb_swap_timing_enabled) return + if (psb_swap_timing_report_count >= psb_swap_timing_max_report) return + + psb_swap_timing_report_count = psb_swap_timing_report_count + 1 + psb_swap_timing_should_report = .true. + end function psb_swap_timing_should_report + module subroutine psi_dswapdata_vect(swap_status,beta,y,desc_a,info,data) #ifdef PSB_MPI_MOD @@ -103,8 +156,14 @@ contains ! locals type(psb_ctxt_type) :: ctxt - integer(psb_ipk_) :: np, me, total_send, total_recv, num_neighbors, data_, err_act + integer(psb_ipk_) :: np, me, total_send, total_recv, num_neighbors, data_ class(psb_i_base_vect_type), pointer :: comm_indexes + logical :: debug_on + integer(psb_ipk_) :: dbg_unit + logical :: timing_on, timing_report + real(psb_dpk_) :: t0, t1, t_get_list, t_kernel, t_total + integer(psb_ipk_) :: call_idx + character(len=24) :: phase_name, scheme_name, exchange_name ! communication scheme/status selectors logical :: baseline, neighbor_a2av @@ -126,6 +185,26 @@ contains goto 9999 endif + debug_on = (psb_get_debug_level() >= psb_debug_ext_) + call psb_swap_timing_setup() + timing_on = psb_swap_timing_enabled + timing_report = .false. + if (timing_on) then + t_get_list = dzero + t_kernel = dzero + t_total = dzero + t0 = psb_wtime() + call_idx = psb_swap_timing_wrapper_calls + 1 + end if + + if (debug_on) then + dbg_unit = psb_get_debug_unit() + if (dbg_unit <= 0) dbg_unit = psb_err_unit + write(dbg_unit,*) me, trim(name), ': enter swap_status=', swap_status, & + & ' data=', data_, ' local_rows=', desc_a%get_local_rows(), & + & ' local_cols=', desc_a%get_local_cols(), ' y_nrows=', y%get_nrows() + end if + if (.not.psb_is_asb_desc(desc_a)) then info=psb_err_invalid_cd_state_ call psb_errpush(info,name) @@ -138,17 +217,19 @@ contains data_ = psb_comm_halo_ end if + if (timing_on) t1 = psb_wtime() call desc_a%get_list_p(data_,comm_indexes,num_neighbors,total_recv,total_send,info) + if (timing_on) t_get_list = psb_wtime() - t1 if (info /= psb_success_) then call psb_errpush(psb_err_internal_error_,name,a_err='desc_a%get_list_p') goto 9999 end if - ! Debug: report list sizes - ! if(me == 0) then - ! write(psb_err_unit,*) me, 'DBG: get_list_p -> num_neighbors=', & - ! & num_neighbors, ' total_send=', total_send, ' total_recv=', total_recv - ! end if + if (debug_on) then + write(dbg_unit,*) me, trim(name), ': list_p num_neighbors=', num_neighbors, & + & ' total_send=', total_send, ' total_recv=', total_recv, & + & ' comm_indexes_size=', size(comm_indexes%v) + end if if( (swap_status /= psb_comm_status_start_).and.(swap_status /= psb_comm_status_wait_)& @@ -181,6 +262,11 @@ contains goto 9999 end if + if (debug_on) then + write(dbg_unit,*) me, trim(name), ': comm_type=', y%comm_handle%comm_type, & + & ' swap_status=', swap_status + end if + ! if(me == 0) then ! write(psb_err_unit,*) me, 'DBG: after set_swap_status, info=', info ! end if @@ -199,14 +285,18 @@ contains ! end if if (baseline) then + if (timing_on) t1 = psb_wtime() call psi_dswap_baseline_vect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,total_send,total_recv,y%comm_handle,info) + if (timing_on) t_kernel = psb_wtime() - t1 if (info /= psb_success_) then call psb_errpush(info,name,a_err='baseline swap') goto 9999 end if else if (neighbor_a2av) then + if (timing_on) t1 = psb_wtime() call psi_dswap_neighbor_topology_vect(ctxt,swap_status,beta,y,comm_indexes,num_neighbors,& & total_send,total_recv,y%comm_handle,info) + if (timing_on) t_kernel = psb_wtime() - t1 if (info /= psb_success_) then call psb_errpush(info,name,a_err='neighbor a2av swap') goto 9999 @@ -217,6 +307,45 @@ contains goto 9999 end if + if (timing_on) then + t_total = psb_wtime() - t0 + call psb_amx(ctxt, t_get_list) + call psb_amx(ctxt, t_kernel) + call psb_amx(ctxt, t_total) + if (me == psb_root_) timing_report = psb_swap_timing_should_report() + if ((me == psb_root_) .and. timing_report) then + psb_swap_timing_wrapper_calls = call_idx + select case(swap_status) + case(psb_comm_status_start_) + phase_name = 'start' + case(psb_comm_status_wait_) + phase_name = 'wait' + case(psb_comm_status_sync_) + phase_name = 'sync' + case default + phase_name = 'unknown' + end select + if (baseline) then + scheme_name = 'baseline' + else + if (y%comm_handle%comm_type == psb_comm_persistent_ineighbor_alltoallv_) then + scheme_name = 'persistent_neighbor' + else + scheme_name = 'neighbor' + end if + end if + if (call_idx == 1) then + exchange_name = 'first' + else + exchange_name = 'steady' + end if + write(psb_out_unit,'("SWAP_TIMING wrapper scheme=",a,", phase=",a,", exchange=",a,", call=",i0)') & + & trim(scheme_name), trim(phase_name), trim(exchange_name), call_idx + write(psb_out_unit,'(" get_list=",es12.5,", kernel=",es12.5,", total=",es12.5)') & + & t_get_list, t_kernel, t_total + end if + end if + call psb_erractionrestore(err_act) return @@ -253,8 +382,15 @@ contains integer(psb_ipk_) :: err_act, i, idx_pt, total_send_, total_recv_,& & snd_pt, rcv_pt, pnti, n logical :: do_send,do_recv - logical, parameter :: usersend=.false., debug=.false. + logical, parameter :: usersend=.false. + logical :: debug + logical :: timing_on, timing_report + integer(psb_ipk_) :: dbg_unit character(len=20) :: name + real(psb_dpk_) :: t0, t1 + real(psb_dpk_) :: t_buf, t_gth, t_post, t_wait, t_sct, t_dev, t_total + integer(psb_ipk_) :: call_idx + character(len=12) :: exchange_name info = psb_success_ name = 'psi_dswap_baseline_vect' @@ -268,6 +404,24 @@ contains icomm = ctxt%get_mpic() + debug = (psb_get_debug_level() >= psb_debug_ext_) + call psb_swap_timing_setup() + timing_on = psb_swap_timing_enabled + timing_report = .false. + if (timing_on) then + t_buf = dzero + t_gth = dzero + t_post = dzero + t_wait = dzero + t_sct = dzero + t_dev = dzero + t_total = dzero + t0 = psb_wtime() + call_idx = psb_swap_timing_baseline_calls + 1 + end if + dbg_unit = psb_get_debug_unit() + if (dbg_unit <= 0) dbg_unit = psb_err_unit + baseline_comm_handle => null() select type(ch => comm_handle) type is(psb_comm_baseline_handle) @@ -292,7 +446,7 @@ contains total_send_ = total_send * n call comm_indexes%sync() - if (debug) write(*,*) me,'Internal buffer' + if (debug) write(dbg_unit,*) me,'Internal buffer' if (do_send) then if (allocated(baseline_comm_handle%comid)) then if (any(baseline_comm_handle%comid /= mpi_request_null)) then @@ -304,13 +458,16 @@ contains goto 9999 end if end if - if (debug) write(*,*) me,'do_send start' + if (debug) write(dbg_unit,*) me,'do_send start' + if (timing_on) t1 = psb_wtime() call y%new_buffer(ione*size(comm_indexes%v),info) call psb_realloc(num_neighbors,2_psb_ipk_,baseline_comm_handle%comid,info) baseline_comm_handle%comid = mpi_request_null call psb_realloc(num_neighbors,prcid,info) + if (timing_on) t_buf = t_buf + (psb_wtime() - t1) ! First I post all the non blocking receives pnti = 1 + if (timing_on) t1 = psb_wtime() do i=1, num_neighbors proc_to_comm = comm_indexes%v(pnti+psb_proc_id_) nerv = comm_indexes%v(pnti+psb_n_elem_recv_) @@ -319,7 +476,7 @@ contains rcv_pt = 1+pnti+psb_n_elem_recv_ prcid(i) = psb_get_mpi_rank(ctxt,proc_to_comm) if ((nerv>0).and.(proc_to_comm /= me)) then - if (debug) write(*,*) me,'Posting receive from',prcid(i),rcv_pt + if (debug) write(dbg_unit,*) me,'Posting receive from',prcid(i),rcv_pt p2ptag = psb_double_swap_tag call mpi_irecv(y%combuf(rcv_pt),nerv,& & psb_mpi_r_dpk_,prcid(i),& @@ -327,11 +484,13 @@ contains end if pnti = pnti + nerv + nesd + 3 end do - if (debug) write(*,*) me,' Gather ' + if (timing_on) t_post = t_post + (psb_wtime() - t1) + if (debug) write(dbg_unit,*) me,' Gather ' ! ! Then gather for sending. ! pnti = 1 + if (timing_on) t1 = psb_wtime() do i=1, num_neighbors nerv = comm_indexes%v(pnti+psb_n_elem_recv_) nesd = comm_indexes%v(pnti+nerv+psb_n_elem_send_) @@ -351,13 +510,16 @@ contains call y%gth(idx_pt,nesd,comm_indexes) pnti = pnti + nerv + nesd + 3 end do + if (timing_on) t_gth = t_gth + (psb_wtime() - t1) ! ! Then wait ! + if (timing_on) t1 = psb_wtime() call y%device_wait() + if (timing_on) t_dev = t_dev + (psb_wtime() - t1) - if (debug) write(*,*) me,' isend' + if (debug) write(dbg_unit,*) me,' isend' ! ! Then send ! @@ -366,6 +528,7 @@ contains snd_pt = 1 rcv_pt = 1 p2ptag = psb_double_swap_tag + if (timing_on) t1 = psb_wtime() do i=1, num_neighbors proc_to_comm = comm_indexes%v(pnti+psb_proc_id_) nerv = comm_indexes%v(pnti+psb_n_elem_recv_) @@ -387,10 +550,11 @@ contains pnti = pnti + nerv + nesd + 3 end do + if (timing_on) t_post = t_post + (psb_wtime() - t1) end if if (do_recv) then - if (debug) write(*,*) me,' do_Recv' + if (debug) write(dbg_unit,*) me,' do_Recv' if (.not.allocated(baseline_comm_handle%comid)) then ! ! No matching send? Something is wrong.... @@ -401,9 +565,10 @@ contains end if call psb_realloc(num_neighbors,prcid,info) - if (debug) write(*,*) me,' wait' + if (debug) write(dbg_unit,*) me,' wait' pnti = 1 p2ptag = psb_double_swap_tag + if (timing_on) t1 = psb_wtime() do i=1, num_neighbors proc_to_comm = comm_indexes%v(pnti+psb_proc_id_) nerv = comm_indexes%v(pnti+psb_n_elem_recv_) @@ -438,11 +603,13 @@ contains end if pnti = pnti + nerv + nesd + 3 end do + if (timing_on) t_wait = t_wait + (psb_wtime() - t1) - if (debug) write(*,*) me,' scatter' + if (debug) write(dbg_unit,*) me,' scatter' pnti = 1 snd_pt = 1 rcv_pt = 1 + if (timing_on) t1 = psb_wtime() do i=1, num_neighbors proc_to_comm = comm_indexes%v(pnti+psb_proc_id_) nerv = comm_indexes%v(pnti+psb_n_elem_recv_) @@ -462,11 +629,12 @@ contains goto 9999 end if - if (debug) write(0,*)me,' Received from: ',prcid(i),& + if (debug) write(dbg_unit,*)me,' Received from: ',prcid(i),& & y%combuf(rcv_pt:rcv_pt+nerv-1) call y%sct(rcv_pt,nerv,comm_indexes,beta) pnti = pnti + nerv + nesd + 3 end do + if (timing_on) t_sct = t_sct + (psb_wtime() - t1) ! ! Waited for everybody, clean up ! @@ -475,9 +643,10 @@ contains ! ! Then wait for device ! - if (debug) write(*,*) me,' wait' + if (debug) write(dbg_unit,*) me,' wait' + if (timing_on) t1 = psb_wtime() call y%device_wait() - if (debug) write(*,*) me,' free buffer' + if (debug) write(dbg_unit,*) me,' free buffer' call y%maybe_free_buffer(info) if (info == 0) then if (allocated(y%comm_handle)) call y%comm_handle%free(info) @@ -486,7 +655,33 @@ contains call psb_errpush(psb_err_alloc_dealloc_,name) goto 9999 end if - if (debug) write(*,*) me,' done' + if (timing_on) t_dev = t_dev + (psb_wtime() - t1) + if (debug) write(dbg_unit,*) me,' done' + end if + + if (timing_on) then + t_total = psb_wtime() - t0 + call psb_amx(ctxt, t_buf) + call psb_amx(ctxt, t_gth) + call psb_amx(ctxt, t_post) + call psb_amx(ctxt, t_wait) + call psb_amx(ctxt, t_sct) + call psb_amx(ctxt, t_dev) + call psb_amx(ctxt, t_total) + if (me == psb_root_) timing_report = psb_swap_timing_should_report() + if ((me == psb_root_) .and. timing_report) then + psb_swap_timing_baseline_calls = call_idx + if (call_idx == 1) then + exchange_name = 'first' + else + exchange_name = 'steady' + end if + write(psb_out_unit,'("SWAP_TIMING baseline phase start=",l1,", wait=",l1)') do_send, do_recv + write(psb_out_unit,'(" exchange=",a,", call=",i0)') trim(exchange_name), call_idx + write(psb_out_unit,'(" buf=",es12.5,", gth=",es12.5,", post=",es12.5,", wait=",es12.5)') & + & t_buf, t_gth, t_post, t_wait + write(psb_out_unit,'(" sct=",es12.5,", dev=",es12.5,", total=",es12.5)') t_sct, t_dev, t_total + end if end if @@ -526,8 +721,14 @@ contains type(psb_comm_neighbor_handle), pointer :: neighbor_comm_handle integer(psb_ipk_) :: err_act, topology_total_send, topology_total_recv, buffer_size logical :: do_start, do_wait - logical, parameter :: debug = .false. + logical :: debug + logical :: timing_on, timing_report + integer(psb_ipk_) :: dbg_unit character(len=30) :: name + real(psb_dpk_) :: t0, t1 + real(psb_dpk_) :: t_topo, t_buf, t_gth, t_init, t_post, t_wait, t_sct, t_dev, t_total + integer(psb_ipk_) :: call_idx + character(len=12) :: exchange_name info = psb_success_ @@ -541,6 +742,25 @@ contains endif icomm = ctxt%get_mpic() + dbg_unit = psb_get_debug_unit() + if (dbg_unit <= 0) dbg_unit = psb_err_unit + debug = (psb_get_debug_level() >= psb_debug_ext_) + call psb_swap_timing_setup() + timing_on = psb_swap_timing_enabled + timing_report = .false. + if (timing_on) then + t_topo = dzero + t_buf = dzero + t_gth = dzero + t_init = dzero + t_post = dzero + t_wait = dzero + t_sct = dzero + t_dev = dzero + t_total = dzero + t0 = psb_wtime() + call_idx = psb_swap_timing_neighbor_calls + 1 + end if neighbor_comm_handle => null() select type(ch => comm_handle) @@ -567,7 +787,7 @@ contains ! START phase: build topology (if needed), gather, post MPI ! --------------------------------------------------------- if (do_start) then - if(debug) write(*,*) me,' nbr_vect: starting data exchange' + if(debug) write(dbg_unit,*) me,' nbr_vect: starting data exchange' if (neighbor_comm_handle%use_persistent_buffers) then if (neighbor_comm_handle%persistent_in_flight) then info = psb_err_mpi_error_ @@ -576,8 +796,10 @@ contains end if end if if (.not. neighbor_comm_handle%is_initialized) then - if (debug) write(*,*) me,' nbr_vect: building topology via handle' + if (debug) write(dbg_unit,*) me,' nbr_vect: building topology via handle' + if (timing_on) t1 = psb_wtime() call neighbor_comm_handle%topology_init(comm_indexes%v, num_neighbors, total_send, total_recv, ctxt, icomm, info) + if (timing_on) t_topo = t_topo + (psb_wtime() - t1) if (info /= psb_success_) then call psb_errpush(psb_err_internal_error_, name, a_err='neighbor_topology_init') goto 9999 @@ -592,6 +814,7 @@ contains buffer_size = topology_total_send + topology_total_recv if (buffer_size > 0) then + if (timing_on) t1 = psb_wtime() if (neighbor_comm_handle%use_persistent_buffers) then if (.not. allocated(y%combuf)) then neighbor_comm_handle%diag_buffer_reallocs = neighbor_comm_handle%diag_buffer_reallocs + 1 @@ -633,13 +856,16 @@ contains goto 9999 end if end if + if (timing_on) t_buf = t_buf + (psb_wtime() - t1) neighbor_comm_handle%comm_request = mpi_request_null ! Gather send data into contiguous send buffer (polymorphic for GPU) if (debug) write(*,*) me,' nbr_vect: gathering send data,', topology_total_send,' elems' + if (timing_on) t1 = psb_wtime() call y%gth(int(topology_total_send,psb_mpk_), & & neighbor_comm_handle%send_indexes, & & y%combuf(1:topology_total_send)) + if (timing_on) t_gth = t_gth + (psb_wtime() - t1) else ! No data to send/recv: ensure requests/buffers indicate idle state neighbor_comm_handle%comm_request = mpi_request_null @@ -648,7 +874,9 @@ contains end if ! Wait for device (important for GPU subclasses) + if (timing_on) t1 = psb_wtime() call y%device_wait() + if (timing_on) t_dev = t_dev + (psb_wtime() - t1) if (neighbor_comm_handle%use_persistent_buffers) then ! Lazy persistent-init: build the request once, then reuse with START/WAIT. @@ -656,6 +884,7 @@ contains #ifdef PSB_HAVE_MPI_NEIGHBOR_PERSISTENT if (buffer_size > 0) then if (debug) write(*,*) me,' nbr_vect: posting MPI_Neighbor_alltoallv_init' + if (timing_on) t1 = psb_wtime() call mpi_neighbor_alltoallv_init( & & y%combuf(1), & ! send buffer & neighbor_comm_handle%send_counts, & @@ -668,6 +897,7 @@ contains & neighbor_comm_handle%graph_comm, & & mpi_info_null, & & neighbor_comm_handle%persistent_request, iret) + if (timing_on) t_init = t_init + (psb_wtime() - t1) if (iret /= mpi_success) then info = psb_err_mpi_error_ call psb_errpush(info, name, m_err=(/iret/)) @@ -689,7 +919,9 @@ contains #ifdef PSB_HAVE_MPI_NEIGHBOR_PERSISTENT if (buffer_size > 0) then + if (timing_on) t1 = psb_wtime() call mpi_start(neighbor_comm_handle%persistent_request, iret) + if (timing_on) t_post = t_post + (psb_wtime() - t1) if (iret /= mpi_success) then info = psb_err_mpi_error_ call psb_errpush(info, name, m_err=(/iret/)) @@ -702,6 +934,7 @@ contains end if #else if (buffer_size > 0) then + if (timing_on) t1 = psb_wtime() call mpi_ineighbor_alltoallv( & & y%combuf(1), & ! send buffer & neighbor_comm_handle%send_counts, & @@ -713,6 +946,7 @@ contains & psb_mpi_r_dpk_, & & neighbor_comm_handle%graph_comm, & & neighbor_comm_handle%comm_request, iret) + if (timing_on) t_post = t_post + (psb_wtime() - t1) if (iret /= mpi_success) then info = psb_err_mpi_error_ call psb_errpush(info, name, m_err=(/iret/)) @@ -728,6 +962,7 @@ contains ! Post non-blocking neighborhood alltoallv if (debug) write(*,*) me,' nbr_vect: posting MPI_Ineighbor_alltoallv' if (buffer_size > 0) then + if (timing_on) t1 = psb_wtime() call mpi_ineighbor_alltoallv( & & y%combuf(1), & ! send buffer & neighbor_comm_handle%send_counts, & @@ -739,6 +974,7 @@ contains & psb_mpi_r_dpk_, & & neighbor_comm_handle%graph_comm, & & neighbor_comm_handle%comm_request, iret) + if (timing_on) t_post = t_post + (psb_wtime() - t1) if (iret /= mpi_success) then info = psb_err_mpi_error_ call psb_errpush(info, name, m_err=(/iret/)) @@ -788,6 +1024,7 @@ contains if ((topology_total_send + topology_total_recv) > 0) then ! Wait for the non-blocking collective to complete if (debug) write(*,*) me,' nbr_vect: waiting on MPI request' + if (timing_on) t1 = psb_wtime() if (neighbor_comm_handle%use_persistent_buffers) then #ifdef PSB_HAVE_MPI_NEIGHBOR_PERSISTENT call mpi_wait(neighbor_comm_handle%persistent_request, p2pstat, iret) @@ -797,6 +1034,7 @@ contains else call mpi_wait(neighbor_comm_handle%comm_request, p2pstat, iret) end if + if (timing_on) t_wait = t_wait + (psb_wtime() - t1) if (iret /= mpi_success) then info = psb_err_mpi_error_ call psb_errpush(info, name, m_err=(/iret/)) @@ -811,10 +1049,12 @@ contains ! Scatter received data to local vector positions (polymorphic for GPU) if (debug) write(*,*) me,' nbr_vect: scattering recv data,', topology_total_recv,' elems' + if (timing_on) t1 = psb_wtime() call y%sct(int(topology_total_recv,psb_mpk_), & & neighbor_comm_handle%recv_indexes, & & y%combuf(topology_total_send+1:topology_total_send+topology_total_recv), & & beta) + if (timing_on) t_sct = t_sct + (psb_wtime() - t1) else ! nothing to wait/scatter end if @@ -825,6 +1065,7 @@ contains & (neighbor_comm_handle%use_persistent_buffers .and. .not. neighbor_comm_handle%persistent_request_ready)) then neighbor_comm_handle%comm_request = mpi_request_null end if + if (timing_on) t1 = psb_wtime() call y%device_wait() if (.not. neighbor_comm_handle%use_persistent_buffers) then call y%maybe_free_buffer(info) @@ -833,10 +1074,43 @@ contains goto 9999 end if end if + if (timing_on) t_dev = t_dev + (psb_wtime() - t1) if (debug) write(*,*) me,' nbr_vect: done' end if ! do_wait + if (timing_on) then + t_total = psb_wtime() - t0 + call psb_amx(ctxt, t_topo) + call psb_amx(ctxt, t_buf) + call psb_amx(ctxt, t_gth) + call psb_amx(ctxt, t_init) + call psb_amx(ctxt, t_post) + call psb_amx(ctxt, t_wait) + call psb_amx(ctxt, t_sct) + call psb_amx(ctxt, t_dev) + call psb_amx(ctxt, t_total) + if (me == psb_root_) timing_report = psb_swap_timing_should_report() + if ((me == psb_root_) .and. timing_report) then + psb_swap_timing_neighbor_calls = call_idx + if (call_idx == 1) then + exchange_name = 'first' + else + exchange_name = 'steady' + end if + if (neighbor_comm_handle%use_persistent_buffers) then + write(psb_out_unit,'("SWAP_TIMING persistent_neighbor phase start=",l1,", wait=",l1)') do_start, do_wait + else + write(psb_out_unit,'("SWAP_TIMING neighbor phase start=",l1,", wait=",l1)') do_start, do_wait + end if + write(psb_out_unit,'(" exchange=",a,", call=",i0)') trim(exchange_name), call_idx + write(psb_out_unit,'(" topo=",es12.5,", buf=",es12.5,", gth=",es12.5,", init=",es12.5)') & + & t_topo, t_buf, t_gth, t_init + write(psb_out_unit,'(" post=",es12.5,", wait=",es12.5,", sct=",es12.5,", dev=",es12.5,", total=",es12.5)') & + & t_post, t_wait, t_sct, t_dev, t_total + end if + end if + call psb_erractionrestore(err_act) return diff --git a/linsolve/impl/psb_dcg.F90 b/linsolve/impl/psb_dcg.F90 index 5636b19e..401f2ea5 100644 --- a/linsolve/impl/psb_dcg.F90 +++ b/linsolve/impl/psb_dcg.F90 @@ -140,6 +140,11 @@ subroutine psb_dcg_vect(a,prec,b,x,eps,desc_a,info,& ctxt = desc_a%get_context() call psb_info(ctxt, me, np) + if (np == -ione) then + info = psb_err_context_error_ + call psb_errpush(info,name,a_err='invalid desc_a context in psb_dcg_vect') + goto 9999 + end if if (.not.allocated(b%v)) then info = psb_err_invalid_vect_state_ call psb_errpush(info,name) diff --git a/test/comm/cg/psb_comm_cg_test.F90 b/test/comm/cg/psb_comm_cg_test.F90 index 7d949690..abc339b0 100644 --- a/test/comm/cg/psb_comm_cg_test.F90 +++ b/test/comm/cg/psb_comm_cg_test.F90 @@ -9,12 +9,14 @@ program psb_comm_cg_test implicit none type(psb_ctxt_type) :: ctxt + type(psb_ctxt_type) :: desc_ctxt type(psb_dspmat_type) :: a type(psb_desc_type) :: desc_a type(psb_d_vect_type) :: b, x type(psb_dprec_type) :: prec integer(psb_ipk_) :: info, iam, np + integer(psb_ipk_) :: desc_me, desc_np integer(psb_ipk_) :: idim, itmax, itrace, istop, iter integer(psb_ipk_) :: scheme_idx, prec_idx, rep, nrep, nwarm integer(psb_ipk_), parameter :: n_schemes=3, n_precs=2 @@ -42,7 +44,7 @@ program psb_comm_cg_test nrep = 5 nwarm = 1 ! Keep itrace positive to avoid modulo-by-zero paths in convergence logging. - itrace = 1 + itrace = 0 istop = 2 eps = 1.d-6 scheme_type = (/ psb_comm_isend_irecv_, psb_comm_ineighbor_alltoallv_, & @@ -127,6 +129,14 @@ program psb_comm_cg_test ! call probe_ieee('after psb_d_gen_pde3d') if (info /= psb_success_) goto 9999 + ! desc_ctxt = desc_a%get_context() + ! call psb_info(desc_ctxt, desc_me, desc_np) + ! if (desc_np == -1) then + ! info = psb_err_context_error_ + ! write(psb_err_unit,*) 'Invalid descriptor context after psb_d_gen_pde3d' + ! goto 9999 + ! end if + do prec_idx = 1, n_precs do scheme_idx = 1, n_schemes do rep = 1, nrep @@ -169,14 +179,6 @@ program psb_comm_cg_test setup_time(prec_idx,scheme_idx,rep) = prec_init_time(prec_idx,scheme_idx,rep) + & & prec_bld_time(prec_idx,scheme_idx,rep) + comm_set_time(prec_idx,scheme_idx,rep) - do iter = 1, nwarm - call psb_geaxpby(dzero,b,dzero,x,desc_a,info) - if (info /= psb_success_) goto 9999 - call psb_krylov('CG',a,prec,b,x,eps,desc_a,info,& - & itmax=itmax,itrace=itrace,istop=istop) - if (info /= psb_success_) goto 9999 - end do - call psb_geaxpby(dzero,b,dzero,x,desc_a,info) if (info /= psb_success_) goto 9999 @@ -212,20 +214,20 @@ program psb_comm_cg_test final_error(prec_idx,scheme_idx,rep) = err solve_info(prec_idx,scheme_idx,rep) = info - if (iam == psb_root_) then - select type(ch => x%v%comm_handle) - type is(psb_comm_neighbor_handle) - write(psb_out_unit,'("DIAG_COMM scheme=",a,", prec=",a,", rep=",i0)') & - & trim(scheme_name(scheme_idx)), trim(prec_name(prec_idx)), rep - write(psb_out_unit,'("DIAG_COMM counters: init=",i0,", start=",i0,", wait=",i0,", realloc=",i0)') & - & ch%diag_init_calls, ch%diag_start_calls, ch%diag_wait_calls, & - & ch%diag_buffer_reallocs - write(psb_out_unit,'("DIAG_COMM state: ready=",l1,", bsz=",i0)') & - & ch%persistent_request_ready, ch%persistent_buffer_size - class default - continue - end select - end if + ! if (iam == psb_root_) then + ! select type(ch => x%v%comm_handle) + ! type is(psb_comm_neighbor_handle) + ! write(psb_out_unit,'("DIAG_COMM scheme=",a,", prec=",a,", rep=",i0)') & + ! & trim(scheme_name(scheme_idx)), trim(prec_name(prec_idx)), rep + ! write(psb_out_unit,'("DIAG_COMM counters: init=",i0,", start=",i0,", wait=",i0,", realloc=",i0)') & + ! & ch%diag_init_calls, ch%diag_start_calls, ch%diag_wait_calls, & + ! & ch%diag_buffer_reallocs + ! write(psb_out_unit,'("DIAG_COMM state: ready=",l1,", bsz=",i0)') & + ! & ch%persistent_request_ready, ch%persistent_buffer_size + ! class default + ! continue + ! end select + ! end if if (info /= psb_success_) goto 9999 end do diff --git a/test/comm/spmv/psb_spmv_overlap_sbatch.sh b/test/comm/spmv/psb_spmv_overlap_sbatch.sh new file mode 100644 index 00000000..4a80775e --- /dev/null +++ b/test/comm/spmv/psb_spmv_overlap_sbatch.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash +#SBATCH --job-name=psb_spmv_overlap +#SBATCH --partition=boost_usr_prod +#SBATCH --time=01:00:00 +#SBATCH --nodes=2 +#SBATCH --ntasks=64 +#SBATCH --ntasks-per-node=32 +#SBATCH --cpus-per-task=1 +#SBATCH --threads-per-core=1 +#SBATCH --gpus-per-node=4 +#SBATCH --export=ALL +#SBATCH -A CNHPC_1736213 +#SBATCH --output=psb_spmv_overlap_%j.out +#SBATCH --error=psb_spmv_overlap_%j.err + +set -euo pipefail + +# Environment tuned like the existing comm test script. +export UCX_VFS_ENABLE=n +export UCX_VFS_USE_FUSE=n +export UCX_STATS_DEST=none +export UCX_LOG_LEVEL=error +export UCX_TLS=dc,sm,self +export UCX_NET_DEVICES=mlx5_0:1 +export OMPI_MCA_coll=^hcoll,han +export OMPI_MCA_coll_hcoll_enable=0 +export UCX_MEMTYPE_CACHE=n +export UCX_CLOSE_TIMEOUT=10s + +EXEC=./test/comm/spmv/runs/spmv_overlap +IDIM_LIST=${IDIM_LIST:-"20 40 60 80 100 140 180 220 260 300"} +TIMES=${TIMES:-100} +BUILD_IF_MISSING=${BUILD_IF_MISSING:-1} + +if [[ ! -x "$EXEC" ]]; then + echo "Executable not found: $EXEC" >&2 + exit 1 +fi + +echo "Running SpMV overlap comm test" +echo " EXEC=$EXEC" +echo " IDIM_LIST=$IDIM_LIST" +echo " TIMES=$TIMES" +echo " BUILD_IF_MISSING=$BUILD_IF_MISSING" + +for IDIM in $IDIM_LIST; do + echo "" + echo "=== Running IDIM=$IDIM TIMES=$TIMES ===" + IDIM="$IDIM" TIMES="$TIMES" srun --exclusive -N2 -n64 "$EXEC" +done diff --git a/test/comm/spmv/psb_spmv_overlap_test.f90 b/test/comm/spmv/psb_spmv_overlap_test.f90 index 431ef5e9..eaf0793b 100644 --- a/test/comm/spmv/psb_spmv_overlap_test.f90 +++ b/test/comm/spmv/psb_spmv_overlap_test.f90 @@ -524,29 +524,47 @@ contains real(psb_dpk_) :: alpha, beta type(psb_dspmat_type) :: a - type(psb_d_vect_type) :: x_baseline, x_neighbor, x_persistent - type(psb_d_vect_type) :: y_baseline, y_neighbor, y_persistent + type(psb_d_vect_type) :: x_isend, x_neighbor, x_persistent + type(psb_d_vect_type) :: y_ov_isend, y_ov_neighbor, y_ov_persistent + type(psb_d_vect_type) :: y_no_isend, y_no_neighbor, y_no_persistent type(psb_desc_type) :: desc_a - character(len=:), allocatable :: output_file_name - character(len=32) :: idim_str + character(len=64) :: env_buf real(psb_dpk_), allocatable :: x_global(:), y_global(:) integer(psb_ipk_) :: my_rank, np, info, err_act + integer :: env_len, env_status, ios integer(psb_ipk_) :: n_global, idim integer(psb_ipk_) :: i, times real(psb_dpk_) :: t0, t1, dt - real(psb_dpk_) :: tsum_baseline, tsum_neighbor, tsum_persistent - real(psb_dpk_) :: err_bn, err_bp, tol - logical :: tnd + real(psb_dpk_) :: t_ov_isend, t_ov_neighbor, t_ov_persistent + real(psb_dpk_) :: t_no_isend, t_no_neighbor, t_no_persistent + real(psb_dpk_) :: err_isend, err_neighbor, err_persistent, tol + real(psb_dpk_) :: avg_ov, avg_no, speedup, gain_pct info = psb_success_ tol = 1.0d-10 times = 100 - tsum_baseline = 0.0_psb_dpk_ - tsum_neighbor = 0.0_psb_dpk_ - tsum_persistent = 0.0_psb_dpk_ - tnd = .false. + t_ov_isend = 0.0_psb_dpk_ + t_ov_neighbor = 0.0_psb_dpk_ + t_ov_persistent = 0.0_psb_dpk_ + t_no_isend = 0.0_psb_dpk_ + t_no_neighbor = 0.0_psb_dpk_ + t_no_persistent = 0.0_psb_dpk_ idim = 10 + call psb_erractionsave(err_act) + + call get_environment_variable('IDIM', env_buf, length=env_len, status=env_status) + if ((env_status == 0) .and. (env_len > 0)) then + read(env_buf(1:env_len), *, iostat=ios) idim + if ((ios /= 0) .or. (idim < 2)) idim = 10 + end if + + call get_environment_variable('TIMES', env_buf, length=env_len, status=env_status) + if ((env_status == 0) .and. (env_len > 0)) then + read(env_buf(1:env_len), *, iostat=ios) times + if ((ios /= 0) .or. (times < 1)) times = 100 + end if + n_global = idim * idim * idim alpha = done beta = dzero @@ -555,7 +573,7 @@ contains call psb_barrier(ctxt) - call psb_d_gen_pde3d(ctxt,idim,a,y_baseline,x_baseline,desc_a,"CSR",info,partition=1) + call psb_d_gen_pde3d(ctxt,idim,a,y_ov_isend,x_isend,desc_a,"CSR",info,partition=1) if (info /= psb_success_) goto 9999 call psb_barrier(ctxt) @@ -571,96 +589,200 @@ contains call psb_geall(x_neighbor, desc_a, info) call psb_geall(x_persistent, desc_a, info) - call psb_geall(y_neighbor, desc_a, info) - call psb_geall(y_persistent, desc_a, info) + call psb_geall(y_ov_neighbor, desc_a, info) + call psb_geall(y_ov_persistent, desc_a, info) + call psb_geall(y_no_isend, desc_a, info) + call psb_geall(y_no_neighbor, desc_a, info) + call psb_geall(y_no_persistent, desc_a, info) if (info /= psb_success_) goto 9999 - call psb_scatter(x_global, x_baseline, desc_a, info, root=psb_root_) + call psb_scatter(x_global, x_isend, desc_a, info, root=psb_root_) call psb_scatter(x_global, x_neighbor, desc_a, info, root=psb_root_) call psb_scatter(x_global, x_persistent, desc_a, info, root=psb_root_) - call psb_scatter(y_global, y_baseline, desc_a, info, root=psb_root_) - call psb_scatter(y_global, y_neighbor, desc_a, info, root=psb_root_) - call psb_scatter(y_global, y_persistent, desc_a, info, root=psb_root_) + call psb_scatter(y_global, y_ov_isend, desc_a, info, root=psb_root_) + call psb_scatter(y_global, y_ov_neighbor, desc_a, info, root=psb_root_) + call psb_scatter(y_global, y_ov_persistent, desc_a, info, root=psb_root_) + call psb_scatter(y_global, y_no_isend, desc_a, info, root=psb_root_) + call psb_scatter(y_global, y_no_neighbor, desc_a, info, root=psb_root_) + call psb_scatter(y_global, y_no_persistent, desc_a, info, root=psb_root_) if (info /= psb_success_) goto 9999 ! Set communication schemes on the x vectors used by psb_spmm. - call psb_comm_set(psb_comm_isend_irecv_, x_baseline%v%comm_handle, info) + call psb_comm_set(psb_comm_isend_irecv_, x_isend%v%comm_handle, info) if (info /= psb_success_) goto 9999 call psb_comm_set(psb_comm_ineighbor_alltoallv_, x_neighbor%v%comm_handle, info) if (info /= psb_success_) goto 9999 call psb_comm_set(psb_comm_persistent_ineighbor_alltoallv_, x_persistent%v%comm_handle, info) if (info /= psb_success_) goto 9999 - ! Warm-up all schemes once. - call psb_spmm(alpha, a, x_baseline, beta, y_baseline, desc_a, info, doswap=.true.) - call psb_spmm(alpha, a, x_neighbor, beta, y_neighbor, desc_a, info, doswap=.true.) - call psb_spmm(alpha, a, x_persistent, beta, y_persistent, desc_a, info, doswap=.true.) + ! Warm-up all schemes once: overlap and non-overlap paths. + call psb_spmm(alpha, a, x_isend, beta, y_ov_isend, desc_a, info, doswap=.true.) + call psb_halo(x_isend, desc_a, info) + call psb_spmm(alpha, a, x_isend, beta, y_no_isend, desc_a, info, doswap=.false.) + + call psb_spmm(alpha, a, x_neighbor, beta, y_ov_neighbor, desc_a, info, doswap=.true.) + call psb_halo(x_neighbor, desc_a, info) + call psb_spmm(alpha, a, x_neighbor, beta, y_no_neighbor, desc_a, info, doswap=.false.) + + call psb_spmm(alpha, a, x_persistent, beta, y_ov_persistent, desc_a, info, doswap=.true.) + call psb_halo(x_persistent, desc_a, info) + call psb_spmm(alpha, a, x_persistent, beta, y_no_persistent, desc_a, info, doswap=.false.) if (info /= psb_success_) goto 9999 - ! Restore vectors so timed loops start from same initial state. - call psb_scatter(x_global, x_baseline, desc_a, info, root=psb_root_) + ! ----------------------------- + ! isend/irecv scheme + ! ----------------------------- + call psb_scatter(x_global, x_isend, desc_a, info, root=psb_root_) + call psb_scatter(y_global, y_ov_isend, desc_a, info, root=psb_root_) + call psb_barrier(ctxt) + t0 = psb_wtime() + do i = 1, times + call psb_spmm(alpha, a, x_isend, beta, y_ov_isend, desc_a, info, doswap=.true.) + end do + t1 = psb_wtime() + dt = t1 - t0 + call psb_amx(ctxt, dt) + t_ov_isend = dt + + call psb_scatter(x_global, x_isend, desc_a, info, root=psb_root_) + call psb_scatter(y_global, y_no_isend, desc_a, info, root=psb_root_) + call psb_barrier(ctxt) + t0 = psb_wtime() + do i = 1, times + call psb_halo(x_isend, desc_a, info) + call psb_spmm(alpha, a, x_isend, beta, y_no_isend, desc_a, info, doswap=.false.) + end do + t1 = psb_wtime() + dt = t1 - t0 + call psb_amx(ctxt, dt) + t_no_isend = dt + + ! ----------------------------- + ! ineighbor_alltoallv scheme + ! ----------------------------- call psb_scatter(x_global, x_neighbor, desc_a, info, root=psb_root_) - call psb_scatter(x_global, x_persistent, desc_a, info, root=psb_root_) - call psb_scatter(y_global, y_baseline, desc_a, info, root=psb_root_) - call psb_scatter(y_global, y_neighbor, desc_a, info, root=psb_root_) - call psb_scatter(y_global, y_persistent, desc_a, info, root=psb_root_) - if (info /= psb_success_) goto 9999 + call psb_scatter(y_global, y_ov_neighbor, desc_a, info, root=psb_root_) + call psb_barrier(ctxt) + t0 = psb_wtime() + do i = 1, times + call psb_spmm(alpha, a, x_neighbor, beta, y_ov_neighbor, desc_a, info, doswap=.true.) + end do + t1 = psb_wtime() + dt = t1 - t0 + call psb_amx(ctxt, dt) + t_ov_neighbor = dt - ! Baseline (isend/irecv) overlapped SpMV. + call psb_scatter(x_global, x_neighbor, desc_a, info, root=psb_root_) + call psb_scatter(y_global, y_no_neighbor, desc_a, info, root=psb_root_) + call psb_barrier(ctxt) t0 = psb_wtime() do i = 1, times - call psb_spmm(alpha, a, x_baseline, beta, y_baseline, desc_a, info, doswap=.true.) + call psb_halo(x_neighbor, desc_a, info) + call psb_spmm(alpha, a, x_neighbor, beta, y_no_neighbor, desc_a, info, doswap=.false.) end do t1 = psb_wtime() dt = t1 - t0 call psb_amx(ctxt, dt) - tsum_baseline = tsum_baseline + dt + t_no_neighbor = dt - ! Neighbor alltoallv overlapped SpMV. + ! ---------------------------------------- + ! persistent_ineighbor_alltoallv scheme + ! ---------------------------------------- + call psb_scatter(x_global, x_persistent, desc_a, info, root=psb_root_) + call psb_scatter(y_global, y_ov_persistent, desc_a, info, root=psb_root_) + call psb_barrier(ctxt) t0 = psb_wtime() do i = 1, times - call psb_spmm(alpha, a, x_neighbor, beta, y_neighbor, desc_a, info, doswap=.true.) + call psb_spmm(alpha, a, x_persistent, beta, y_ov_persistent, desc_a, info, doswap=.true.) end do t1 = psb_wtime() dt = t1 - t0 call psb_amx(ctxt, dt) - tsum_neighbor = tsum_neighbor + dt + t_ov_persistent = dt - ! Persistent-neighbor overlapped SpMV. + call psb_scatter(x_global, x_persistent, desc_a, info, root=psb_root_) + call psb_scatter(y_global, y_no_persistent, desc_a, info, root=psb_root_) + call psb_barrier(ctxt) t0 = psb_wtime() do i = 1, times - call psb_spmm(alpha, a, x_persistent, beta, y_persistent, desc_a, info, doswap=.true.) + call psb_halo(x_persistent, desc_a, info) + call psb_spmm(alpha, a, x_persistent, beta, y_no_persistent, desc_a, info, doswap=.false.) end do t1 = psb_wtime() dt = t1 - t0 call psb_amx(ctxt, dt) - tsum_persistent = tsum_persistent + dt + t_no_persistent = dt + + if (info /= psb_success_) goto 9999 - err_bn = maxval(abs(y_baseline%get_vect() - y_neighbor%get_vect())) - err_bp = maxval(abs(y_baseline%get_vect() - y_persistent%get_vect())) - call psb_amx(ctxt, err_bn) - call psb_amx(ctxt, err_bp) + err_isend = maxval(abs(y_ov_isend%get_vect() - y_no_isend%get_vect())) + err_neighbor = maxval(abs(y_ov_neighbor%get_vect() - y_no_neighbor%get_vect())) + err_persistent = maxval(abs(y_ov_persistent%get_vect() - y_no_persistent%get_vect())) + call psb_amx(ctxt, err_isend) + call psb_amx(ctxt, err_neighbor) + call psb_amx(ctxt, err_persistent) if (my_rank == 0) then - write(psb_out_unit,'(" Avg baseline time : ",es12.5)') tsum_baseline / real(times, psb_dpk_) - write(psb_out_unit,'(" Tot baseline time : ",es12.5)') tsum_baseline - write(psb_out_unit,'(" Avg neighbor time : ",es12.5)') tsum_neighbor / real(times, psb_dpk_) - write(psb_out_unit,'(" Tot neighbor time : ",es12.5)') tsum_neighbor - write(psb_out_unit,'(" Avg pers-neigh time: ",es12.5)') tsum_persistent / real(times, psb_dpk_) - write(psb_out_unit,'(" Tot pers-neigh time: ",es12.5)') tsum_persistent - write(psb_out_unit,'(" Check baseline vs neighbor err = ",es12.5)') err_bn - write(psb_out_unit,'(" Check baseline vs persistent err = ",es12.5)') err_bp - if ((err_bn > tol) .or. (err_bp > tol)) then + write(psb_out_unit,'(/,"SpMV overlap benchmark")') + write(psb_out_unit,'(" idim : ",i0)') idim + write(psb_out_unit,'(" global unknowns : ",i0)') n_global + write(psb_out_unit,'(" repetitions : ",i0)') times + write(psb_out_unit,'(" timing metric : max over MPI ranks")') + write(psb_out_unit,'(" gain(%) = 100*(1 - overlap/no_overlap)")') + + write(psb_out_unit,'(/,"Scheme: isend_irecv")') + avg_ov = t_ov_isend / real(times, psb_dpk_) + avg_no = t_no_isend / real(times, psb_dpk_) + speedup = t_no_isend / max(t_ov_isend, tiny(done)) + gain_pct = 100.0_psb_dpk_ * (done - (t_ov_isend / max(t_no_isend, tiny(done)))) + write(psb_out_unit,'(" total overlap : ",es12.5)') t_ov_isend + write(psb_out_unit,'(" total no_overlap : ",es12.5)') t_no_isend + write(psb_out_unit,'(" avg overlap : ",es12.5)') avg_ov + write(psb_out_unit,'(" avg no_overlap : ",es12.5)') avg_no + write(psb_out_unit,'(" speedup (no/ov) : ",f10.4)') speedup + write(psb_out_unit,'(" gain (%) : ",f10.4)') gain_pct + write(psb_out_unit,'(" overlap vs no_overlap err = ",es12.5)') err_isend + + write(psb_out_unit,'(/,"Scheme: ineighbor_alltoallv")') + avg_ov = t_ov_neighbor / real(times, psb_dpk_) + avg_no = t_no_neighbor / real(times, psb_dpk_) + speedup = t_no_neighbor / max(t_ov_neighbor, tiny(done)) + gain_pct = 100.0_psb_dpk_ * (done - (t_ov_neighbor / max(t_no_neighbor, tiny(done)))) + write(psb_out_unit,'(" total overlap : ",es12.5)') t_ov_neighbor + write(psb_out_unit,'(" total no_overlap : ",es12.5)') t_no_neighbor + write(psb_out_unit,'(" avg overlap : ",es12.5)') avg_ov + write(psb_out_unit,'(" avg no_overlap : ",es12.5)') avg_no + write(psb_out_unit,'(" speedup (no/ov) : ",f10.4)') speedup + write(psb_out_unit,'(" gain (%) : ",f10.4)') gain_pct + write(psb_out_unit,'(" overlap vs no_overlap err = ",es12.5)') err_neighbor + + write(psb_out_unit,'(/,"Scheme: persistent_ineighbor_alltoallv")') + avg_ov = t_ov_persistent / real(times, psb_dpk_) + avg_no = t_no_persistent / real(times, psb_dpk_) + speedup = t_no_persistent / max(t_ov_persistent, tiny(done)) + gain_pct = 100.0_psb_dpk_ * (done - (t_ov_persistent / max(t_no_persistent, tiny(done)))) + write(psb_out_unit,'(" total overlap : ",es12.5)') t_ov_persistent + write(psb_out_unit,'(" total no_overlap : ",es12.5)') t_no_persistent + write(psb_out_unit,'(" avg overlap : ",es12.5)') avg_ov + write(psb_out_unit,'(" avg no_overlap : ",es12.5)') avg_no + write(psb_out_unit,'(" speedup (no/ov) : ",f10.4)') speedup + write(psb_out_unit,'(" gain (%) : ",f10.4)') gain_pct + write(psb_out_unit,'(" overlap vs no_overlap err = ",es12.5)') err_persistent + + if ((err_isend > tol) .or. (err_neighbor > tol) .or. (err_persistent > tol)) then write(psb_out_unit,'(" WARNING: mismatch exceeds tolerance ",es12.5)') tol end if end if - call psb_gefree(x_baseline, desc_a, info) + call psb_gefree(x_isend, desc_a, info) call psb_gefree(x_neighbor, desc_a, info) call psb_gefree(x_persistent, desc_a, info) - call psb_gefree(y_baseline, desc_a, info) - call psb_gefree(y_neighbor, desc_a, info) - call psb_gefree(y_persistent, desc_a, info) + call psb_gefree(y_ov_isend, desc_a, info) + call psb_gefree(y_ov_neighbor, desc_a, info) + call psb_gefree(y_ov_persistent, desc_a, info) + call psb_gefree(y_no_isend, desc_a, info) + call psb_gefree(y_no_neighbor, desc_a, info) + call psb_gefree(y_no_persistent, desc_a, info) call psb_spfree(a, desc_a, info) call psb_cdfree(desc_a, info) diff --git a/test/comm/swapdata/psb_comm_test.F90 b/test/comm/swapdata/psb_comm_test.F90 index 25b4e22b..c7bb08c5 100644 --- a/test/comm/swapdata/psb_comm_test.F90 +++ b/test/comm/swapdata/psb_comm_test.F90 @@ -17,6 +17,7 @@ ! program psb_comm_test use psb_base_mod + use psb_error_mod, only: psb_set_debug_level, psb_debug_ext_ use psi_mod use psb_comm_factory_mod, only: psb_comm_set, psb_comm_free use psb_comm_schemes_mod, only: psb_comm_ineighbor_alltoallv_, psb_comm_persistent_ineighbor_alltoallv_, & @@ -30,6 +31,7 @@ program psb_comm_test integer(psb_ipk_) :: iters character(len=32) :: arg character(len=16) :: mode + logical :: debug_swapdata ! ---- descriptor / context ---- type(psb_ctxt_type) :: ctxt @@ -58,13 +60,16 @@ program psb_comm_test real(psb_dpk_) :: t0, t1, dt, tsum_baseline, tsum_neighbor, tsum_neighbor_persistent integer(psb_lpk_), allocatable :: glob_col(:) character(len=40) :: name + real(psb_dpk_) :: huge_d name = 'test_halo_new' tol = 1.0d-12 + huge_d = huge(1.0_psb_dpk_) n_pass = 0 n_total = 0 iters = 5 mode = 'both' + debug_swapdata = .false. ! ---- parse command-line argument for idim ---- idim = 10 @@ -92,9 +97,15 @@ program psb_comm_test call get_command_argument(i+1, arg) read(arg, *) mode end if + else if (trim(arg) == '--debug') then + debug_swapdata = .true. end if end do + if (debug_swapdata) then + call psb_set_debug_level(psb_debug_ext_) + end if + run_baseline = .false. run_neighbor = .false. run_persistent = .false. @@ -180,11 +191,35 @@ program psb_comm_test ! 3. Allocate two D vectors (scratch) and fill owned entries ! ================================================================== call psb_geall(v_baseline, desc_a, info) + if (info /= psb_success_) then + write(psb_err_unit,*) my_rank, 'geall baseline error:', info + call psb_abort(ctxt) + end if call psb_geall(v_neighbor, desc_a, info) + if (info /= psb_success_) then + write(psb_err_unit,*) my_rank, 'geall neighbor error:', info + call psb_abort(ctxt) + end if call psb_geall(v_neighbor_persistent, desc_a, info) + if (info /= psb_success_) then + write(psb_err_unit,*) my_rank, 'geall persistent-neighbor error:', info + call psb_abort(ctxt) + end if call psb_geasb(v_baseline, desc_a, info, scratch=.true.) + if (info /= psb_success_) then + write(psb_err_unit,*) my_rank, 'geasb baseline error:', info + call psb_abort(ctxt) + end if call psb_geasb(v_neighbor, desc_a, info, scratch=.true.) + if (info /= psb_success_) then + write(psb_err_unit,*) my_rank, 'geasb neighbor error:', info + call psb_abort(ctxt) + end if call psb_geasb(v_neighbor_persistent, desc_a, info, scratch=.true.) + if (info /= psb_success_) then + write(psb_err_unit,*) my_rank, 'geasb persistent-neighbor error:', info + call psb_abort(ctxt) + end if ! Fill owned entries with the global index value allocate(vals(ncol)) @@ -207,6 +242,10 @@ program psb_comm_test do i = 1, ncol expected(i) = real(glob_col(i), psb_dpk_) end do + allocate(result_baseline(ncol), result_neighbor(ncol), result_persistent(ncol)) + result_baseline = huge_d + result_neighbor = huge_d + result_persistent = huge_d ! ================================================================== ! 6. Baseline halo exchange (Isend/Irecv in one call) @@ -342,16 +381,33 @@ program psb_comm_test ! ================================================================== ! 8. Extract results and compare ! ================================================================== - result_baseline = v_baseline%get_vect() - result_neighbor = v_neighbor%get_vect() - result_persistent = v_neighbor_persistent%get_vect() + result_baseline = v_baseline%v%v + result_neighbor = v_neighbor%v%v + result_persistent = v_neighbor_persistent%v%v + + ! Debug: Check if results are properly populated + if (my_rank == 0 .and. debug_swapdata) then + write(psb_out_unit,'("DEBUG: ncol=",i0," nrow=",i0)') ncol, nrow + write(psb_out_unit,'("DEBUG: size(result_baseline)=",i0)') size(result_baseline) + if (ncol > 0) then + write(psb_out_unit,'("DEBUG: result_baseline(1:min(5,ncol))=",5(es12.5,1x))') & + & result_baseline(1:min(5,ncol)) + write(psb_out_unit,'("DEBUG: expected(1:min(5,ncol))=",5(es12.5,1x))') & + & expected(1:min(5,ncol)) + end if + end if if (run_baseline .and. run_neighbor) then n_total = n_total + 1 - err = maxval(abs(result_baseline(1:ncol) - result_neighbor(1:ncol))) + err = huge_d + if (ncol > 0) then + err = maxval(abs(result_baseline(1:ncol) - result_neighbor(1:ncol))) + else + err = 0.0_psb_dpk_ + end if call psb_amx(ctxt, err) if (my_rank == 0) then - if (err < tol) then + if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then write(psb_out_unit,'(" [PASS] cross-check baseline vs neighbor : err = ",es12.5)') err n_pass = n_pass + 1 else @@ -362,10 +418,15 @@ program psb_comm_test if (run_baseline) then n_total = n_total + 1 - err = maxval(abs(result_baseline(1:ncol) - expected(1:ncol))) + err = huge_d + if (ncol > 0) then + err = maxval(abs(result_baseline(1:ncol) - expected(1:ncol))) + else + err = 0.0_psb_dpk_ + end if call psb_amx(ctxt, err) if (my_rank == 0) then - if (err < tol) then + if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then write(psb_out_unit,'(" [PASS] baseline absolute correctness : err = ",es12.5)') err n_pass = n_pass + 1 else @@ -376,10 +437,15 @@ program psb_comm_test if (run_neighbor) then n_total = n_total + 1 - err = maxval(abs(result_neighbor(1:ncol) - expected(1:ncol))) + err = huge_d + if (ncol > 0) then + err = maxval(abs(result_neighbor(1:ncol) - expected(1:ncol))) + else + err = 0.0_psb_dpk_ + end if call psb_amx(ctxt, err) if (my_rank == 0) then - if (err < tol) then + if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then write(psb_out_unit,'(" [PASS] neighbor absolute correctness : err = ",es12.5)') err n_pass = n_pass + 1 else @@ -390,10 +456,15 @@ program psb_comm_test if (run_baseline .and. run_persistent) then n_total = n_total + 1 - err = maxval(abs(result_baseline(1:ncol) - result_persistent(1:ncol))) + err = huge_d + if (ncol > 0) then + err = maxval(abs(result_baseline(1:ncol) - result_persistent(1:ncol))) + else + err = 0.0_psb_dpk_ + end if call psb_amx(ctxt, err) if (my_rank == 0) then - if (err < tol) then + if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then write(psb_out_unit,'(" [PASS] cross-check baseline vs pers-nei : err = ",es12.5)') err n_pass = n_pass + 1 else @@ -404,10 +475,15 @@ program psb_comm_test if (run_persistent) then n_total = n_total + 1 - err = maxval(abs(result_persistent(1:ncol) - expected(1:ncol))) + err = huge_d + if (ncol > 0) then + err = maxval(abs(result_persistent(1:ncol) - expected(1:ncol))) + else + err = 0.0_psb_dpk_ + end if call psb_amx(ctxt, err) if (my_rank == 0) then - if (err < tol) then + if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then write(psb_out_unit,'(" [PASS] pers-neigh absolute correctness : err = ",es12.5)') err n_pass = n_pass + 1 else @@ -426,12 +502,17 @@ program psb_comm_test call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_) call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor%v, desc_a, info, data=psb_comm_halo_) - result_neighbor = v_neighbor%get_vect() + result_neighbor = v_neighbor%v%v n_total = n_total + 1 - err = maxval(abs(result_neighbor(1:ncol) - expected(1:ncol))) + err = huge_d + if (ncol > 0) then + err = maxval(abs(result_neighbor(1:ncol) - expected(1:ncol))) + else + err = 0.0_psb_dpk_ + end if call psb_amx(ctxt, err) if (my_rank == 0) then - if (err < tol) then + if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then write(psb_out_unit,'(" [PASS] neighbor topology reuse : err = ",es12.5)') err n_pass = n_pass + 1 else @@ -450,12 +531,17 @@ program psb_comm_test call psi_swapdata(psb_comm_status_start_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_) call psi_swapdata(psb_comm_status_wait_, dzero, v_neighbor_persistent%v, desc_a, info, data=psb_comm_halo_) - result_persistent = v_neighbor_persistent%get_vect() + result_persistent = v_neighbor_persistent%v%v n_total = n_total + 1 - err = maxval(abs(result_persistent(1:ncol) - expected(1:ncol))) + err = huge_d + if (ncol > 0) then + err = maxval(abs(result_persistent(1:ncol) - expected(1:ncol))) + else + err = 0.0_psb_dpk_ + end if call psb_amx(ctxt, err) if (my_rank == 0) then - if (err < tol) then + if ((err >= 0.0_psb_dpk_) .and. (err < tol)) then write(psb_out_unit,'(" [PASS] pers-neigh buffer reuse : err = ",es12.5)') err n_pass = n_pass + 1 else @@ -467,6 +553,7 @@ program psb_comm_test ! ================================================================== ! 9. Summary ! ================================================================== + call psb_barrier(ctxt) if (my_rank == 0) then write(psb_out_unit,'("================================================")') write(psb_out_unit,'(" Results: ",i0," / ",i0," tests passed")') n_pass, n_total @@ -477,6 +564,7 @@ program psb_comm_test end if write(psb_out_unit,'("================================================")') end if + call psb_barrier(ctxt) ! ================================================================== ! 10. Cleanup diff --git a/test/comm/swapdata/psb_comm_test_sbatch.sh b/test/comm/swapdata/psb_comm_test_sbatch.sh new file mode 100644 index 00000000..81aae2c1 --- /dev/null +++ b/test/comm/swapdata/psb_comm_test_sbatch.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +#SBATCH --job-name=psb_swapdata_test +#SBATCH --partition=boost_usr_prod +#SBATCH --time=02:00:00 +#SBATCH --nodes=4 +#SBATCH --ntasks=128 +#SBATCH --ntasks-per-node=32 +#SBATCH --cpus-per-task=1 +#SBATCH --threads-per-core=1 +#SBATCH --gpus-per-node=4 +#SBATCH --export=ALL +#SBATCH -A CNHPC_1736213 +#SBATCH --output=psb_comm_swapdata_%j.out +#SBATCH --error=psb_comm_swapdata_%j.err + +set -euo pipefail + +# Environment tuned like the main comm test script. +export UCX_VFS_ENABLE=n +export UCX_VFS_USE_FUSE=n +export UCX_STATS_DEST=none +export UCX_LOG_LEVEL=error +export UCX_TLS=dc,sm,self +export UCX_NET_DEVICES=mlx5_0:1 +export OMPI_MCA_coll=^hcoll,han +export OMPI_MCA_coll_hcoll_enable=0 +export UCX_MEMTYPE_CACHE=n +export UCX_CLOSE_TIMEOUT=10s + +EXEC=./test/comm/swapdata/runs/psb_comm_test +DIM_START=${DIM_START:-20} +DIM_END=${DIM_END:-300} +DIM_STEP=${DIM_STEP:-20} +ITERS=${ITERS:-10} +MODE=${MODE:-both} +DEBUG=${DEBUG:-0} + +EXTRA_ARGS=() +if [[ "$DEBUG" != "0" ]]; then + EXTRA_ARGS+=(--debug) +fi + +if [[ ! -x "$EXEC" ]]; then + echo "Executable not found: $EXEC" >&2 + echo "Build the test first with 'make' in test/comm/swapdata/" >&2 + exit 1 +fi + +echo "Running swapdata comm test" +echo " EXEC=$EXEC" +echo " DIM_START=$DIM_START" +echo " DIM_END=$DIM_END" +echo " DIM_STEP=$DIM_STEP" +echo " ITERS=$ITERS" +echo " MODE=$MODE" +echo " DEBUG=$DEBUG" + +for DIM in $(seq "$DIM_START" "$DIM_STEP" "$DIM_END"); do + echo "" + echo "=== Running DIM=$DIM ===" + srun --exclusive -N2 -n64 "$EXEC" --dim "$DIM" --iters "$ITERS" --mode "$MODE" "${EXTRA_ARGS[@]}" +done