diff --git a/qpms/cytmatrices.pxd b/qpms/cytmatrices.pxd index 6e462b3..9fcc1e4 100644 --- a/qpms/cytmatrices.pxd +++ b/qpms/cytmatrices.pxd @@ -22,6 +22,16 @@ cdef class TMatrixGenerator: cdef inline qpms_tmatrix_generator_t *rawpointer(self): return &(self.g) +cdef class TMatrixFunction: + cdef readonly qpms_tmatrix_function_t f + cdef readonly TMatrixGenerator generator # reference holder + cdef readonly BaseSpec spec # reference holder + cdef inline qpms_tmatrix_function_t raw(self): + return self.f + cdef inline qpms_tmatrix_function_t *rawpointer(self): + return &self.f + + cdef class TMatrixGeneratorTransformed: pass diff --git a/qpms/cytmatrices.pyx b/qpms/cytmatrices.pyx index 6cdfd4b..99ca7be 100644 --- a/qpms/cytmatrices.pyx +++ b/qpms/cytmatrices.pyx @@ -228,6 +228,28 @@ cdef class __AxialSymParams: qpms_tmatrix_generator_axialsym_RQ_transposed_fill(&arrview[0][0], omega, &self.p, norm, QPMS_BESSEL_REGULAR) return arr +cdef class TMatrixFunction: + ''' + Wrapper over qpms_tmatrix_function_t. The main functional difference between this + and TMatrixGenerator is that this remembers a specific BaseSpec + and its __call__ method takes only one mandatory argument (in addition to self). + ''' + def __init__(self, TMatrixGenerator tmg, BaseSpec spec): + self.generator = tmg + self.spec = spec + self.f.spec = self.generator.rawpointer() + self.f.gen = self.spec.rawpointer() + + def __call__(self, cdouble omega, fill = None): + cdef CTMatrix tm + if fill is None: # make a new CTMatrix + tm = CTMatrix(self.spec, None) + else: # TODO check whether fill has the same bspec as self? + tm = fill + if self.g.function(tm.rawpointer(), omega, self.f.gen.params) != 0: + raise ValueError("Something went wrong") + else: + return tm cdef class TMatrixGenerator: