From c0b2343b89cc02e3fae1b25f12b258657a68e6cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Ne=C4=8Dada?= Date: Fri, 29 Mar 2019 13:17:28 +0200 Subject: [PATCH] Try whether Triton behaves better if the parallelized code uses no OpenBLAS. Former-commit-id: ac1c17635e2d01d8a14129a90ffb43d3156510fa --- qpms/CMakeLists.txt | 2 +- qpms/own_zgemm.c | 313 ++++++++++++++++++++++++++++++++++++++++++++ qpms/qpmsblas.h | 20 +++ qpms/scatsystem.c | 13 +- setup.py | 2 + 5 files changed, 346 insertions(+), 4 deletions(-) create mode 100644 qpms/own_zgemm.c create mode 100644 qpms/qpmsblas.h diff --git a/qpms/CMakeLists.txt b/qpms/CMakeLists.txt index f2ca92d..c5fe5da 100644 --- a/qpms/CMakeLists.txt +++ b/qpms/CMakeLists.txt @@ -9,7 +9,7 @@ include_directories(${DIRS}) add_library (qpms translations.c tmatrices.c vecprint.c vswf.c wigner.c lattices2d.c gaunt.c error.c legendre.c symmetries.c vecprint.c - bessel.c) + bessel.c own_zgemm.c) use_c99() set(LIBS ${LIBS} ${GSL_LIBRARIES} ${GSLCBLAS_LIBRARIES}) diff --git a/qpms/own_zgemm.c b/qpms/own_zgemm.c new file mode 100644 index 0000000..7de5baa --- /dev/null +++ b/qpms/own_zgemm.c @@ -0,0 +1,313 @@ +/* IMPORTANT! This code is partially taken from GSL, so everything must be GPL'd + * or this has to be rewritten (or removed; the only reason to use this are problems + * with OpenBLAS) when distributed. + */ + +#include "qpmsblas.h" +#include +#include +#include + +void +cblas_xerbla (int p, const char *rout, const char *form, ...) +{ + va_list ap; + + va_start (ap, form); + + if (p) + { + fprintf (stderr, "Parameter %d to routine %s was incorrect\n", p, rout); + } + + vfprintf (stderr, form, ap); + va_end (ap); + + abort (); +} + + +#define BASE double + +#define INDEX QPMS_BLAS_INDEX_T +#define OFFSET(N, incX) ((incX) > 0 ? 0 : ((N) - 1) * (-(incX))) +#define BLAS_ERROR(x) cblas_xerbla(0, __FILE__, x); + +#define MAX(x,y) (((x) < (y)) ? (y) : (x)) + +#define CONJUGATE(x) ((x) == CblasConjTrans) +#define TRANSPOSE(x) ((x) == CblasTrans || (x) == CblasConjTrans) +#define UPPER(x) ((x) == CblasUpper) +#define LOWER(x) ((x) == CblasLower) + +/* Handling of packed complex types... */ + +#define REAL(a,i) (((BASE *) a)[2*(i)]) +#define IMAG(a,i) (((BASE *) a)[2*(i)+1]) + +#define REAL0(a) (((BASE *)a)[0]) +#define IMAG0(a) (((BASE *)a)[1]) + +#define CONST_REAL(a,i) (((const BASE *) a)[2*(i)]) +#define CONST_IMAG(a,i) (((const BASE *) a)[2*(i)+1]) + +#define CONST_REAL0(a) (((const BASE *)a)[0]) +#define CONST_IMAG0(a) (((const BASE *)a)[1]) + + +#define GB(KU,KL,lda,i,j) ((KU+1+(i-j))*lda + j) + +#define TRCOUNT(N,i) ((((i)+1)*(2*(N)-(i)))/2) + +/* #define TBUP(N,i,j) */ +/* #define TBLO(N,i,j) */ + +#define TPUP(N,i,j) (TRCOUNT(N,(i)-1)+(j)-(i)) +#define TPLO(N,i,j) (((i)*((i)+1))/2 + (j)) + + +/* check if CBLAS_ORDER is correct */ +#define CHECK_ORDER(pos,posIfError,order) \ +if(((order)!=CblasRowMajor)&&((order)!=CblasColMajor)) \ + pos = posIfError; + +/* check if CBLAS_TRANSPOSE is correct */ +#define CHECK_TRANSPOSE(pos,posIfError,Trans) \ +if(((Trans)!=CblasNoTrans)&&((Trans)!=CblasTrans)&&((Trans)!=CblasConjTrans)) \ + pos = posIfError; + +/* check if a dimension argument is correct */ +#define CHECK_DIM(pos,posIfError,dim) \ +if((dim)<0) \ + pos = posIfError; + +/* cblas_xgemm() */ +#define CBLAS_ERROR_GEMM(pos,Order,TransA,TransB,M,N,K,alpha,A,lda,B,ldb,beta,C,ldc) \ +{ \ + CBLAS_TRANSPOSE __transF=CblasNoTrans,__transG=CblasNoTrans; \ + if((Order)==CblasRowMajor) { \ + __transF = ((TransA)!=CblasConjTrans) ? (TransA) : CblasTrans; \ + __transG = ((TransB)!=CblasConjTrans) ? (TransB) : CblasTrans; \ + } else { \ + __transF = ((TransB)!=CblasConjTrans) ? (TransB) : CblasTrans; \ + __transG = ((TransA)!=CblasConjTrans) ? (TransA) : CblasTrans; \ + } \ + CHECK_ORDER(pos,1,Order); \ + CHECK_TRANSPOSE(pos,2,TransA); \ + CHECK_TRANSPOSE(pos,3,TransB); \ + CHECK_DIM(pos,4,M); \ + CHECK_DIM(pos,5,N); \ + CHECK_DIM(pos,6,K); \ + if((Order)==CblasRowMajor) { \ + if(__transF==CblasNoTrans) { \ + if((lda), this must be include _afterwards_ because of the typedefs! +#ifndef QPMSBLAS_H +#define QPMSBLAS_H +#define QPMS_BLAS_INDEX_T long long int + +#ifndef CBLAS_H +typedef enum {CblasRowMajor=101, CblasColMajor=102} CBLAS_LAYOUT; +typedef enum {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113} CBLAS_TRANSPOSE; +typedef enum {CblasUpper=121, CblasLower=122} CBLAS_UPLO; +typedef enum {CblasNonUnit=131, CblasUnit=132} CBLAS_DIAG; +typedef enum {CblasLeft=141, CblasRight=142} CBLAS_SIDE; +#endif + +void qpms_zgemm(CBLAS_LAYOUT Order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, + const QPMS_BLAS_INDEX_T M, const QPMS_BLAS_INDEX_T N, const QPMS_BLAS_INDEX_T K, + const _Complex double *alpha, const _Complex double *A, const QPMS_BLAS_INDEX_T lda, + const _Complex double *B, const QPMS_BLAS_INDEX_T ldb, + const _Complex double *beta, _Complex double *C, const QPMS_BLAS_INDEX_T ldc); + +#endif //QPMSBLAS_H diff --git a/qpms/scatsystem.c b/qpms/scatsystem.c index 3a62ed3..6dc47cc 100644 --- a/qpms/scatsystem.c +++ b/qpms/scatsystem.c @@ -16,6 +16,13 @@ #include "tmatrices.h" #include +#ifdef QPMS_SCATSYSTEM_USE_OWN_BLAS +#include "qpmsblas.h" +#define SERIAL_ZGEMM qpms_zgemm +#else +#define SERIAL_ZGEMM cblas_zgemm +#endif + #define SQ(x) ((x)*(x)) #define QPMS_SCATSYS_LEN_RTOL 1e-13 #define QPMS_SCATSYS_TMATRIX_ATOL 1e-14 @@ -1362,7 +1369,7 @@ static void *qpms_scatsys_build_modeproblem_matrix_irrep_packed_parallelR_thread bspecR, bspecC->n, bspecC, 1, a->k, posR, posC)); - cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, + SERIAL_ZGEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, bspecR->n /*m*/, bspecC->n /*n*/, bspecR->n /*k*/, &one/*alpha*/, tmmR/*a*/, bspecR->n/*lda*/, Sblock/*b*/, bspecC->n/*ldb*/, &zero/*beta*/, @@ -1374,7 +1381,7 @@ static void *qpms_scatsys_build_modeproblem_matrix_irrep_packed_parallelR_thread } // tmp[oiR|piR,piC] = ∑_K M[piR,K] U*[K,piC] - cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasConjTrans, + SERIAL_ZGEMM(CblasRowMajor, CblasNoTrans, CblasConjTrans, particle_fullsizeR /*M*/, orbit_packedsizeC /*N*/, particle_fullsizeC /*K*/, &one /*alpha*/, TSblock/*A*/, particle_fullsizeC/*ldA*/, omC + opiC*particle_fullsizeC /*B*/, @@ -1382,7 +1389,7 @@ static void *qpms_scatsys_build_modeproblem_matrix_irrep_packed_parallelR_thread tmp /*C*/, orbit_packedsizeC /*LDC*/); // target[oiR|piR,oiC|piC] += U[...] tmp[...] - cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, + SERIAL_ZGEMM(CblasRowMajor, CblasNoTrans, CblasNoTrans, orbit_packedsizeR /*M*/, orbit_packedsizeC /*N*/, particle_fullsizeR /*K*/, &one /*alpha*/, omR + opiR*particle_fullsizeR/*A*/, orbit_fullsizeR/*ldA*/, tmp /*B*/, orbit_packedsizeC /*ldB*/, &one /*beta*/, diff --git a/setup.py b/setup.py index 486d0c3..f6c78cd 100755 --- a/setup.py +++ b/setup.py @@ -74,10 +74,12 @@ qpms_c = Extension('qpms_c', 'qpms/tmatrices.c', 'qpms/error.c', 'qpms/bessel.c', + 'qpms/own_zgemm.c', ], extra_compile_args=['-std=c99','-ggdb', '-O3', '-DQPMS_COMPILE_PYTHON_EXTENSIONS', # this is required #'-DQPMS_USE_OMP', + '-DQPMS_SCATSYSTEM_USE_OWN_BLAS', '-DDISABLE_NDEBUG', # uncomment to enable assertions in the modules #'-fopenmp', ],