Solve scattering problem using LU decomposition.
Former-commit-id: 32388ecc2da5d60a3b3616a765c0ecaed87fb4fd
This commit is contained in:
parent
5758c5d587
commit
0ffee1a073
|
@ -1681,6 +1681,8 @@ cdef class ScatteringSystem:
|
||||||
return ar
|
return ar
|
||||||
|
|
||||||
def planewave_full(self, k_cart, E_cart):
|
def planewave_full(self, k_cart, E_cart):
|
||||||
|
k_cart = np.array(k_cart)
|
||||||
|
E_cart = np.array(E_cart)
|
||||||
if k_cart.shape != (3,) or E_cart.shape != (3,):
|
if k_cart.shape != (3,) or E_cart.shape != (3,):
|
||||||
raise ValueError("k_cart and E_cart must be ndarrays of shape (3,)")
|
raise ValueError("k_cart and E_cart must be ndarrays of shape (3,)")
|
||||||
cdef qpms_incfield_planewave_params_t p
|
cdef qpms_incfield_planewave_params_t p
|
||||||
|
@ -1709,7 +1711,57 @@ cdef class ScatteringSystem:
|
||||||
cdef cdouble[::1] target_view = target_np
|
cdef cdouble[::1] target_view = target_np
|
||||||
qpms_scatsys_apply_Tmatrices_full(&target_view[0], &a_view[0], self.s)
|
qpms_scatsys_apply_Tmatrices_full(&target_view[0], &a_view[0], self.s)
|
||||||
return target_np
|
return target_np
|
||||||
|
|
||||||
|
cdef qpms_scatsys_t *rawpointer(self):
|
||||||
|
return self.s
|
||||||
|
|
||||||
|
def scatter_solver(self, double k, iri=None):
|
||||||
|
return ScatteringMatrix(self, k, iri)
|
||||||
|
|
||||||
|
cdef class ScatteringMatrix:
|
||||||
|
'''
|
||||||
|
Wrapper over the C qpms_ss_LU structure that keeps the factorised mode problem matrix.
|
||||||
|
'''
|
||||||
|
cdef ScatteringSystem ss # Here we keep the reference to the parent scattering system
|
||||||
|
cdef qpms_ss_LU lu
|
||||||
|
|
||||||
|
def __cinit__(self, ScatteringSystem ss, double k, iri=None):
|
||||||
|
self.ss = ss
|
||||||
|
# TODO? pre-allocate the matrix with numpy to make it transparent?
|
||||||
|
if iri is None:
|
||||||
|
self.lu = qpms_scatsys_build_modeproblem_matrix_full_LU(
|
||||||
|
NULL, NULL, ss.rawpointer(), k)
|
||||||
|
else:
|
||||||
|
self.lu = qpms_scatsys_build_modeproblem_matrix_irrep_packed_LU(
|
||||||
|
NULL, NULL, ss.rawpointer(), iri, k)
|
||||||
|
|
||||||
|
def __dealloc__(self):
|
||||||
|
qpms_ss_LU_free(self.lu)
|
||||||
|
|
||||||
|
property iri:
|
||||||
|
def __get__(self):
|
||||||
|
return None if self.lu.full else self.lu.iri
|
||||||
|
|
||||||
|
def __call__(self, a_inc):
|
||||||
|
cdef size_t vlen
|
||||||
|
cdef qpms_iri_t iri = -1;
|
||||||
|
if self.lu.full:
|
||||||
|
vlen = self.lu.ss[0].fecv_size
|
||||||
|
if len(a_inc) != vlen:
|
||||||
|
raise ValueError("Length of a full coefficient vector has to be %d, not %d"
|
||||||
|
% (vlen, len(a_inc)))
|
||||||
|
else:
|
||||||
|
iri = self.lu.iri
|
||||||
|
vlen = self.lu.ss[0].saecv_sizes[iri]
|
||||||
|
if len(a_inc) != vlen:
|
||||||
|
raise ValueError("Length of a %d. irrep packed coefficient vector has to be %d, not %d"
|
||||||
|
% (iri, vlen, len(a_inc)))
|
||||||
|
a_inc = np.array(a_inc, dtype=complex, copy=False, order='C')
|
||||||
|
cdef const cdouble[::1] a_view = a_inc;
|
||||||
|
cdef np.ndarray f = np.empty((vlen,), dtype=complex, order='C')
|
||||||
|
cdef cdouble[::1] f_view = f
|
||||||
|
qpms_scatsys_scatter_solve(&f_view[0], &a_view[0], self.lu)
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
def tlm2uvswfi(t, l, m):
|
def tlm2uvswfi(t, l, m):
|
||||||
|
|
|
@ -398,6 +398,20 @@ cdef extern from "scatsystem.h":
|
||||||
const void *args, bint add)
|
const void *args, bint add)
|
||||||
cdouble *qpms_scatsys_apply_Tmatrices_full(cdouble *target_full, const cdouble *inc_full,
|
cdouble *qpms_scatsys_apply_Tmatrices_full(cdouble *target_full, const cdouble *inc_full,
|
||||||
const qpms_scatsys_t *ss)
|
const qpms_scatsys_t *ss)
|
||||||
|
struct qpms_ss_LU:
|
||||||
|
const qpms_scatsys_t *ss
|
||||||
|
bint full
|
||||||
|
qpms_iri_t iri
|
||||||
|
cdouble *a
|
||||||
|
int *ipiv
|
||||||
|
void qpms_ss_LU_free(qpms_ss_LU lu)
|
||||||
|
qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_full_LU(cdouble *target,
|
||||||
|
int *target_piv, const qpms_scatsys_t *ss, double k)
|
||||||
|
qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_irrep_packed_LU(cdouble *target,
|
||||||
|
int *target_piv, const qpms_scatsys_t *ss, qpms_iri_t iri, double k)
|
||||||
|
qpms_ss_LU qpms_scatsys_modeproblem_matrix_full_factorise(cdouble *modeproblem_matrix_full,
|
||||||
|
int *target_piv, const qpms_scatsys_t *ss)
|
||||||
|
qpms_ss_LU qpms_scatsys_modeproblem_matrix_irrep_packed_factorise(cdouble *modeproblem_matrix_full,
|
||||||
|
int *target_piv, const qpms_scatsys_t *ss, qpms_iri_t iri)
|
||||||
|
cdouble *qpms_scatsys_scatter_solve(cdouble *target_f, const cdouble *a_inc, qpms_ss_LU ludata)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,9 @@
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
|
#define lapack_int int
|
||||||
|
#define lapack_complex_double complex double
|
||||||
|
#define lapack_complex_double_real(z) (creal(z))
|
||||||
|
#define lapack_complex_double_imag(z) (cimag(z))
|
||||||
|
#include <lapacke.h>
|
||||||
#include <cblas.h>
|
#include <cblas.h>
|
||||||
#include <lapacke.h>
|
#include <lapacke.h>
|
||||||
#include "scatsystem.h"
|
#include "scatsystem.h"
|
||||||
|
@ -1738,3 +1743,63 @@ ccart3_t qpms_scatsys_eval_E_irrep(const qpms_scatsys_t *ss,
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
void qpms_ss_LU_free(qpms_ss_LU lu) {
|
||||||
|
free(lu.a);
|
||||||
|
free(lu.ipiv);
|
||||||
|
}
|
||||||
|
|
||||||
|
qpms_ss_LU qpms_scatsys_modeproblem_matrix_full_factorise(complex double *mpmatrix_full,
|
||||||
|
int *target_piv, const qpms_scatsys_t *ss) {
|
||||||
|
QPMS_ENSURE(mpmatrix_full, "A non-NULL pointer to the pre-calculated mode matrix is required");
|
||||||
|
if (!target_piv) QPMS_CRASHING_MALLOC(target_piv, ss->fecv_size * sizeof(int));
|
||||||
|
QPMS_ENSURE_SUCCESS(LAPACKE_zgetrf(LAPACK_ROW_MAJOR, ss->fecv_size, ss->fecv_size,
|
||||||
|
mpmatrix_full, ss->fecv_size, target_piv));
|
||||||
|
qpms_ss_LU lu;
|
||||||
|
lu.a = mpmatrix_full;
|
||||||
|
lu.ipiv = target_piv;
|
||||||
|
lu.ss = ss;
|
||||||
|
lu.full = true;
|
||||||
|
lu.iri = -1;
|
||||||
|
return lu;
|
||||||
|
}
|
||||||
|
|
||||||
|
qpms_ss_LU qpms_scatsys_modeproblem_matrix_irrep_packed_factorise(complex double *mpmatrix_packed,
|
||||||
|
int *target_piv, const qpms_scatsys_t *ss, qpms_iri_t iri) {
|
||||||
|
QPMS_ENSURE(mpmatrix_packed, "A non-NULL pointer to the pre-calculated mode matrix is required");
|
||||||
|
size_t n = ss->saecv_sizes[iri];
|
||||||
|
if (!target_piv) QPMS_CRASHING_MALLOC(target_piv, n * sizeof(int));
|
||||||
|
QPMS_ENSURE_SUCCESS(LAPACKE_zgetrf(LAPACK_ROW_MAJOR, n, n,
|
||||||
|
mpmatrix_packed, n, target_piv));
|
||||||
|
qpms_ss_LU lu;
|
||||||
|
lu.a = mpmatrix_packed;
|
||||||
|
lu.ipiv = target_piv;
|
||||||
|
lu.ss = ss;
|
||||||
|
lu.full = false;
|
||||||
|
lu.iri = iri;
|
||||||
|
return lu;
|
||||||
|
}
|
||||||
|
|
||||||
|
qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_full_LU(
|
||||||
|
complex double *target, int *target_piv,
|
||||||
|
const qpms_scatsys_t *ss, double k){
|
||||||
|
target = qpms_scatsys_build_modeproblem_matrix_full(target, ss, k);
|
||||||
|
return qpms_scatsys_modeproblem_matrix_full_factorise(target, target_piv, ss);
|
||||||
|
}
|
||||||
|
|
||||||
|
qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_irrep_packed_LU(
|
||||||
|
complex double *target, int *target_piv,
|
||||||
|
const qpms_scatsys_t *ss, qpms_iri_t iri, double k){
|
||||||
|
target = qpms_scatsys_build_modeproblem_matrix_irrep_packed(target, ss, iri, k);
|
||||||
|
return qpms_scatsys_modeproblem_matrix_irrep_packed_factorise(target, target_piv, ss, iri);
|
||||||
|
}
|
||||||
|
|
||||||
|
complex double *qpms_scatsys_scatter_solve(
|
||||||
|
complex double *f, const complex double *a_inc, qpms_ss_LU lu) {
|
||||||
|
const size_t n = lu.full ? lu.ss->fecv_size : lu.ss->saecv_sizes[lu.iri];
|
||||||
|
if (!f) QPMS_CRASHING_MALLOC(f, n * sizeof(complex double));
|
||||||
|
memcpy(f, a_inc, n*sizeof(complex double)); // It will be rewritten by zgetrs
|
||||||
|
QPMS_ENSURE_SUCCESS(LAPACKE_zgetrs(LAPACK_ROW_MAJOR, 'N' /*trans*/, n /*n*/, 1 /*nrhs number of right hand sides*/,
|
||||||
|
lu.a /*a*/, n /*lda*/, lu.ipiv /*ipiv*/, f/*b*/, 1 /*ldb; CHECKME*/));
|
||||||
|
return f;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
#include "vswf.h"
|
#include "vswf.h"
|
||||||
#include <stdbool.h>
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
|
||||||
/// Overrides the number of threads spawned by the paralellized functions.
|
/// Overrides the number of threads spawned by the paralellized functions.
|
||||||
/** TODO MORE DOC which are those? */
|
/** TODO MORE DOC which are those? */
|
||||||
void qpms_scatsystem_set_nthreads(long n);
|
void qpms_scatsystem_set_nthreads(long n);
|
||||||
|
@ -297,13 +296,53 @@ complex double *qpms_scatsys_build_modeproblem_matrix_irrep_packed_orbitorder_pa
|
||||||
|
|
||||||
/// LU factorisation (LAPACKE_zgetrf) result holder.
|
/// LU factorisation (LAPACKE_zgetrf) result holder.
|
||||||
typedef struct qpms_ss_LU {
|
typedef struct qpms_ss_LU {
|
||||||
|
const qpms_scatsys_t *ss;
|
||||||
|
bool full; ///< true if full matrix; false if irrep-packed.
|
||||||
|
qpms_iri_t iri; ///< Irrep index if `full == false`.
|
||||||
/// LU decomposition array.
|
/// LU decomposition array.
|
||||||
complex double *a;
|
complex double *a;
|
||||||
/// Pivot index array, size at least max(1,min(m, n)).
|
/// Pivot index array, size at least max(1,min(m, n)).
|
||||||
lapack_int *ipiv;
|
int *ipiv;
|
||||||
} qpms_ss_LU;
|
} qpms_ss_LU;
|
||||||
|
void qpms_ss_LU_free(qpms_ss_LU);
|
||||||
|
|
||||||
|
/// Builds an LU-factorised mode/scattering problem \f$ (I - TS) \f$ matrix from scratch.
|
||||||
|
qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_full_LU(
|
||||||
|
complex double *target, ///< Pre-allocated target array. Optional (if NULL, new one is allocated).
|
||||||
|
int *target_piv, ///< Pre-allocated pivot array. Optional (if NULL, new one is allocated).
|
||||||
|
const qpms_scatsys_t *ss,
|
||||||
|
/*COMPLEXIFY*/ double k ///< Wave number to use in the translation matrix.
|
||||||
|
);
|
||||||
|
|
||||||
|
/// Builds an irrep-packed LU-factorised mode/scattering problem matrix from scratch.
|
||||||
|
qpms_ss_LU qpms_scatsys_build_modeproblem_matrix_irrep_packed_LU(
|
||||||
|
complex double *target, ///< Pre-allocated target array. Optional (if NULL, new one is allocated).
|
||||||
|
int *target_piv, ///< Pre-allocated pivot array. Optional (if NULL, new one is allocated).
|
||||||
|
const qpms_scatsys_t *ss, qpms_iri_t iri,
|
||||||
|
/*COMPLEXIFY*/ double k ///< Wave number to use in the translation matrix.
|
||||||
|
);
|
||||||
|
|
||||||
|
/// Computes LU factorisation of a pre-calculated mode/scattering problem matrix, replacing its contents.
|
||||||
|
qpms_ss_LU qpms_scatsys_modeproblem_matrix_full_factorise(
|
||||||
|
complex double *modeproblem_matrix_full, ///< Pre-calculated mode problem matrix (I-TS). Mandatory.
|
||||||
|
int *target_piv, ///< Pre-allocated pivot array. Optional (if NULL, new one is allocated).
|
||||||
|
const qpms_scatsys_t *ss
|
||||||
|
);
|
||||||
|
|
||||||
|
/// Computes LU factorisation of a pre-calculated irrep-packed mode/scattering problem matrix, replacing its contents.
|
||||||
|
qpms_ss_LU qpms_scatsys_modeproblem_matrix_irrep_packed_factorise(
|
||||||
|
complex double *modeproblem_matrix_irrep_packed, ///< Pre-calculated mode problem matrix (I-TS). Mandatory.
|
||||||
|
int *target_piv, ///< Pre-allocated pivot array. Optional (if NULL, new one is allocated).
|
||||||
|
const qpms_scatsys_t *ss, qpms_iri_t iri
|
||||||
|
);
|
||||||
|
|
||||||
|
/// Solves a (possibly partial, irrep-packed) scattering problem \f$ (I-TS)f = Ta_\mathrm{inc} \f$ using a pre-factorised \f$ (I-TS) \f$.
|
||||||
|
complex double *qpms_scatsys_scatter_solve(
|
||||||
|
complex double *target_f, ///< Target (full or irrep-packed, depending on `ludata.full`) array for \a f. If NULL, a new one is allocated.
|
||||||
|
const complex double *a_inc, ///< Incident field expansion coefficient vector \a a (full or irrep-packed, depending on `ludata.full`).
|
||||||
|
qpms_ss_LU ludata ///< Pre-factorised \f$ I - TS \f$ matrix data.
|
||||||
|
);
|
||||||
|
|
||||||
void qpms_ss_LU_free(qpms_ss_LU *);
|
|
||||||
|
|
||||||
/// NOT IMPLEMENTED Dumps a qpms_scatsys_t structure to a file.
|
/// NOT IMPLEMENTED Dumps a qpms_scatsys_t structure to a file.
|
||||||
qpms_errno_t qpms_scatsys_dump(qpms_scatsys_t *ss, char *path);
|
qpms_errno_t qpms_scatsys_dump(qpms_scatsys_t *ss, char *path);
|
||||||
|
|
Loading…
Reference in New Issue