Custom python T-matrix generators

This commit is contained in:
Marek Nečada 2021-08-22 09:01:16 +03:00
parent 314cde1b99
commit a5b137847a
1 changed files with 23 additions and 2 deletions

View File

@ -72,14 +72,14 @@ cdef class CTMatrix: # N.B. there is another type called TMatrix in tmatrices.py
Wrapper over the C qpms_tmatrix_t stucture. Wrapper over the C qpms_tmatrix_t stucture.
''' '''
def __cinit__(CTMatrix self, BaseSpec spec, matrix): def __cinit__(CTMatrix self, BaseSpec spec, matrix, copy=True):
self.spec = spec self.spec = spec
self.t.spec = self.spec.rawpointer(); self.t.spec = self.spec.rawpointer();
if (matrix is None) or not np.any(matrix): if (matrix is None) or not np.any(matrix):
self.m = np.zeros((len(spec),len(spec)), dtype=complex, order='C') self.m = np.zeros((len(spec),len(spec)), dtype=complex, order='C')
else: else:
# The following will raise an exception if shape is wrong # The following will raise an exception if shape is wrong
self.m = np.array(matrix, dtype=complex, copy=True, order='C').reshape((len(spec), len(spec))) self.m = np.array(matrix, dtype=complex, copy=copy, order='C').reshape((len(spec), len(spec)))
#self.m.setflags(write=False) # checkme #self.m.setflags(write=False) # checkme
cdef cdouble[:,:] m_memview = self.m cdef cdouble[:,:] m_memview = self.m
self.t.m = &(m_memview[0,0]) self.t.m = &(m_memview[0,0])
@ -167,6 +167,22 @@ cdef qpms_arc_function_retval_t userarc(double theta, const void *params):
retval.r, retval.beta = fun(theta) retval.r, retval.beta = fun(theta)
return retval return retval
cdef qpms_errno_t qpms_tmatrix_generator_pythoncallable(
qpms_tmatrix_t *t, cdouble omega, const void *param):
'''Use a python callable as a T-matrix generator
fun is expected to be callable as
fun(CTMatrix tmatrix, cdouble omega)
Therefore we must recreate the CTMatrix object
and its BaseSpec object before passing it to fun
'''
cdef object fun = <object> param
cdef size_t n = t[0].spec[0].n
cdef BaseSpec bspec = BaseSpec.from_cpointer(t[0].spec)
cdef CTMatrix tmatrix = CTMatrix(bspec,
np.asarray(<np.complex[:n, :n]>(t[0].m)), copy=False)
return fun(tmatrix, omega)
cdef class ArcFunction: cdef class ArcFunction:
cdef qpms_arc_function_t g cdef qpms_arc_function_t g
@ -270,6 +286,11 @@ cdef class TMatrixGenerator:
self.holder = what self.holder = what
self.g.function = qpms_tmatrix_generator_interpolator self.g.function = qpms_tmatrix_generator_interpolator
self.g.params = <void*>(<TMatrixInterpolator?>self.holder).rawpointer() self.g.params = <void*>(<TMatrixInterpolator?>self.holder).rawpointer()
elif callable(what):
warnings.warn("Custom python T-matrix generators are an experimental feature. Also expect it to be slow.")
self.holder = what
self.g.function = qpms_tmatrix_generator_pythoncallable
self.g.params = <void*>self.holder
else: else:
raise TypeError("Can't construct TMatrixGenerator from that") raise TypeError("Can't construct TMatrixGenerator from that")