Solve scattering problem using LU decomposition.

Former-commit-id: 32388ecc2da5d60a3b3616a765c0ecaed87fb4fd
This commit is contained in:
Marek Nečada 2019-07-23 15:50:31 +03:00
parent 5758c5d587
commit 0ffee1a073
4 changed files with 175 additions and 5 deletions

View File

@ -1681,6 +1681,8 @@ cdef class ScatteringSystem:
return ar
def planewave_full(self, k_cart, E_cart):
k_cart = np.array(k_cart)
E_cart = np.array(E_cart)
if k_cart.shape != (3,) or E_cart.shape != (3,):
raise ValueError("k_cart and E_cart must be ndarrays of shape (3,)")
cdef qpms_incfield_planewave_params_t p
@ -1709,7 +1711,57 @@ cdef class ScatteringSystem:
cdef cdouble[::1] target_view = target_np
qpms_scatsys_apply_Tmatrices_full(&target_view[0], &a_view[0], self.s)
return target_np
cdef qpms_scatsys_t *rawpointer(self):
return self.s
def scatter_solver(self, double k, iri=None):
return ScatteringMatrix(self, k, iri)
cdef class ScatteringMatrix:
'''
Wrapper over the C qpms_ss_LU structure that keeps the factorised mode problem matrix.
'''
cdef ScatteringSystem ss # Here we keep the reference to the parent scattering system
cdef qpms_ss_LU lu
def __cinit__(self, ScatteringSystem ss, double k, iri=None):
self.ss = ss
# TODO? pre-allocate the matrix with numpy to make it transparent?
if iri is None:
self.lu = qpms_scatsys_build_modeproblem_matrix_full_LU(
NULL, NULL, ss.rawpointer(), k)
else:
self.lu = qpms_scatsys_build_modeproblem_matrix_irrep_packed_LU(
NULL, NULL, ss.rawpointer(), iri, k)
def __dealloc__(self):
qpms_ss_LU_free(self.lu)
property iri:
def __get__(self):
return None if self.lu.full else self.lu.iri
def __call__(self, a_inc):
cdef size_t vlen
cdef qpms_iri_t iri = -1;
if self.lu.full:
vlen = self.lu.ss[0].fecv_size
if len(a_inc) != vlen:
raise ValueError("Length of a full coefficient vector has to be %d, not %d"
% (vlen, len(a_inc)))
else:
iri = self.lu.iri
vlen = self.lu.ss[0].saecv_sizes[iri]
if len(a_inc) != vlen:
raise ValueError("Length of a %d. irrep packed coefficient vector has to be %d, not %d"
% (iri, vlen, len(a_inc)))
a_inc = np.array(a_inc, dtype=complex, copy=False, order='C')
cdef const cdouble[::1] a_view = a_inc;
cdef np.ndarray f = np.empty((vlen,), dtype=complex, order='C')
cdef cdouble[::1] f_view = f
qpms_scatsys_scatter_solve(&f_view[0], &a_view[0], self.lu)
return f
def tlm2uvswfi(t, l, m):

View File

@ -398,6 +398,20 @@ cdef extern from "scatsystem.h":
const void *args, bint add)
cdouble *qpms_scatsys_apply_Tmatrices_full(cdouble *target_full, const cdouble *inc_full,
const qpms_scatsys_t *ss)
struct qpms_ss_LU:
const qpms_scatsys_t *ss
bint full
qpms_iri_t iri
cdouble *a
int *ipiv
void qpms_ss_LU_free(qpms_ss_LU lu)
qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_full_LU(cdouble *target,
int *target_piv, const qpms_scatsys_t *ss, double k)
qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_irrep_packed_LU(cdouble *target,
int *target_piv, const qpms_scatsys_t *ss, qpms_iri_t iri, double k)
qpms_ss_LU qpms_scatsys_modeproblem_matrix_full_factorise(cdouble *modeproblem_matrix_full,
int *target_piv, const qpms_scatsys_t *ss)
qpms_ss_LU qpms_scatsys_modeproblem_matrix_irrep_packed_factorise(cdouble *modeproblem_matrix_full,
int *target_piv, const qpms_scatsys_t *ss, qpms_iri_t iri)
cdouble *qpms_scatsys_scatter_solve(cdouble *target_f, const cdouble *a_inc, qpms_ss_LU ludata)

