Attempt to parallelize scattered_E cython methods.

This commit is contained in:
Marek Nečada 2020-07-03 14:20:07 +03:00
parent 1328077490
commit 17824b062e
2 changed files with 45 additions and 40 deletions

View File

@ -17,6 +17,8 @@ from .cymaterials cimport EpsMuGenerator, EpsMu
from libc.stdlib cimport malloc, free, calloc from libc.stdlib cimport malloc, free, calloc
import warnings import warnings
from cython.parallel import prange, parallel
from cython import boundscheck, wraparound
# Set custom GSL error handler. N.B. this is obviously not thread-safe. # Set custom GSL error handler. N.B. this is obviously not thread-safe.
cdef char *pgsl_err_reason cdef char *pgsl_err_reason
@ -901,6 +903,7 @@ cdef class ScatteringSystem:
return retdict return retdict
@boundscheck(False)
def scattered_E(self, cdouble wavenumber, scatcoeffvector_full, evalpos, bint alt=False, btyp=BesselType.HANKEL_PLUS): def scattered_E(self, cdouble wavenumber, scatcoeffvector_full, evalpos, bint alt=False, btyp=BesselType.HANKEL_PLUS):
cdef qpms_bessel_t btyp_c = BesselType(btyp) cdef qpms_bessel_t btyp_c = BesselType(btyp)
evalpos = np.array(evalpos, dtype=float, copy=False) evalpos = np.array(evalpos, dtype=float, copy=False)
@ -912,18 +915,19 @@ cdef class ScatteringSystem:
cdef np.ndarray[complex, ndim=2] results = np.empty((evalpos_a.shape[0],3), dtype=complex) cdef np.ndarray[complex, ndim=2] results = np.empty((evalpos_a.shape[0],3), dtype=complex)
cdef ccart3_t res cdef ccart3_t res
cdef cart3_t pos cdef cart3_t pos
cdef size_t i cdef Py_ssize_t i
for i in range(evalpos_a.shape[0]): with nogil, parallel():
pos.x = evalpos_a[i,0] for i in prange(evalpos_a.shape[0]):
pos.y = evalpos_a[i,1] pos.x = evalpos_a[i,0]
pos.z = evalpos_a[i,2] pos.y = evalpos_a[i,1]
if alt: pos.z = evalpos_a[i,2]
res = qpms_scatsys_scattered_E__alt(self.s, btyp_c, wavenumber, &scv_view[0], pos) if alt:
else: res = qpms_scatsys_scattered_E__alt(self.s, btyp_c, wavenumber, &scv_view[0], pos)
res = qpms_scatsys_scattered_E(self.s, btyp_c, wavenumber, &scv_view[0], pos) else:
results[i,0] = res.x res = qpms_scatsys_scattered_E(self.s, btyp_c, wavenumber, &scv_view[0], pos)
results[i,1] = res.y results[i,0] = res.x
results[i,2] = res.z results[i,1] = res.y
results[i,2] = res.z
return results.reshape(evalpos.shape) return results.reshape(evalpos.shape)
def empty_lattice_modes_xy(EpsMu epsmu, reciprocal_basis, wavevector, double maxomega): def empty_lattice_modes_xy(EpsMu epsmu, reciprocal_basis, wavevector, double maxomega):
@ -971,6 +975,7 @@ cdef class _ScatteringSystemAtOmegaK:
def __set__(self, double eta): def __set__(self, double eta):
self.sswk.eta = eta self.sswk.eta = eta
@boundscheck(False)
def scattered_E(self, scatcoeffvector_full, evalpos, btyp=QPMS_HANKEL_PLUS): # TODO DOC!!! def scattered_E(self, scatcoeffvector_full, evalpos, btyp=QPMS_HANKEL_PLUS): # TODO DOC!!!
if(btyp != QPMS_HANKEL_PLUS): if(btyp != QPMS_HANKEL_PLUS):
raise NotImplementedError("Only first kind Bessel function-based fields are supported") raise NotImplementedError("Only first kind Bessel function-based fields are supported")
@ -984,15 +989,16 @@ cdef class _ScatteringSystemAtOmegaK:
cdef np.ndarray[complex, ndim=2] results = np.empty((evalpos_a.shape[0],3), dtype=complex) cdef np.ndarray[complex, ndim=2] results = np.empty((evalpos_a.shape[0],3), dtype=complex)
cdef ccart3_t res cdef ccart3_t res
cdef cart3_t pos cdef cart3_t pos
cdef size_t i cdef Py_ssize_t i
for i in range(evalpos_a.shape[0]): with nogil, wraparound(False), parallel():
pos.x = evalpos_a[i,0] for i in prange(evalpos_a.shape[0]):
pos.y = evalpos_a[i,1] pos.x = evalpos_a[i,0]
pos.z = evalpos_a[i,2] pos.y = evalpos_a[i,1]
res = qpms_scatsyswk_scattered_E(&self.sswk, btyp_c, &scv_view[0], pos) pos.z = evalpos_a[i,2]
results[i,0] = res.x res = qpms_scatsyswk_scattered_E(&self.sswk, btyp_c, &scv_view[0], pos)
results[i,1] = res.y results[i,0] = res.x
results[i,2] = res.z results[i,1] = res.y
results[i,2] = res.z
return results.reshape(evalpos.shape) return results.reshape(evalpos.shape)
cdef class _ScatteringSystemAtOmega: cdef class _ScatteringSystemAtOmega:
@ -1126,7 +1132,7 @@ cdef class _ScatteringSystemAtOmega:
cdef np.ndarray[complex, ndim=2] results = np.empty((evalpos_a.shape[0],3), dtype=complex) cdef np.ndarray[complex, ndim=2] results = np.empty((evalpos_a.shape[0],3), dtype=complex)
cdef ccart3_t res cdef ccart3_t res
cdef cart3_t pos cdef cart3_t pos
cdef size_t i cdef Py_ssize_t i
for i in range(evalpos_a.shape[0]): for i in range(evalpos_a.shape[0]):
pos.x = evalpos_a[i,0] pos.x = evalpos_a[i,0]
pos.y = evalpos_a[i,1] pos.y = evalpos_a[i,1]

View File

@ -658,9 +658,9 @@ cdef extern from "scatsystem.h":
cdouble *target, const qpms_scatsys_at_omega_t *ssw, qpms_iri_t iri) nogil cdouble *target, const qpms_scatsys_at_omega_t *ssw, qpms_iri_t iri) nogil
cdouble *qpms_scatsys_incident_field_vector_full(cdouble *target_full, cdouble *qpms_scatsys_incident_field_vector_full(cdouble *target_full,
const qpms_scatsys_t *ss, qpms_incfield_t field_at_point, const qpms_scatsys_t *ss, qpms_incfield_t field_at_point,
const void *args, bint add) const void *args, bint add) nogil
cdouble *qpms_scatsysw_apply_Tmatrices_full(cdouble *target_full, const cdouble *inc_full, cdouble *qpms_scatsysw_apply_Tmatrices_full(cdouble *target_full, const cdouble *inc_full,
const qpms_scatsys_at_omega_t *ssw) const qpms_scatsys_at_omega_t *ssw) nogil
struct qpms_ss_LU: struct qpms_ss_LU:
const qpms_scatsys_at_omega_t *ssw const qpms_scatsys_at_omega_t *ssw
const qpms_scatsys_at_omega_k_t *sswk const qpms_scatsys_at_omega_k_t *sswk
@ -670,19 +670,19 @@ cdef extern from "scatsystem.h":
int *ipiv int *ipiv
void qpms_ss_LU_free(qpms_ss_LU lu) void qpms_ss_LU_free(qpms_ss_LU lu)
qpms_ss_LU qpms_scatsysw_build_modeproblem_matrix_full_LU(cdouble *target, qpms_ss_LU qpms_scatsysw_build_modeproblem_matrix_full_LU(cdouble *target,
int *target_piv, const qpms_scatsys_at_omega_t *ssw) int *target_piv, const qpms_scatsys_at_omega_t *ssw) nogil
qpms_ss_LU qpms_scatsysw_build_modeproblem_matrix_irrep_packed_LU(cdouble *target, qpms_ss_LU qpms_scatsysw_build_modeproblem_matrix_irrep_packed_LU(cdouble *target,
int *target_piv, const qpms_scatsys_at_omega_t *ssw, qpms_iri_t iri) int *target_piv, const qpms_scatsys_at_omega_t *ssw, qpms_iri_t iri) nogil
qpms_ss_LU qpms_scatsysw_modeproblem_matrix_full_factorise(cdouble *modeproblem_matrix_full, qpms_ss_LU qpms_scatsysw_modeproblem_matrix_full_factorise(cdouble *modeproblem_matrix_full,
int *target_piv, const qpms_scatsys_at_omega_t *ssw) int *target_piv, const qpms_scatsys_at_omega_t *ssw) nogil
qpms_ss_LU qpms_scatsysw_modeproblem_matrix_irrep_packed_factorise(cdouble *modeproblem_matrix_full, qpms_ss_LU qpms_scatsysw_modeproblem_matrix_irrep_packed_factorise(cdouble *modeproblem_matrix_full,
int *target_piv, const qpms_scatsys_at_omega_t *ssw, qpms_iri_t iri) int *target_piv, const qpms_scatsys_at_omega_t *ssw, qpms_iri_t iri) nogil
cdouble *qpms_scatsys_scatter_solve(cdouble *target_f, const cdouble *a_inc, qpms_ss_LU ludata) cdouble *qpms_scatsys_scatter_solve(cdouble *target_f, const cdouble *a_inc, qpms_ss_LU ludata) nogil
const qpms_vswf_set_spec_t *qpms_ss_bspec_tmi(const qpms_scatsys_t *ss, qpms_ss_tmi_t tmi) const qpms_vswf_set_spec_t *qpms_ss_bspec_tmi(const qpms_scatsys_t *ss, qpms_ss_tmi_t tmi) nogil
const qpms_vswf_set_spec_t *qpms_ss_bspec_pi(const qpms_scatsys_t *ss, qpms_ss_pi_t pi) const qpms_vswf_set_spec_t *qpms_ss_bspec_pi(const qpms_scatsys_t *ss, qpms_ss_pi_t pi) nogil
beyn_result_t *qpms_scatsys_finite_find_eigenmodes(const qpms_scatsys_t *ss, qpms_iri_t iri, beyn_result_t *qpms_scatsys_finite_find_eigenmodes(const qpms_scatsys_t *ss, qpms_iri_t iri,
cdouble omega_centre, double omega_rr, double omega_ri, size_t contour_npoints, cdouble omega_centre, double omega_rr, double omega_ri, size_t contour_npoints,
double rank_tol, size_t rank_sel_min, double res_tol) double rank_tol, size_t rank_sel_min, double res_tol) nogil
# periodic-related funs # periodic-related funs
struct qpms_scatsys_at_omega_k_t: struct qpms_scatsys_at_omega_k_t:
const qpms_scatsys_at_omega_t *ssw const qpms_scatsys_at_omega_t *ssw
@ -693,19 +693,18 @@ cdef extern from "scatsystem.h":
qpms_ss_LU qpms_scatsyswk_build_modeproblem_matrix_full_LU(cdouble *target, int *target_piv, const qpms_scatsys_at_omega_k_t *sswk) qpms_ss_LU qpms_scatsyswk_build_modeproblem_matrix_full_LU(cdouble *target, int *target_piv, const qpms_scatsys_at_omega_k_t *sswk)
beyn_result_t *qpms_scatsys_periodic_find_eigenmodes(const qpms_scatsys_t *ss, const double *k, beyn_result_t *qpms_scatsys_periodic_find_eigenmodes(const qpms_scatsys_t *ss, const double *k,
cdouble omega_centre, double omega_rr, double omega_ri, size_t contour_npoints, cdouble omega_centre, double omega_rr, double omega_ri, size_t contour_npoints,
double rank_tol, size_t rank_sel_min, double res_tol) double rank_tol, size_t rank_sel_min, double res_tol) nogil
const qpms_vswf_set_spec_t *qpms_ss_bspec_pi(const qpms_scatsys_t *ss, qpms_ss_pi_t pi)
ccart3_t qpms_scatsys_scattered_E(const qpms_scatsys_t *ss, qpms_bessel_t btyp, cdouble wavenumber, ccart3_t qpms_scatsys_scattered_E(const qpms_scatsys_t *ss, qpms_bessel_t btyp, cdouble wavenumber,
const cdouble *f_excitation_vector_full, cart3_t where) const cdouble *f_excitation_vector_full, cart3_t where) nogil
ccart3_t qpms_scatsysw_scattered_E(const qpms_scatsys_at_omega_t *ssw, qpms_bessel_t btyp, ccart3_t qpms_scatsysw_scattered_E(const qpms_scatsys_at_omega_t *ssw, qpms_bessel_t btyp,
const cdouble *f_excitation_vector_full, cart3_t where) const cdouble *f_excitation_vector_full, cart3_t where) nogil
ccart3_t qpms_scatsys_scattered_E__alt(const qpms_scatsys_t *ss, qpms_bessel_t btyp, cdouble wavenumber, ccart3_t qpms_scatsys_scattered_E__alt(const qpms_scatsys_t *ss, qpms_bessel_t btyp, cdouble wavenumber,
const cdouble *f_excitation_vector_full, cart3_t where) const cdouble *f_excitation_vector_full, cart3_t where) nogil
ccart3_t qpms_scatsysw_scattered_E__alt(const qpms_scatsys_at_omega_t *ssw, qpms_bessel_t btyp, ccart3_t qpms_scatsysw_scattered_E__alt(const qpms_scatsys_at_omega_t *ssw, qpms_bessel_t btyp,
const cdouble *f_excitation_vector_full, cart3_t where) const cdouble *f_excitation_vector_full, cart3_t where) nogil
ccart3_t qpms_scatsyswk_scattered_E(const qpms_scatsys_at_omega_k_t *sswk, qpms_bessel_t btyp, ccart3_t qpms_scatsyswk_scattered_E(const qpms_scatsys_at_omega_k_t *sswk, qpms_bessel_t btyp,
const cdouble *f_excitation_vector_full, cart3_t where) const cdouble *f_excitation_vector_full, cart3_t where) nogil
double qpms_ss_adjusted_eta(const qpms_scatsys_t *ss, cdouble wavenumber, const double *wavevector); double qpms_ss_adjusted_eta(const qpms_scatsys_t *ss, cdouble wavenumber, const double *wavevector) nogil
cdef extern from "ewald.h": cdef extern from "ewald.h":