Attempt to parallelize scattered_E cython methods.

This commit is contained in:
Marek Nečada 2020-07-03 14:20:07 +03:00
parent 1328077490
commit 17824b062e
2 changed files with 45 additions and 40 deletions

View File

@ -17,6 +17,8 @@ from .cymaterials cimport EpsMuGenerator, EpsMu
from libc.stdlib cimport malloc, free, calloc
import warnings
from cython.parallel import prange, parallel
from cython import boundscheck, wraparound
# Set custom GSL error handler. N.B. this is obviously not thread-safe.
cdef char *pgsl_err_reason
@ -901,6 +903,7 @@ cdef class ScatteringSystem:
return retdict
@boundscheck(False)
def scattered_E(self, cdouble wavenumber, scatcoeffvector_full, evalpos, bint alt=False, btyp=BesselType.HANKEL_PLUS):
cdef qpms_bessel_t btyp_c = BesselType(btyp)
evalpos = np.array(evalpos, dtype=float, copy=False)
@ -912,8 +915,9 @@ cdef class ScatteringSystem:
cdef np.ndarray[complex, ndim=2] results = np.empty((evalpos_a.shape[0],3), dtype=complex)
cdef ccart3_t res
cdef cart3_t pos
cdef size_t i
for i in range(evalpos_a.shape[0]):
cdef Py_ssize_t i
with nogil, parallel():
for i in prange(evalpos_a.shape[0]):
pos.x = evalpos_a[i,0]
pos.y = evalpos_a[i,1]
pos.z = evalpos_a[i,2]
@ -971,6 +975,7 @@ cdef class _ScatteringSystemAtOmegaK:
def __set__(self, double eta):
self.sswk.eta = eta
@boundscheck(False)
def scattered_E(self, scatcoeffvector_full, evalpos, btyp=QPMS_HANKEL_PLUS): # TODO DOC!!!
if(btyp != QPMS_HANKEL_PLUS):
raise NotImplementedError("Only first kind Bessel function-based fields are supported")
@ -984,8 +989,9 @@ cdef class _ScatteringSystemAtOmegaK:
cdef np.ndarray[complex, ndim=2] results = np.empty((evalpos_a.shape[0],3), dtype=complex)
cdef ccart3_t res
cdef cart3_t pos
cdef size_t i
for i in range(evalpos_a.shape[0]):
cdef Py_ssize_t i
with nogil, wraparound(False), parallel():
for i in prange(evalpos_a.shape[0]):
pos.x = evalpos_a[i,0]
pos.y = evalpos_a[i,1]
pos.z = evalpos_a[i,2]
@ -1126,7 +1132,7 @@ cdef class _ScatteringSystemAtOmega:
cdef np.ndarray[complex, ndim=2] results = np.empty((evalpos_a.shape[0],3), dtype=complex)
cdef ccart3_t res
cdef cart3_t pos
cdef size_t i
cdef Py_ssize_t i
for i in range(evalpos_a.shape[0]):
pos.x = evalpos_a[i,0]
pos.y = evalpos_a[i,1]

View File