View File

@ -1,4 +1,9 @@
#include <stdlib.h>
#define lapack_int int
#define lapack_complex_double complex double
#define lapack_complex_double_real(z) (creal(z))
#define lapack_complex_double_imag(z) (cimag(z))
#include <lapacke.h>
#include <cblas.h>
#include <lapacke.h>
#include "scatsystem.h"
@ -1738,3 +1743,63 @@ ccart3_t qpms_scatsys_eval_E_irrep(const qpms_scatsys_t *ss,
}
#endif
void qpms_ss_LU_free(qpms_ss_LU lu) {
free(lu.a);
free(lu.ipiv);
}
qpms_ss_LU qpms_scatsys_modeproblem_matrix_full_factorise(complex double *mpmatrix_full,
int *target_piv, const qpms_scatsys_t *ss) {
QPMS_ENSURE(mpmatrix_full, "A non-NULL pointer to the pre-calculated mode matrix is required");
if (!target_piv) QPMS_CRASHING_MALLOC(target_piv, ss->fecv_size * sizeof(int));
QPMS_ENSURE_SUCCESS(LAPACKE_zgetrf(LAPACK_ROW_MAJOR, ss->fecv_size, ss->fecv_size,
mpmatrix_full, ss->fecv_size, target_piv));
qpms_ss_LU lu;
lu.a = mpmatrix_full;
lu.ipiv = target_piv;
lu.ss = ss;
lu.full = true;
lu.iri = -1;
return lu;
}
qpms_ss_LU qpms_scatsys_modeproblem_matrix_irrep_packed_factorise(complex double *mpmatrix_packed,
int *target_piv, const qpms_scatsys_t *ss, qpms_iri_t iri) {
QPMS_ENSURE(mpmatrix_packed, "A non-NULL pointer to the pre-calculated mode matrix is required");
size_t n = ss->saecv_sizes[iri];
if (!target_piv) QPMS_CRASHING_MALLOC(target_piv, n * sizeof(int));
QPMS_ENSURE_SUCCESS(LAPACKE_zgetrf(LAPACK_ROW_MAJOR, n, n,
mpmatrix_packed, n, target_piv));
qpms_ss_LU lu;
lu.a = mpmatrix_packed;
lu.ipiv = target_piv;
lu.ss = ss;
lu.full = false;
lu.iri = iri;
return lu;
}
qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_full_LU(
complex double *target, int *target_piv,
const qpms_scatsys_t *ss, double k){
target = qpms_scatsys_build_modeproblem_matrix_full(target, ss, k);
return qpms_scatsys_modeproblem_matrix_full_factorise(target, target_piv, ss);
}
qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_irrep_packed_LU(
complex double *target, int *target_piv,
const qpms_scatsys_t *ss, qpms_iri_t iri, double k){
target = qpms_scatsys_build_modeproblem_matrix_irrep_packed(target, ss, iri, k);
return qpms_scatsys_modeproblem_matrix_irrep_packed_factorise(target, target_piv, ss, iri);
}
complex double *qpms_scatsys_scatter_solve(
complex double *f, const complex double *a_inc, qpms_ss_LU lu) {
const size_t n = lu.full ? lu.ss->fecv_size : lu.ss->saecv_sizes[lu.iri];
if (!f) QPMS_CRASHING_MALLOC(f, n * sizeof(complex double));
memcpy(f, a_inc, n*sizeof(complex double)); // It will be rewritten by zgetrs
QPMS_ENSURE_SUCCESS(LAPACKE_zgetrs(LAPACK_ROW_MAJOR, 'N' /*trans*/, n /*n*/, 1 /*nrhs number of right hand sides*/,
lu.a /*a*/, n /*lda*/, lu.ipiv /*ipiv*/, f/*b*/, 1 /*ldb; CHECKME*/));
return f;
}

