diff --git a/qpms/qpms_c.pyx b/qpms/qpms_c.pyx index 3fc1250..a37a945 100644 --- a/qpms/qpms_c.pyx +++ b/qpms/qpms_c.pyx @@ -1681,6 +1681,8 @@ cdef class ScatteringSystem: return ar def planewave_full(self, k_cart, E_cart): + k_cart = np.array(k_cart) + E_cart = np.array(E_cart) if k_cart.shape != (3,) or E_cart.shape != (3,): raise ValueError("k_cart and E_cart must be ndarrays of shape (3,)") cdef qpms_incfield_planewave_params_t p @@ -1709,7 +1711,57 @@ cdef class ScatteringSystem: cdef cdouble[::1] target_view = target_np qpms_scatsys_apply_Tmatrices_full(&target_view[0], &a_view[0], self.s) return target_np + + cdef qpms_scatsys_t *rawpointer(self): + return self.s + def scatter_solver(self, double k, iri=None): + return ScatteringMatrix(self, k, iri) + +cdef class ScatteringMatrix: + ''' + Wrapper over the C qpms_ss_LU structure that keeps the factorised mode problem matrix. + ''' + cdef ScatteringSystem ss # Here we keep the reference to the parent scattering system + cdef qpms_ss_LU lu + + def __cinit__(self, ScatteringSystem ss, double k, iri=None): + self.ss = ss + # TODO? pre-allocate the matrix with numpy to make it transparent? + if iri is None: + self.lu = qpms_scatsys_build_modeproblem_matrix_full_LU( + NULL, NULL, ss.rawpointer(), k) + else: + self.lu = qpms_scatsys_build_modeproblem_matrix_irrep_packed_LU( + NULL, NULL, ss.rawpointer(), iri, k) + + def __dealloc__(self): + qpms_ss_LU_free(self.lu) + + property iri: + def __get__(self): + return None if self.lu.full else self.lu.iri + + def __call__(self, a_inc): + cdef size_t vlen + cdef qpms_iri_t iri = -1; + if self.lu.full: + vlen = self.lu.ss[0].fecv_size + if len(a_inc) != vlen: + raise ValueError("Length of a full coefficient vector has to be %d, not %d" + % (vlen, len(a_inc))) + else: + iri = self.lu.iri + vlen = self.lu.ss[0].saecv_sizes[iri] + if len(a_inc) != vlen: + raise ValueError("Length of a %d. irrep packed coefficient vector has to be %d, not %d" + % (iri, vlen, len(a_inc))) + a_inc = np.array(a_inc, dtype=complex, copy=False, order='C') + cdef const cdouble[::1] a_view = a_inc; + cdef np.ndarray f = np.empty((vlen,), dtype=complex, order='C') + cdef cdouble[::1] f_view = f + qpms_scatsys_scatter_solve(&f_view[0], &a_view[0], self.lu) + return f def tlm2uvswfi(t, l, m): diff --git a/qpms/qpms_cdefs.pxd b/qpms/qpms_cdefs.pxd index 28a758b..24b8b61 100644 --- a/qpms/qpms_cdefs.pxd +++ b/qpms/qpms_cdefs.pxd @@ -398,6 +398,20 @@ cdef extern from "scatsystem.h": const void *args, bint add) cdouble *qpms_scatsys_apply_Tmatrices_full(cdouble *target_full, const cdouble *inc_full, const qpms_scatsys_t *ss) - - + struct qpms_ss_LU: + const qpms_scatsys_t *ss + bint full + qpms_iri_t iri + cdouble *a + int *ipiv + void qpms_ss_LU_free(qpms_ss_LU lu) + qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_full_LU(cdouble *target, + int *target_piv, const qpms_scatsys_t *ss, double k) + qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_irrep_packed_LU(cdouble *target, + int *target_piv, const qpms_scatsys_t *ss, qpms_iri_t iri, double k) + qpms_ss_LU qpms_scatsys_modeproblem_matrix_full_factorise(cdouble *modeproblem_matrix_full, + int *target_piv, const qpms_scatsys_t *ss) + qpms_ss_LU qpms_scatsys_modeproblem_matrix_irrep_packed_factorise(cdouble *modeproblem_matrix_full, + int *target_piv, const qpms_scatsys_t *ss, qpms_iri_t iri) + cdouble *qpms_scatsys_scatter_solve(cdouble *target_f, const cdouble *a_inc, qpms_ss_LU ludata) diff --git a/qpms/scatsystem.c b/qpms/scatsystem.c index 630f6bf..8e7380d 100644 --- a/qpms/scatsystem.c +++ b/qpms/scatsystem.c @@ -1,4 +1,9 @@ #include +#define lapack_int int +#define lapack_complex_double complex double +#define lapack_complex_double_real(z) (creal(z)) +#define lapack_complex_double_imag(z) (cimag(z)) +#include #include #include #include "scatsystem.h" @@ -1738,3 +1743,63 @@ ccart3_t qpms_scatsys_eval_E_irrep(const qpms_scatsys_t *ss, } #endif +void qpms_ss_LU_free(qpms_ss_LU lu) { + free(lu.a); + free(lu.ipiv); +} + +qpms_ss_LU qpms_scatsys_modeproblem_matrix_full_factorise(complex double *mpmatrix_full, + int *target_piv, const qpms_scatsys_t *ss) { + QPMS_ENSURE(mpmatrix_full, "A non-NULL pointer to the pre-calculated mode matrix is required"); + if (!target_piv) QPMS_CRASHING_MALLOC(target_piv, ss->fecv_size * sizeof(int)); + QPMS_ENSURE_SUCCESS(LAPACKE_zgetrf(LAPACK_ROW_MAJOR, ss->fecv_size, ss->fecv_size, + mpmatrix_full, ss->fecv_size, target_piv)); + qpms_ss_LU lu; + lu.a = mpmatrix_full; + lu.ipiv = target_piv; + lu.ss = ss; + lu.full = true; + lu.iri = -1; + return lu; +} + +qpms_ss_LU qpms_scatsys_modeproblem_matrix_irrep_packed_factorise(complex double *mpmatrix_packed, + int *target_piv, const qpms_scatsys_t *ss, qpms_iri_t iri) { + QPMS_ENSURE(mpmatrix_packed, "A non-NULL pointer to the pre-calculated mode matrix is required"); + size_t n = ss->saecv_sizes[iri]; + if (!target_piv) QPMS_CRASHING_MALLOC(target_piv, n * sizeof(int)); + QPMS_ENSURE_SUCCESS(LAPACKE_zgetrf(LAPACK_ROW_MAJOR, n, n, + mpmatrix_packed, n, target_piv)); + qpms_ss_LU lu; + lu.a = mpmatrix_packed; + lu.ipiv = target_piv; + lu.ss = ss; + lu.full = false; + lu.iri = iri; + return lu; +} + +qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_full_LU( + complex double *target, int *target_piv, + const qpms_scatsys_t *ss, double k){ + target = qpms_scatsys_build_modeproblem_matrix_full(target, ss, k); + return qpms_scatsys_modeproblem_matrix_full_factorise(target, target_piv, ss); +} + +qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_irrep_packed_LU( + complex double *target, int *target_piv, + const qpms_scatsys_t *ss, qpms_iri_t iri, double k){ + target = qpms_scatsys_build_modeproblem_matrix_irrep_packed(target, ss, iri, k); + return qpms_scatsys_modeproblem_matrix_irrep_packed_factorise(target, target_piv, ss, iri); +} + +complex double *qpms_scatsys_scatter_solve( + complex double *f, const complex double *a_inc, qpms_ss_LU lu) { + const size_t n = lu.full ? lu.ss->fecv_size : lu.ss->saecv_sizes[lu.iri]; + if (!f) QPMS_CRASHING_MALLOC(f, n * sizeof(complex double)); + memcpy(f, a_inc, n*sizeof(complex double)); // It will be rewritten by zgetrs + QPMS_ENSURE_SUCCESS(LAPACKE_zgetrs(LAPACK_ROW_MAJOR, 'N' /*trans*/, n /*n*/, 1 /*nrhs number of right hand sides*/, + lu.a /*a*/, n /*lda*/, lu.ipiv /*ipiv*/, f/*b*/, 1 /*ldb; CHECKME*/)); + return f; +} + diff --git a/qpms/scatsystem.h b/qpms/scatsystem.h index 2daddb0..d55002a 100644 --- a/qpms/scatsystem.h +++ b/qpms/scatsystem.h @@ -14,7 +14,6 @@ #include "vswf.h" #include - /// Overrides the number of threads spawned by the paralellized functions. /** TODO MORE DOC which are those? */ void qpms_scatsystem_set_nthreads(long n); @@ -297,13 +296,53 @@ complex double *qpms_scatsys_build_modeproblem_matrix_irrep_packed_orbitorder_pa /// LU factorisation (LAPACKE_zgetrf) result holder. typedef struct qpms_ss_LU { + const qpms_scatsys_t *ss; + bool full; ///< true if full matrix; false if irrep-packed. + qpms_iri_t iri; ///< Irrep index if `full == false`. /// LU decomposition array. complex double *a; /// Pivot index array, size at least max(1,min(m, n)). - lapack_int *ipiv; + int *ipiv; } qpms_ss_LU; +void qpms_ss_LU_free(qpms_ss_LU); + +/// Builds an LU-factorised mode/scattering problem \f$ (I - TS) \f$ matrix from scratch. +qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_full_LU( + complex double *target, ///< Pre-allocated target array. Optional (if NULL, new one is allocated). + int *target_piv, ///< Pre-allocated pivot array. Optional (if NULL, new one is allocated). + const qpms_scatsys_t *ss, + /*COMPLEXIFY*/ double k ///< Wave number to use in the translation matrix. + ); + +/// Builds an irrep-packed LU-factorised mode/scattering problem matrix from scratch. +qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_irrep_packed_LU( + complex double *target, ///< Pre-allocated target array. Optional (if NULL, new one is allocated). + int *target_piv, ///< Pre-allocated pivot array. Optional (if NULL, new one is allocated). + const qpms_scatsys_t *ss, qpms_iri_t iri, + /*COMPLEXIFY*/ double k ///< Wave number to use in the translation matrix. + ); + +/// Computes LU factorisation of a pre-calculated mode/scattering problem matrix, replacing its contents. +qpms_ss_LU qpms_scatsys_modeproblem_matrix_full_factorise( + complex double *modeproblem_matrix_full, ///< Pre-calculated mode problem matrix (I-TS). Mandatory. + int *target_piv, ///< Pre-allocated pivot array. Optional (if NULL, new one is allocated). + const qpms_scatsys_t *ss + ); + +/// Computes LU factorisation of a pre-calculated irrep-packed mode/scattering problem matrix, replacing its contents. +qpms_ss_LU qpms_scatsys_modeproblem_matrix_irrep_packed_factorise( + complex double *modeproblem_matrix_irrep_packed, ///< Pre-calculated mode problem matrix (I-TS). Mandatory. + int *target_piv, ///< Pre-allocated pivot array. Optional (if NULL, new one is allocated). + const qpms_scatsys_t *ss, qpms_iri_t iri + ); + +/// Solves a (possibly partial, irrep-packed) scattering problem \f$ (I-TS)f = Ta_\mathrm{inc} \f$ using a pre-factorised \f$ (I-TS) \f$. +complex double *qpms_scatsys_scatter_solve( + complex double *target_f, ///< Target (full or irrep-packed, depending on `ludata.full`) array for \a f. If NULL, a new one is allocated. + const complex double *a_inc, ///< Incident field expansion coefficient vector \a a (full or irrep-packed, depending on `ludata.full`). + qpms_ss_LU ludata ///< Pre-factorised \f$ I - TS \f$ matrix data. + ); -void qpms_ss_LU_free(qpms_ss_LU *); /// NOT IMPLEMENTED Dumps a qpms_scatsys_t structure to a file. qpms_errno_t qpms_scatsys_dump(qpms_scatsys_t *ss, char *path);