diff --git a/qpms/scatsystem.c b/qpms/scatsystem.c index 67c64b6..e182318 100644 --- a/qpms/scatsystem.c +++ b/qpms/scatsystem.c @@ -43,19 +43,19 @@ void qpms_scatsystem_set_nthreads(long n) { qpms_scatsystem_nthreads_override = n; } -static inline void qpms_ss_ensure_periodic(qpms_scatsys_t *ss) { +static inline void qpms_ss_ensure_periodic(const qpms_scatsys_t *ss) { QPMS_ENSURE(ss->lattice_dimension > 0, "This method is applicable only to periodic systems."); } -static inline void qpms_ss_ensure_periodic_a(qpms_scatsys_t *ss, const char *s) { +static inline void qpms_ss_ensure_periodic_a(const qpms_scatsys_t *ss, const char *s) { QPMS_ENSURE(ss->lattice_dimension > 0, "This method is applicable only to periodic systems. Use %s instead.", s); } -static inline void qpms_ss_ensure_nonperiodic(qpms_scatsys_t *ss) { +static inline void qpms_ss_ensure_nonperiodic(const qpms_scatsys_t *ss) { QPMS_ENSURE(ss->lattice_dimension == 0, "This method is applicable only to nonperiodic systems."); } -static inline void qpms_ss_ensure_nonperiodic_a(qpms_scatsys_t *ss, const char *s) { +static inline void qpms_ss_ensure_nonperiodic_a(const qpms_scatsys_t *ss, const char *s) { QPMS_ENSURE(ss->lattice_dimension == 0, "This method is applicable only to nonperiodic systems. Use %s instead.", s); } @@ -1145,6 +1145,20 @@ complex double *qpms_scatsys_build_translation_matrix_full( target, ss, k, QPMS_HANKEL_PLUS); } +complex double *qpms_scatsyswk_build_translation_matrix_full( + /// Target memory with capacity for ss->fecv_size**2 elements. If NULL, new will be allocated. + complex double *target, + const qpms_scatsys_at_omega_k_t *sswk + ) +{ + const qpms_scatsys_at_omega_t *ssw = sswk->ssw; + const complex double wavenumber = ssw->wavenumber; + const qpms_scatsys_t *ss = ssw->ss; + qpms_ss_ensure_periodic(ss); + const cart3_t k_cart3 = cart3_from_double_array(sswk->k); + return qpms_scatsys_periodic_build_translation_matrix_full(target, ss, wavenumber, &k_cart3); +} + complex double *qpms_scatsys_build_translation_matrix_e_full( /// Target memory with capacity for ss->fecv_size**2 elements. If NULL, new will be allocated. complex double *target, @@ -1184,9 +1198,8 @@ complex double *qpms_scatsys_build_translation_matrix_e_full( return target; } -static inline int qpms_ss_ppair_W32xy(qpms_scatsys_t *ss, - qpms_ss_pi_t pdest, qpms_ss_pi_t psrc, complex double wavenumber, - const cart2_t kvector, +static inline int qpms_ss_ppair_W32xy(const qpms_scatsys_t *ss, + qpms_ss_pi_t pdest, qpms_ss_pi_t psrc, complex double wavenumber, const cart2_t kvector, complex double *target, const ptrdiff_t deststride, const ptrdiff_t srcstride, qpms_ewald_part parts) { const qpms_vswf_set_spec_t *srcspec = qpms_ss_bspec_pi(ss, psrc); @@ -1198,11 +1211,25 @@ static inline int qpms_ss_ppair_W32xy(qpms_scatsys_t *ss, return qpms_trans_calculator_get_trans_array_e32_e(ss->c, target, NULL /*err*/, destspec, deststride, srcspec, srcstride, - ss->eta, wavenumber, + ss->per->eta, wavenumber, cart3xy2cart2(ss->per->lattice_basis[0]), cart3xy2cart2(ss->per->lattice_basis[1]), kvector, - cart2_substract(cart3xy2cart2(ss->p[pdest].pos), cart3xy2cart2(ss->p[psrc].pos)) - u->maxR, u->maxK, parts); + cart2_substract(cart3xy2cart2(ss->p[pdest].pos), cart3xy2cart2(ss->p[psrc].pos)), + maxR, maxK, parts); +} + +static inline int qpms_ss_ppair_W(const qpms_scatsys_t *ss, + qpms_ss_pi_t pdest, qpms_ss_pi_t psrc, complex double wavenumber, const double wavevector[], + complex double *target, const ptrdiff_t deststride, const ptrdiff_t srcstride, + qpms_ewald_part parts) { + if(ss->lattice_dimension == 2 && // Currently, we can only the xy-plane + !ss->per->lattice_basis[0].z && !ss->per->lattice_basis[1].z && + !wavevector[2]) + return qpms_ss_ppair_W32xy(ss, pdest, psrc, wavenumber, cart2_from_double_array(wavevector), + target + deststride * ss->fecv_pstarts[pdest] + srcstride * ss->fecv_pstarts[psrc], + deststride, srcstride, parts); + else + QPMS_NOT_IMPLEMENTED("Only 2D xy-lattices currently supported"); } complex double *qpms_scatsys_periodic_build_translation_matrix_full( @@ -1216,7 +1243,7 @@ complex double *qpms_scatsys_periodic_build_translation_matrix_full( const ptrdiff_t deststride = ss->fecv_size, srcstride = 1; // We have some limitations in the current implementation if(ss->lattice_dimension == 2 && // Currently, we can only the xy-plane - !ss->per->lattice_basis[0].z && !ss->per_lattice_basis[1].z && + !ss->per->lattice_basis[0].z && !ss->per->lattice_basis[1].z && !wavevector->z) { for (qpms_ss_pi_t pd = 0; pd < ss->p_count; ++pd) for (qpms_ss_pi_t ps = 0; ps < ss->p_count; ++ps) { @@ -1229,21 +1256,23 @@ complex double *qpms_scatsys_periodic_build_translation_matrix_full( return target; } -complex double *qpms_scatsysw_build_modeproblem_matrix_full( +// Common implementation of qpms_scatsysw[k]_build_modeproblem_matrix_full +static inline complex double *qpms_scatsysw_scatsyswk_build_modeproblem_matrix_full( /// Target memory with capacity for ss->fecv_size**2 elements. If NULL, new will be allocated. complex double *target, - const qpms_scatsys_at_omega_t *ssw + const qpms_scatsys_at_omega_t *ssw, + const double k[] // NULL if non-periodic ) { - const complex double k = ssw->wavenumber; + const complex double wavenumber = ssw->wavenumber; const qpms_scatsys_t *ss = ssw->ss; qpms_ss_ensure_nonperiodic(ss); const size_t full_len = ss->fecv_size; if(!target) QPMS_CRASHING_MALLOC(target, SQ(full_len) * sizeof(complex double)); - complex double *tmp; + complex double *tmp; // translation matrix, S or W QPMS_CRASHING_MALLOC(tmp, SQ(ss->max_bspecn) * sizeof(complex double)); - memset(target, 0, SQ(full_len) * sizeof(complex double)); //unnecessary? + memset(target, 0, SQ(full_len) * sizeof(complex double)); const complex double zero = 0, minusone = -1; { // Non-diagonal part; M[piR, piC] = -T[piR] S(piR<-piC) size_t fullvec_offsetR = 0; @@ -1255,49 +1284,55 @@ complex double *qpms_scatsysw_build_modeproblem_matrix_full( const complex double *tmmR = ssw->tm[ss->p[piR].tmatrix_id]->m; for(qpms_ss_pi_t piC = 0; piC < ss->p_count; ++piC) { const qpms_vswf_set_spec_t *bspecC = ssw->tm[ss->p[piC].tmatrix_id]->spec; - if(piC != piR) { // The diagonal will be dealt with later. - const cart3_t posC = ss->p[piC].pos; - QPMS_ENSURE_SUCCESS(qpms_trans_calculator_get_trans_array_lc3p(ss->c, - tmp, // tmp is S(piR<-piC) - bspecR, bspecC->n, bspecC, 1, - k, posR, posC, QPMS_HANKEL_PLUS)); - cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, - bspecR->n /*m*/, bspecC->n /*n*/, bspecR->n /*k*/, - &minusone/*alpha*/, tmmR/*a*/, bspecR->n/*lda*/, - tmp/*b*/, bspecC->n/*ldb*/, &zero/*beta*/, - target + fullvec_offsetR*full_len + fullvec_offsetC /*c*/, - full_len /*ldc*/); + if (k == NULL) { // non-periodic case + if(piC != piR) { // No "self-interaction" in non-periodic case + const cart3_t posC = ss->p[piC].pos; + QPMS_ENSURE_SUCCESS(qpms_trans_calculator_get_trans_array_lc3p(ss->c, + tmp, // tmp is S(piR<-piC) + bspecR, bspecC->n, bspecC, 1, + wavenumber, posR, posC, QPMS_HANKEL_PLUS)); + } + } else { // periodic case + QPMS_ENSURE_SUCCESS(qpms_ss_ppair_W(ss, piR, piC, wavenumber, k, + tmp /*target*/, bspecC->n /*deststride*/, 1 /*srcstride*/, + QPMS_EWALD_FULL)); } + + cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, + bspecR->n /*m*/, bspecC->n /*n*/, bspecR->n /*k*/, + &minusone/*alpha*/, tmmR/*a*/, bspecR->n/*lda*/, + tmp/*b*/, bspecC->n/*ldb*/, &zero/*beta*/, + target + fullvec_offsetR*full_len + fullvec_offsetC /*c*/, + full_len /*ldc*/); + fullvec_offsetC += bspecC->n; } fullvec_offsetR += bspecR->n; } } - // diagonal part M[pi,pi] = +1 - for (size_t i = 0; i < full_len; ++i) target[full_len * i + i] = +1; - + + // Add the identity, diagonal part M[pi,pi] += 1 + for (size_t i = 0; i < full_len; ++i) target[full_len * i + i] += 1; + free(tmp); return target; } -complex double *qpms_scatsyswk_build_modeproblem_matrix_full( - /// Target memory with capacity for ss->fecv_size**2 elements. If NULL, new will be allocated. - complex double *target, - const qpms_scatsys_at_omega_k_t *sswk - ) -{ - const qpms_scatsys_at_omega_t *ssw = sswk->ssw; - const complex double k = ssw->wavenumber; - const qpms_scatsys_t *ss = ssw->ss; - qpms_ss_ensure_periodic(ss); - const size_t full_len = ss->fecv_size; - if(!target) - QPMS_CRASHING_MALLOC(target, SQ(full_len) * sizeof(complex double)); - //////////// TODO //////////////// - QPMS_NOT_IMPLEMENTED("TODO"); - return target; +complex double *qpms_scatsysw_build_modeproblem_matrix_full( + complex double *target, const qpms_scatsys_at_omega_t *ssw) { + qpms_ss_ensure_nonperiodic_a(ssw->ss, "qpms_scatsyswk_build_modeproblem_matrix_full()"); + return qpms_scatsysw_scatsyswk_build_modeproblem_matrix_full( + target, ssw, NULL); } +complex double *qpms_scatsyswk_build_modeproblem_matrix_full( + complex double *target, const qpms_scatsys_at_omega_k_t *sswk) +{ + qpms_ss_ensure_periodic_a(sswk->ssw->ss, "qpms_scatsysw_build_modeproblem_matrix_full()"); + return qpms_scatsysw_scatsyswk_build_modeproblem_matrix_full(target, sswk->ssw, sswk->k); +} + + // Serial reference implementation. complex double *qpms_scatsysw_build_modeproblem_matrix_irrep_packed_serial( /// Target memory with capacity for ss->saecv_sizes[iri]**2 elements. If NULL, new will be allocated. @@ -1990,7 +2025,7 @@ void qpms_ss_LU_free(qpms_ss_LU lu) { } qpms_ss_LU qpms_scatsysw_modeproblem_matrix_full_factorise(complex double *mpmatrix_full, - int *target_piv, const qpms_scatsys_at_omega_t *ssw, const qpms_scatsys_at_k_omega_t *sswk) { + int *target_piv, const qpms_scatsys_at_omega_t *ssw, const qpms_scatsys_at_omega_k_t *sswk) { if (sswk) { QPMS_ASSERT(sswk->ssw == ssw || !ssw); ssw = sswk->ssw; @@ -2016,7 +2051,7 @@ qpms_ss_LU qpms_scatsysw_modeproblem_matrix_full_factorise(complex double *mpmat qpms_ss_LU qpms_scatsysw_modeproblem_matrix_irrep_packed_factorise(complex double *mpmatrix_packed, int *target_piv, const qpms_scatsys_at_omega_t *ssw, qpms_iri_t iri) { QPMS_ENSURE(mpmatrix_packed, "A non-NULL pointer to the pre-calculated mode matrix is required"); - qpms_scatsys_ensure_nonperiodic(ssw->ss); + qpms_ss_ensure_nonperiodic(ssw->ss); size_t n = ssw->ss->saecv_sizes[iri]; if (!target_piv) QPMS_CRASHING_MALLOC(target_piv, n * sizeof(int)); QPMS_ENSURE_SUCCESS(LAPACKE_zgetrf(LAPACK_ROW_MAJOR, n, n, @@ -2033,7 +2068,7 @@ qpms_ss_LU qpms_scatsysw_modeproblem_matrix_irrep_packed_factorise(complex doubl qpms_ss_LU qpms_scatsysw_build_modeproblem_matrix_full_LU( complex double *target, int *target_piv, const qpms_scatsys_at_omega_t *ssw){ - qpms_scatsys_ensure_nonperiodic_a(ssw->ss, "qpms_scatsyswk_build_modeproblem_matrix_full_LU()"); + qpms_ss_ensure_nonperiodic_a(ssw->ss, "qpms_scatsyswk_build_modeproblem_matrix_full_LU()"); target = qpms_scatsysw_build_modeproblem_matrix_full(target, ssw); return qpms_scatsysw_modeproblem_matrix_full_factorise(target, target_piv, ssw, NULL); } diff --git a/qpms/scatsystem.h b/qpms/scatsystem.h index d82ff40..fc30bea 100644 --- a/qpms/scatsystem.h +++ b/qpms/scatsystem.h @@ -356,7 +356,7 @@ complex double *qpms_scatsys_build_translation_matrix_full( complex double k ///< Wave number to use in the translation matrix. ); -/// Creates the full \f$ (I - WS) \f$ matrix of the periodic scattering system. NI +/// Creates the full \f$ (I - WS) \f$ matrix of the periodic scattering system. /** * \returns \a target on success, NULL on error. */ @@ -422,10 +422,11 @@ complex double *qpms_scatsysw_build_modeproblem_matrix_irrep_packed_serial( qpms_iri_t iri ///< Index of the irreducible representation in ssw->ss->sym ); +struct qpms_scatsys_at_omega_k_t; // Defined below. /// LU factorisation (LAPACKE_zgetrf) result holder. typedef struct qpms_ss_LU { const qpms_scatsys_at_omega_t *ssw; - const qpms_scatsys_at_omega_k_t *sswk; ///< Only for periodic systems, otherwise NULL. + const struct qpms_scatsys_at_omega_k_t *sswk; ///< Only for periodic systems, otherwise NULL. bool full; ///< true if full matrix; false if irrep-packed. qpms_iri_t iri; ///< Irrep index if `full == false`. /// LU decomposition array. @@ -455,7 +456,7 @@ qpms_ss_LU qpms_scatsysw_modeproblem_matrix_full_factorise( complex double *modeproblem_matrix_full, ///< Pre-calculated mode problem matrix (I-TS). Mandatory. int *target_piv, ///< Pre-allocated pivot array. Optional (if NULL, new one is allocated). const qpms_scatsys_at_omega_t *ssw, ///< Must be filled for non-periodic systems. - const qpms_scatsys_at_omega_k_t *sswk ///< Must be filled for periodic systems, otherwise must be NULL. + const struct qpms_scatsys_at_omega_k_t *sswk ///< Must be filled for periodic systems, otherwise must be NULL. ); /// Computes LU factorisation of a pre-calculated irrep-packed mode/scattering problem matrix, replacing its contents. @@ -476,12 +477,15 @@ complex double *qpms_scatsys_scatter_solve( // ======================= Periodic system -only related stuff ============================= /// Scattering system at a given frequency and k-vector. Used only with periodic systems. +/** + * N.B. use as a stack variable now, but this might become heap-allocated in the future (with own con- and destructor) + */ typedef struct qpms_scatsys_at_omega_k_t { const qpms_scatsys_at_omega_t *ssw; double k[3]; ///< The k-vector's cartesian coordinates. } qpms_scatsys_at_omega_k_t; -/// Creates the full \f$ (I - WS) \f$ matrix of the periodic scattering system. NI +/// Creates the full \f$ (I - WS) \f$ matrix of the periodic scattering system. /** * \returns \a target on success, NULL on error. */ @@ -497,19 +501,19 @@ complex double *qpms_scatsys_periodic_build_translation_matrix_full( complex double *target, const qpms_scatsys_t *ss, complex double wavenumber, ///< Wave number to use in the translation matrix. - const double wavevector[] ///< Wavevector / pseudomomentum in cartesian coordinates. + const cart3_t *wavevector ///< Wavevector / pseudomomentum in cartesian coordinates. ); -/// Global translation matrix. NI +/// Global translation matrix. complex double *qpms_scatsyswk_build_translation_matrix_full( /// Target memory with capacity for ss->fecv_size**2 elements. If NULL, new will be allocated. complex double *target, - const qpms_scatsys_omega_k_t *sswk + const qpms_scatsys_at_omega_k_t *sswk ); /// Builds an LU-factorised mode/scattering problem \f$ (I - TS) \f$ matrix from scratch. Periodic systems only. -qpms_ss_LU qpms_scatsysw_build_modeproblem_matrix_full_LU( +qpms_ss_LU qpms_scatsyswk_build_modeproblem_matrix_full_LU( complex double *target, ///< Pre-allocated target array. Optional (if NULL, new one is allocated). int *target_piv, ///< Pre-allocated pivot array. Optional (if NULL, new one is allocated). const qpms_scatsys_at_omega_k_t *sswk