From a5b137847a65a669cb2dd4debf09b8a8ff7bd906 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Ne=C4=8Dada?= Date: Sun, 22 Aug 2021 09:01:16 +0300 Subject: [PATCH] Custom python T-matrix generators --- qpms/cytmatrices.pyx | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/qpms/cytmatrices.pyx b/qpms/cytmatrices.pyx index fc858b6..91a22af 100644 --- a/qpms/cytmatrices.pyx +++ b/qpms/cytmatrices.pyx @@ -72,14 +72,14 @@ cdef class CTMatrix: # N.B. there is another type called TMatrix in tmatrices.py Wrapper over the C qpms_tmatrix_t stucture. ''' - def __cinit__(CTMatrix self, BaseSpec spec, matrix): + def __cinit__(CTMatrix self, BaseSpec spec, matrix, copy=True): self.spec = spec self.t.spec = self.spec.rawpointer(); if (matrix is None) or not np.any(matrix): self.m = np.zeros((len(spec),len(spec)), dtype=complex, order='C') else: # The following will raise an exception if shape is wrong - self.m = np.array(matrix, dtype=complex, copy=True, order='C').reshape((len(spec), len(spec))) + self.m = np.array(matrix, dtype=complex, copy=copy, order='C').reshape((len(spec), len(spec))) #self.m.setflags(write=False) # checkme cdef cdouble[:,:] m_memview = self.m self.t.m = &(m_memview[0,0]) @@ -167,6 +167,22 @@ cdef qpms_arc_function_retval_t userarc(double theta, const void *params): retval.r, retval.beta = fun(theta) return retval +cdef qpms_errno_t qpms_tmatrix_generator_pythoncallable( + qpms_tmatrix_t *t, cdouble omega, const void *param): + '''Use a python callable as a T-matrix generator + + fun is expected to be callable as + fun(CTMatrix tmatrix, cdouble omega) + Therefore we must recreate the CTMatrix object + and its BaseSpec object before passing it to fun + ''' + cdef object fun = param + cdef size_t n = t[0].spec[0].n + cdef BaseSpec bspec = BaseSpec.from_cpointer(t[0].spec) + cdef CTMatrix tmatrix = CTMatrix(bspec, + np.asarray((t[0].m)), copy=False) + return fun(tmatrix, omega) + cdef class ArcFunction: cdef qpms_arc_function_t g @@ -270,6 +286,11 @@ cdef class TMatrixGenerator: self.holder = what self.g.function = qpms_tmatrix_generator_interpolator self.g.params = (self.holder).rawpointer() + elif callable(what): + warnings.warn("Custom python T-matrix generators are an experimental feature. Also expect it to be slow.") + self.holder = what + self.g.function = qpms_tmatrix_generator_pythoncallable + self.g.params = self.holder else: raise TypeError("Can't construct TMatrixGenerator from that")