diff --git a/examples/api/custom_tmatrix/custom_tmatrices_simple.py b/examples/api/custom_tmatrix/custom_tmatrices_simple.py new file mode 100755 index 0000000..b6f3d50 --- /dev/null +++ b/examples/api/custom_tmatrix/custom_tmatrices_simple.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +from qpms import TMatrixGenerator, BaseSpec, eV, hbar +import numpy as np +import sys + +errors = 0 + +def tmg_diagonal_fun(tmatrix, omega): + ''' + Example of a python function used as a custom T-matrix generator + + It receives a CTMatrix argument with pre-filled BaseSpec + (in tmatrix.spec) and angular frequency. + + It has to fill in the T-matrix elements tmatrix[...] + (a numpy array of shape (len(tmatrix.spec),len(tmatrix.spec))) + and return zero (on success) or other integral value on error. + + Note that this in justa an example of using the API, + not supposed to be anything physical. + ''' + l = tmatrix.spec.l() + tmatrix[...] = np.diag(1./l**2) + return 0 + +# Wrap the function as an actual TMatrixGenerator +tmg_diagonal = TMatrixGenerator(tmg_diagonal_fun) + +bspec = BaseSpec(lMax=2) + +tmatrix = tmg_diagonal(bspec, (2.0+.01j) * eV/hbar) + +errors += np.sum(tmatrix[...] != np.diag(1./bspec.l()**2)) + +sys.exit(errors) + diff --git a/qpms/cybspec.pxd b/qpms/cybspec.pxd index 853c372..4d080fc 100644 --- a/qpms/cybspec.pxd +++ b/qpms/cybspec.pxd @@ -6,4 +6,6 @@ cdef class BaseSpec: cdef qpms_vswf_set_spec_t s cdef np.ndarray __ilist + @staticmethod + cdef BaseSpec from_cpointer(const qpms_vswf_set_spec_t *orig) cdef qpms_vswf_set_spec_t *rawpointer(BaseSpec self) diff --git a/qpms/cybspec.pyx b/qpms/cybspec.pyx index e2813b8..9983272 100644 --- a/qpms/cybspec.pyx +++ b/qpms/cybspec.pyx @@ -27,6 +27,7 @@ try: except AttributeError: # For older Python versions, use IntEnum instead VSWFNorm = enum.IntEnum('VSWFNorm', names=__VSWF_norm_dict, module=__name__) + cdef class BaseSpec: '''Cython wrapper over qpms_vswf_set_spec_t. @@ -39,6 +40,16 @@ cdef class BaseSpec: #cdef np.ndarray __ilist # in pxd #cdef const qpms_uvswfi_t[:] __ilist + @staticmethod + cdef BaseSpec from_cpointer(const qpms_vswf_set_spec_t *orig): + '''Makes an instance of BaseSpec from an existing + C pointer, copying the contents + ''' + cdef const qpms_uvswfi_t[::1] ilist_orig = orig[0].ilist + cdef BaseSpec bs = BaseSpec(ilist_orig) + bs.s.norm = orig[0].norm + return bs + def __cinit__(self, *args, **kwargs): cdef const qpms_uvswfi_t[:] ilist_memview if len(args) == 0: diff --git a/qpms/cytmatrices.pyx b/qpms/cytmatrices.pyx index fc858b6..91a22af 100644 --- a/qpms/cytmatrices.pyx +++ b/qpms/cytmatrices.pyx @@ -72,14 +72,14 @@ cdef class CTMatrix: # N.B. there is another type called TMatrix in tmatrices.py Wrapper over the C qpms_tmatrix_t stucture. ''' - def __cinit__(CTMatrix self, BaseSpec spec, matrix): + def __cinit__(CTMatrix self, BaseSpec spec, matrix, copy=True): self.spec = spec self.t.spec = self.spec.rawpointer(); if (matrix is None) or not np.any(matrix): self.m = np.zeros((len(spec),len(spec)), dtype=complex, order='C') else: # The following will raise an exception if shape is wrong - self.m = np.array(matrix, dtype=complex, copy=True, order='C').reshape((len(spec), len(spec))) + self.m = np.array(matrix, dtype=complex, copy=copy, order='C').reshape((len(spec), len(spec))) #self.m.setflags(write=False) # checkme cdef cdouble[:,:] m_memview = self.m self.t.m = &(m_memview[0,0]) @@ -167,6 +167,22 @@ cdef qpms_arc_function_retval_t userarc(double theta, const void *params): retval.r, retval.beta = fun(theta) return retval +cdef qpms_errno_t qpms_tmatrix_generator_pythoncallable( + qpms_tmatrix_t *t, cdouble omega, const void *param): + '''Use a python callable as a T-matrix generator + + fun is expected to be callable as + fun(CTMatrix tmatrix, cdouble omega) + Therefore we must recreate the CTMatrix object + and its BaseSpec object before passing it to fun + ''' + cdef object fun = param + cdef size_t n = t[0].spec[0].n + cdef BaseSpec bspec = BaseSpec.from_cpointer(t[0].spec) + cdef CTMatrix tmatrix = CTMatrix(bspec, + np.asarray((t[0].m)), copy=False) + return fun(tmatrix, omega) + cdef class ArcFunction: cdef qpms_arc_function_t g @@ -270,6 +286,11 @@ cdef class TMatrixGenerator: self.holder = what self.g.function = qpms_tmatrix_generator_interpolator self.g.params = (self.holder).rawpointer() + elif callable(what): + warnings.warn("Custom python T-matrix generators are an experimental feature. Also expect it to be slow.") + self.holder = what + self.g.function = qpms_tmatrix_generator_pythoncallable + self.g.params = self.holder else: raise TypeError("Can't construct TMatrixGenerator from that")