|
|
|
@ -1,8 +1,6 @@
|
|
|
|
|
#include <stdlib.h>
|
|
|
|
|
#include <stdio.h>
|
|
|
|
|
|
|
|
|
|
#include "cintrf.h"
|
|
|
|
|
#include "vectordev.h"
|
|
|
|
|
#include "psi_cuda_common.cuh"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -62,11 +60,9 @@ __global__ void CONCAT(GEN_PSI_FUNC_NAME(TYPE_SYMBOL),_krn)(int ii, int nrws,
|
|
|
|
|
ir += ldv;
|
|
|
|
|
}
|
|
|
|
|
idiag[i]=idval;
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void CONCAT(GEN_PSI_FUNC_NAME(TYPE_SYMBOL),_)(spgpuHandle_t handle, int nrws, int i, int nr, int nza,
|
|
|
|
|
int baseIdx, int hacksz, int ldv, int nzm,
|
|
|
|
|
int *rS,int *devIdisp, int *devJa, VALUE_TYPE *devVal,
|
|
|
|
@ -76,8 +72,10 @@ void CONCAT(GEN_PSI_FUNC_NAME(TYPE_SYMBOL),_)(spgpuHandle_t handle, int nrws, i
|
|
|
|
|
dim3 grid ((nrws + THREAD_BLOCK - 1) / THREAD_BLOCK);
|
|
|
|
|
|
|
|
|
|
CONCAT(GEN_PSI_FUNC_NAME(TYPE_SYMBOL),_krn)
|
|
|
|
|
<<< grid, block, 0, handle->currentStream >>>(i,nrws, nr, nza, baseIdx, hacksz, ldv, nzm,
|
|
|
|
|
rS,devIdisp,devJa,devVal,idiag, rP,cM);
|
|
|
|
|
<<< grid, block, 0, handle->currentStream >>>(i,nrws, nr, nza, baseIdx,
|
|
|
|
|
hacksz, ldv, nzm,
|
|
|
|
|
rS,devIdisp,devJa,devVal,
|
|
|
|
|
idiag, rP,cM);
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -89,16 +87,17 @@ GEN_PSI_FUNC_NAME(TYPE_SYMBOL)
|
|
|
|
|
(spgpuHandle_t handle, int nr, int nc, int nza, int baseIdx, int hacksz, int ldv, int nzm,
|
|
|
|
|
int *rS,int *devIdisp, int *devJa, VALUE_TYPE *devVal,
|
|
|
|
|
int *idiag, int *rP, VALUE_TYPE *cM)
|
|
|
|
|
{ int i,j, nrws;
|
|
|
|
|
{ int i, nrws;
|
|
|
|
|
//int maxNForACall = THREAD_BLOCK*handle->maxGridSizeX;
|
|
|
|
|
int maxNForACall = max(handle->maxGridSizeX, THREAD_BLOCK*handle->maxGridSizeX);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//fprintf(stderr,"Loop on j: %d\n",j);
|
|
|
|
|
for (i=0; i<nr; i+=nrws) {
|
|
|
|
|
nrws = MIN(maxNForACall, nr - i);
|
|
|
|
|
//fprintf(stderr,"ifirst: %d i : %d nrws: %d i + ifirst + (nrws -1) -1 %d \n",ifirst,i,nrws,i + ifirst + (nrws -1) -1);
|
|
|
|
|
CONCAT(GEN_PSI_FUNC_NAME(TYPE_SYMBOL),_)(handle,nrws,i, nr, nza, baseIdx, hacksz, ldv, nzm,
|
|
|
|
|
rS,devIdisp, devJa, devVal, idiag, rP, cM);
|
|
|
|
|
CONCAT(GEN_PSI_FUNC_NAME(TYPE_SYMBOL),_)(handle,nrws,i, nr, nza, baseIdx,
|
|
|
|
|
hacksz, ldv, nzm,
|
|
|
|
|
rS,devIdisp, devJa, devVal,
|
|
|
|
|
idiag, rP, cM);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|