View File

@ -14,7 +14,6 @@
#include "vswf.h"
#include <stdbool.h>
/// Overrides the number of threads spawned by the paralellized functions.
/** TODO MORE DOC which are those? */
void qpms_scatsystem_set_nthreads(long n);
@ -297,13 +296,53 @@ complex double *qpms_scatsys_build_modeproblem_matrix_irrep_packed_orbitorder_pa
/// LU factorisation (LAPACKE_zgetrf) result holder.
typedef struct qpms_ss_LU {
const qpms_scatsys_t *ss;
bool full; ///< true if full matrix; false if irrep-packed.
qpms_iri_t iri; ///< Irrep index if `full == false`.
/// LU decomposition array.
complex double *a;
/// Pivot index array, size at least max(1,min(m, n)).
lapack_int *ipiv;
int *ipiv;
} qpms_ss_LU;
void qpms_ss_LU_free(qpms_ss_LU);
/// Builds an LU-factorised mode/scattering problem \f$ (I - TS) \f$ matrix from scratch.
qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_full_LU(
complex double *target, ///< Pre-allocated target array. Optional (if NULL, new one is allocated).
int *target_piv, ///< Pre-allocated pivot array. Optional (if NULL, new one is allocated).
const qpms_scatsys_t *ss,
/*COMPLEXIFY*/ double k ///< Wave number to use in the translation matrix.
);
/// Builds an irrep-packed LU-factorised mode/scattering problem matrix from scratch.
qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_irrep_packed_LU(
complex double *target, ///< Pre-allocated target array. Optional (if NULL, new one is allocated).
int *target_piv, ///< Pre-allocated pivot array. Optional (if NULL, new one is allocated).
const qpms_scatsys_t *ss, qpms_iri_t iri,
/*COMPLEXIFY*/ double k ///< Wave number to use in the translation matrix.
);
/// Computes LU factorisation of a pre-calculated mode/scattering problem matrix, replacing its contents.
qpms_ss_LU qpms_scatsys_modeproblem_matrix_full_factorise(
complex double *modeproblem_matrix_full, ///< Pre-calculated mode problem matrix (I-TS). Mandatory.
int *target_piv, ///< Pre-allocated pivot array. Optional (if NULL, new one is allocated).
const qpms_scatsys_t *ss
);
/// Computes LU factorisation of a pre-calculated irrep-packed mode/scattering problem matrix, replacing its contents.
qpms_ss_LU qpms_scatsys_modeproblem_matrix_irrep_packed_factorise(
complex double *modeproblem_matrix_irrep_packed, ///< Pre-calculated mode problem matrix (I-TS). Mandatory.
int *target_piv, ///< Pre-allocated pivot array. Optional (if NULL, new one is allocated).
const qpms_scatsys_t *ss, qpms_iri_t iri
);
/// Solves a (possibly partial, irrep-packed) scattering problem \f$ (I-TS)f = Ta_\mathrm{inc} \f$ using a pre-factorised \f$ (I-TS) \f$.
complex double *qpms_scatsys_scatter_solve(
complex double *target_f, ///< Target (full or irrep-packed, depending on `ludata.full`) array for \a f. If NULL, a new one is allocated.
const complex double *a_inc, ///< Incident field expansion coefficient vector \a a (full or irrep-packed, depending on `ludata.full`).
qpms_ss_LU ludata ///< Pre-factorised \f$ I - TS \f$ matrix data.
);
void qpms_ss_LU_free(qpms_ss_LU *);
/// NOT IMPLEMENTED Dumps a qpms_scatsys_t structure to a file.
qpms_errno_t qpms_scatsys_dump(qpms_scatsys_t *ss, char *path);