Solve scattering problem using LU decomposition.
Former-commit-id: 32388ecc2da5d60a3b3616a765c0ecaed87fb4fd
This commit is contained in:
parent
5758c5d587
commit
0ffee1a073
|
@ -1681,6 +1681,8 @@ cdef class ScatteringSystem:
|
|||
return ar
|
||||
|
||||
def planewave_full(self, k_cart, E_cart):
|
||||
k_cart = np.array(k_cart)
|
||||
E_cart = np.array(E_cart)
|
||||
if k_cart.shape != (3,) or E_cart.shape != (3,):
|
||||
raise ValueError("k_cart and E_cart must be ndarrays of shape (3,)")
|
||||
cdef qpms_incfield_planewave_params_t p
|
||||
|
@ -1709,7 +1711,57 @@ cdef class ScatteringSystem:
|
|||
cdef cdouble[::1] target_view = target_np
|
||||
qpms_scatsys_apply_Tmatrices_full(&target_view[0], &a_view[0], self.s)
|
||||
return target_np
|
||||
|
||||
cdef qpms_scatsys_t *rawpointer(self):
|
||||
return self.s
|
||||
|
||||
def scatter_solver(self, double k, iri=None):
|
||||
return ScatteringMatrix(self, k, iri)
|
||||
|
||||
cdef class ScatteringMatrix:
|
||||
'''
|
||||
Wrapper over the C qpms_ss_LU structure that keeps the factorised mode problem matrix.
|
||||
'''
|
||||
cdef ScatteringSystem ss # Here we keep the reference to the parent scattering system
|
||||
cdef qpms_ss_LU lu
|
||||
|
||||
def __cinit__(self, ScatteringSystem ss, double k, iri=None):
|
||||
self.ss = ss
|
||||
# TODO? pre-allocate the matrix with numpy to make it transparent?
|
||||
if iri is None:
|
||||
self.lu = qpms_scatsys_build_modeproblem_matrix_full_LU(
|
||||
NULL, NULL, ss.rawpointer(), k)
|
||||
else:
|
||||
self.lu = qpms_scatsys_build_modeproblem_matrix_irrep_packed_LU(
|
||||
NULL, NULL, ss.rawpointer(), iri, k)
|
||||
|
||||
def __dealloc__(self):
|
||||
qpms_ss_LU_free(self.lu)
|
||||
|
||||
property iri:
|
||||
def __get__(self):
|
||||
return None if self.lu.full else self.lu.iri
|
||||
|
||||
def __call__(self, a_inc):
|
||||
cdef size_t vlen
|
||||
cdef qpms_iri_t iri = -1;
|
||||
if self.lu.full:
|
||||
vlen = self.lu.ss[0].fecv_size
|
||||
if len(a_inc) != vlen:
|
||||
raise ValueError("Length of a full coefficient vector has to be %d, not %d"
|
||||
% (vlen, len(a_inc)))
|
||||
else:
|
||||
iri = self.lu.iri
|
||||
vlen = self.lu.ss[0].saecv_sizes[iri]
|
||||
if len(a_inc) != vlen:
|
||||
raise ValueError("Length of a %d. irrep packed coefficient vector has to be %d, not %d"
|
||||
% (iri, vlen, len(a_inc)))
|
||||
a_inc = np.array(a_inc, dtype=complex, copy=False, order='C')
|
||||
cdef const cdouble[::1] a_view = a_inc;
|
||||
cdef np.ndarray f = np.empty((vlen,), dtype=complex, order='C')
|
||||
cdef cdouble[::1] f_view = f
|
||||
qpms_scatsys_scatter_solve(&f_view[0], &a_view[0], self.lu)
|
||||
return f
|
||||
|
||||
|
||||
def tlm2uvswfi(t, l, m):
|
||||
|
|
|
@ -398,6 +398,20 @@ cdef extern from "scatsystem.h":
|
|||
const void *args, bint add)
|
||||
cdouble *qpms_scatsys_apply_Tmatrices_full(cdouble *target_full, const cdouble *inc_full,
|
||||
const qpms_scatsys_t *ss)
|
||||
|
||||
|
||||
struct qpms_ss_LU:
|
||||
const qpms_scatsys_t *ss
|
||||
bint full
|
||||
qpms_iri_t iri
|
||||
cdouble *a
|
||||
int *ipiv
|
||||
void qpms_ss_LU_free(qpms_ss_LU lu)
|
||||
qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_full_LU(cdouble *target,
|
||||
int *target_piv, const qpms_scatsys_t *ss, double k)
|
||||
qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_irrep_packed_LU(cdouble *target,
|
||||
int *target_piv, const qpms_scatsys_t *ss, qpms_iri_t iri, double k)
|
||||
qpms_ss_LU qpms_scatsys_modeproblem_matrix_full_factorise(cdouble *modeproblem_matrix_full,
|
||||
int *target_piv, const qpms_scatsys_t *ss)
|
||||
qpms_ss_LU qpms_scatsys_modeproblem_matrix_irrep_packed_factorise(cdouble *modeproblem_matrix_full,
|
||||
int *target_piv, const qpms_scatsys_t *ss, qpms_iri_t iri)
|
||||
cdouble *qpms_scatsys_scatter_solve(cdouble *target_f, const cdouble *a_inc, qpms_ss_LU ludata)
|
||||
|
||||
|
|
|
@ -1,4 +1,9 @@
|
|||
#include <stdlib.h>
|
||||
#define lapack_int int
|
||||
#define lapack_complex_double complex double
|
||||
#define lapack_complex_double_real(z) (creal(z))
|
||||
#define lapack_complex_double_imag(z) (cimag(z))
|
||||
#include <lapacke.h>
|
||||
#include <cblas.h>
|
||||
#include <lapacke.h>
|
||||
#include "scatsystem.h"
|
||||
|
@ -1738,3 +1743,63 @@ ccart3_t qpms_scatsys_eval_E_irrep(const qpms_scatsys_t *ss,
|
|||
}
|
||||
#endif
|
||||
|
||||
void qpms_ss_LU_free(qpms_ss_LU lu) {
|
||||
free(lu.a);
|
||||
free(lu.ipiv);
|
||||
}
|
||||
|
||||
qpms_ss_LU qpms_scatsys_modeproblem_matrix_full_factorise(complex double *mpmatrix_full,
|
||||
int *target_piv, const qpms_scatsys_t *ss) {
|
||||
QPMS_ENSURE(mpmatrix_full, "A non-NULL pointer to the pre-calculated mode matrix is required");
|
||||
if (!target_piv) QPMS_CRASHING_MALLOC(target_piv, ss->fecv_size * sizeof(int));
|
||||
QPMS_ENSURE_SUCCESS(LAPACKE_zgetrf(LAPACK_ROW_MAJOR, ss->fecv_size, ss->fecv_size,
|
||||
mpmatrix_full, ss->fecv_size, target_piv));
|
||||
qpms_ss_LU lu;
|
||||
lu.a = mpmatrix_full;
|
||||
lu.ipiv = target_piv;
|
||||
lu.ss = ss;
|
||||
lu.full = true;
|
||||
lu.iri = -1;
|
||||
return lu;
|
||||
}
|
||||
|
||||
qpms_ss_LU qpms_scatsys_modeproblem_matrix_irrep_packed_factorise(complex double *mpmatrix_packed,
|
||||
int *target_piv, const qpms_scatsys_t *ss, qpms_iri_t iri) {
|
||||
QPMS_ENSURE(mpmatrix_packed, "A non-NULL pointer to the pre-calculated mode matrix is required");
|
||||
size_t n = ss->saecv_sizes[iri];
|
||||
if (!target_piv) QPMS_CRASHING_MALLOC(target_piv, n * sizeof(int));
|
||||
QPMS_ENSURE_SUCCESS(LAPACKE_zgetrf(LAPACK_ROW_MAJOR, n, n,
|
||||
mpmatrix_packed, n, target_piv));
|
||||
qpms_ss_LU lu;
|
||||
lu.a = mpmatrix_packed;
|
||||
lu.ipiv = target_piv;
|
||||
lu.ss = ss;
|
||||
lu.full = false;
|
||||
lu.iri = iri;
|
||||
return lu;
|
||||
}
|
||||
|
||||
qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_full_LU(
|
||||
complex double *target, int *target_piv,
|
||||
const qpms_scatsys_t *ss, double k){
|
||||
target = qpms_scatsys_build_modeproblem_matrix_full(target, ss, k);
|
||||
return qpms_scatsys_modeproblem_matrix_full_factorise(target, target_piv, ss);
|
||||
}
|
||||
|
||||
qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_irrep_packed_LU(
|
||||
complex double *target, int *target_piv,
|
||||
const qpms_scatsys_t *ss, qpms_iri_t iri, double k){
|
||||
target = qpms_scatsys_build_modeproblem_matrix_irrep_packed(target, ss, iri, k);
|
||||
return qpms_scatsys_modeproblem_matrix_irrep_packed_factorise(target, target_piv, ss, iri);
|
||||
}
|
||||
|
||||
complex double *qpms_scatsys_scatter_solve(
|
||||
complex double *f, const complex double *a_inc, qpms_ss_LU lu) {
|
||||
const size_t n = lu.full ? lu.ss->fecv_size : lu.ss->saecv_sizes[lu.iri];
|
||||
if (!f) QPMS_CRASHING_MALLOC(f, n * sizeof(complex double));
|
||||
memcpy(f, a_inc, n*sizeof(complex double)); // It will be rewritten by zgetrs
|
||||
QPMS_ENSURE_SUCCESS(LAPACKE_zgetrs(LAPACK_ROW_MAJOR, 'N' /*trans*/, n /*n*/, 1 /*nrhs number of right hand sides*/,
|
||||
lu.a /*a*/, n /*lda*/, lu.ipiv /*ipiv*/, f/*b*/, 1 /*ldb; CHECKME*/));
|
||||
return f;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
#include "vswf.h"
|
||||
#include <stdbool.h>
|
||||
|
||||
|
||||
/// Overrides the number of threads spawned by the paralellized functions.
|
||||
/** TODO MORE DOC which are those? */
|
||||
void qpms_scatsystem_set_nthreads(long n);
|
||||
|
@ -297,13 +296,53 @@ complex double *qpms_scatsys_build_modeproblem_matrix_irrep_packed_orbitorder_pa
|
|||
|
||||
/// LU factorisation (LAPACKE_zgetrf) result holder.
|
||||
typedef struct qpms_ss_LU {
|
||||
const qpms_scatsys_t *ss;
|
||||
bool full; ///< true if full matrix; false if irrep-packed.
|
||||
qpms_iri_t iri; ///< Irrep index if `full == false`.
|
||||
/// LU decomposition array.
|
||||
complex double *a;
|
||||
/// Pivot index array, size at least max(1,min(m, n)).
|
||||
lapack_int *ipiv;
|
||||
int *ipiv;
|
||||
} qpms_ss_LU;
|
||||
void qpms_ss_LU_free(qpms_ss_LU);
|
||||
|
||||
/// Builds an LU-factorised mode/scattering problem \f$ (I - TS) \f$ matrix from scratch.
|
||||
qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_full_LU(
|
||||
complex double *target, ///< Pre-allocated target array. Optional (if NULL, new one is allocated).
|
||||
int *target_piv, ///< Pre-allocated pivot array. Optional (if NULL, new one is allocated).
|
||||
const qpms_scatsys_t *ss,
|
||||
/*COMPLEXIFY*/ double k ///< Wave number to use in the translation matrix.
|
||||
);
|
||||
|
||||
/// Builds an irrep-packed LU-factorised mode/scattering problem matrix from scratch.
|
||||
qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_irrep_packed_LU(
|
||||
complex double *target, ///< Pre-allocated target array. Optional (if NULL, new one is allocated).
|
||||
int *target_piv, ///< Pre-allocated pivot array. Optional (if NULL, new one is allocated).
|
||||
const qpms_scatsys_t *ss, qpms_iri_t iri,
|
||||
/*COMPLEXIFY*/ double k ///< Wave number to use in the translation matrix.
|
||||
);
|
||||
|
||||
/// Computes LU factorisation of a pre-calculated mode/scattering problem matrix, replacing its contents.
|
||||
qpms_ss_LU qpms_scatsys_modeproblem_matrix_full_factorise(
|
||||
complex double *modeproblem_matrix_full, ///< Pre-calculated mode problem matrix (I-TS). Mandatory.
|
||||
int *target_piv, ///< Pre-allocated pivot array. Optional (if NULL, new one is allocated).
|
||||
const qpms_scatsys_t *ss
|
||||
);
|
||||
|
||||
/// Computes LU factorisation of a pre-calculated irrep-packed mode/scattering problem matrix, replacing its contents.
|
||||
qpms_ss_LU qpms_scatsys_modeproblem_matrix_irrep_packed_factorise(
|
||||
complex double *modeproblem_matrix_irrep_packed, ///< Pre-calculated mode problem matrix (I-TS). Mandatory.
|
||||
int *target_piv, ///< Pre-allocated pivot array. Optional (if NULL, new one is allocated).
|
||||
const qpms_scatsys_t *ss, qpms_iri_t iri
|
||||
);
|
||||
|
||||
/// Solves a (possibly partial, irrep-packed) scattering problem \f$ (I-TS)f = Ta_\mathrm{inc} \f$ using a pre-factorised \f$ (I-TS) \f$.
|
||||
complex double *qpms_scatsys_scatter_solve(
|
||||
complex double *target_f, ///< Target (full or irrep-packed, depending on `ludata.full`) array for \a f. If NULL, a new one is allocated.
|
||||
const complex double *a_inc, ///< Incident field expansion coefficient vector \a a (full or irrep-packed, depending on `ludata.full`).
|
||||
qpms_ss_LU ludata ///< Pre-factorised \f$ I - TS \f$ matrix data.
|
||||
);
|
||||
|
||||
void qpms_ss_LU_free(qpms_ss_LU *);
|
||||
|
||||
/// NOT IMPLEMENTED Dumps a qpms_scatsys_t structure to a file.
|
||||
qpms_errno_t qpms_scatsys_dump(qpms_scatsys_t *ss, char *path);
|
||||
|
|
Loading…
Reference in New Issue