@ -658,9 +658,9 @@ cdef extern from "scatsystem.h":
cdouble *target, const qpms_scatsys_at_omega_t *ssw, qpms_iri_t iri) nogil
cdouble *qpms_scatsys_incident_field_vector_full(cdouble *target_full,
const qpms_scatsys_t *ss, qpms_incfield_t field_at_point,
const void *args, bint add)
const void *args, bint add) nogil
cdouble *qpms_scatsysw_apply_Tmatrices_full(cdouble *target_full, const cdouble *inc_full,
const qpms_scatsys_at_omega_t *ssw)
const qpms_scatsys_at_omega_t *ssw) nogil
struct qpms_ss_LU:
const qpms_scatsys_at_omega_t *ssw
const qpms_scatsys_at_omega_k_t *sswk
@ -670,19 +670,19 @@ cdef extern from "scatsystem.h":
int *ipiv
void qpms_ss_LU_free(qpms_ss_LU lu)
qpms_ss_LU qpms_scatsysw_build_modeproblem_matrix_full_LU(cdouble *target,
int *target_piv, const qpms_scatsys_at_omega_t *ssw)
int *target_piv, const qpms_scatsys_at_omega_t *ssw) nogil
qpms_ss_LU qpms_scatsysw_build_modeproblem_matrix_irrep_packed_LU(cdouble *target,
int *target_piv, const qpms_scatsys_at_omega_t *ssw, qpms_iri_t iri)
int *target_piv, const qpms_scatsys_at_omega_t *ssw, qpms_iri_t iri) nogil
qpms_ss_LU qpms_scatsysw_modeproblem_matrix_full_factorise(cdouble *modeproblem_matrix_full,
int *target_piv, const qpms_scatsys_at_omega_t *ssw)
int *target_piv, const qpms_scatsys_at_omega_t *ssw) nogil
qpms_ss_LU qpms_scatsysw_modeproblem_matrix_irrep_packed_factorise(cdouble *modeproblem_matrix_full,
int *target_piv, const qpms_scatsys_at_omega_t *ssw, qpms_iri_t iri)
cdouble *qpms_scatsys_scatter_solve(cdouble *target_f, const cdouble *a_inc, qpms_ss_LU ludata)
const qpms_vswf_set_spec_t *qpms_ss_bspec_tmi(const qpms_scatsys_t *ss, qpms_ss_tmi_t tmi)
const qpms_vswf_set_spec_t *qpms_ss_bspec_pi(const qpms_scatsys_t *ss, qpms_ss_pi_t pi)
int *target_piv, const qpms_scatsys_at_omega_t *ssw, qpms_iri_t iri) nogil
cdouble *qpms_scatsys_scatter_solve(cdouble *target_f, const cdouble *a_inc, qpms_ss_LU ludata) nogil
const qpms_vswf_set_spec_t *qpms_ss_bspec_tmi(const qpms_scatsys_t *ss, qpms_ss_tmi_t tmi) nogil
const qpms_vswf_set_spec_t *qpms_ss_bspec_pi(const qpms_scatsys_t *ss, qpms_ss_pi_t pi) nogil
beyn_result_t *qpms_scatsys_finite_find_eigenmodes(const qpms_scatsys_t *ss, qpms_iri_t iri,
cdouble omega_centre, double omega_rr, double omega_ri, size_t contour_npoints,
double rank_tol, size_t rank_sel_min, double res_tol)
double rank_tol, size_t rank_sel_min, double res_tol) nogil
# periodic-related funs
struct qpms_scatsys_at_omega_k_t:
const qpms_scatsys_at_omega_t *ssw
@ -693,19 +693,18 @@ cdef extern from "scatsystem.h":
qpms_ss_LU qpms_scatsyswk_build_modeproblem_matrix_full_LU(cdouble *target, int *target_piv, const qpms_scatsys_at_omega_k_t *sswk)
beyn_result_t *qpms_scatsys_periodic_find_eigenmodes(const qpms_scatsys_t *ss, const double *k,
cdouble omega_centre, double omega_rr, double omega_ri, size_t contour_npoints,
double rank_tol, size_t rank_sel_min, double res_tol)
const qpms_vswf_set_spec_t *qpms_ss_bspec_pi(const qpms_scatsys_t *ss, qpms_ss_pi_t pi)
double rank_tol, size_t rank_sel_min, double res_tol) nogil
ccart3_t qpms_scatsys_scattered_E(const qpms_scatsys_t *ss, qpms_bessel_t btyp, cdouble wavenumber,
const cdouble *f_excitation_vector_full, cart3_t where)
const cdouble *f_excitation_vector_full, cart3_t where) nogil
ccart3_t qpms_scatsysw_scattered_E(const qpms_scatsys_at_omega_t *ssw, qpms_bessel_t btyp,
const cdouble *f_excitation_vector_full, cart3_t where)
const cdouble *f_excitation_vector_full, cart3_t where) nogil
ccart3_t qpms_scatsys_scattered_E__alt(const qpms_scatsys_t *ss, qpms_bessel_t btyp, cdouble wavenumber,
const cdouble *f_excitation_vector_full, cart3_t where)
const cdouble *f_excitation_vector_full, cart3_t where) nogil
ccart3_t qpms_scatsysw_scattered_E__alt(const qpms_scatsys_at_omega_t *ssw, qpms_bessel_t btyp,
const cdouble *f_excitation_vector_full, cart3_t where)
const cdouble *f_excitation_vector_full, cart3_t where) nogil
ccart3_t qpms_scatsyswk_scattered_E(const qpms_scatsys_at_omega_k_t *sswk, qpms_bessel_t btyp,
const cdouble *f_excitation_vector_full, cart3_t where)
double qpms_ss_adjusted_eta(const qpms_scatsys_t *ss, cdouble wavenumber, const double *wavevector);
const cdouble *f_excitation_vector_full, cart3_t where) nogil
double qpms_ss_adjusted_eta(const qpms_scatsys_t *ss, cdouble wavenumber, const double *wavevector) nogil
cdef extern from "ewald.h":