Merge branch 'python_tmg': T-matrix generators from python functions

This commit is contained in:
Marek Nečada 2021-08-23 08:42:50 +03:00
commit 572e15edbb
4 changed files with 72 additions and 2 deletions

View File

@ -0,0 +1,36 @@
#!/usr/bin/env python3
from qpms import TMatrixGenerator, BaseSpec, eV, hbar
import numpy as np
import sys
errors = 0
def tmg_diagonal_fun(tmatrix, omega):
'''
Example of a python function used as a custom T-matrix generator
It receives a CTMatrix argument with pre-filled BaseSpec
(in tmatrix.spec) and angular frequency.
It has to fill in the T-matrix elements tmatrix[...]
(a numpy array of shape (len(tmatrix.spec),len(tmatrix.spec)))
and return zero (on success) or other integral value on error.
Note that this in justa an example of using the API,
not supposed to be anything physical.
'''
l = tmatrix.spec.l()
tmatrix[...] = np.diag(1./l**2)
return 0
# Wrap the function as an actual TMatrixGenerator
tmg_diagonal = TMatrixGenerator(tmg_diagonal_fun)
bspec = BaseSpec(lMax=2)
tmatrix = tmg_diagonal(bspec, (2.0+.01j) * eV/hbar)
errors += np.sum(tmatrix[...] != np.diag(1./bspec.l()**2))
sys.exit(errors)

View File

@ -6,4 +6,6 @@ cdef class BaseSpec:
cdef qpms_vswf_set_spec_t s cdef qpms_vswf_set_spec_t s
cdef np.ndarray __ilist cdef np.ndarray __ilist
@staticmethod
cdef BaseSpec from_cpointer(const qpms_vswf_set_spec_t *orig)
cdef qpms_vswf_set_spec_t *rawpointer(BaseSpec self) cdef qpms_vswf_set_spec_t *rawpointer(BaseSpec self)

View File

@ -27,6 +27,7 @@ try:
except AttributeError: # For older Python versions, use IntEnum instead except AttributeError: # For older Python versions, use IntEnum instead
VSWFNorm = enum.IntEnum('VSWFNorm', names=__VSWF_norm_dict, module=__name__) VSWFNorm = enum.IntEnum('VSWFNorm', names=__VSWF_norm_dict, module=__name__)
cdef class BaseSpec: cdef class BaseSpec:
'''Cython wrapper over qpms_vswf_set_spec_t. '''Cython wrapper over qpms_vswf_set_spec_t.
@ -39,6 +40,16 @@ cdef class BaseSpec:
#cdef np.ndarray __ilist # in pxd #cdef np.ndarray __ilist # in pxd
#cdef const qpms_uvswfi_t[:] __ilist #cdef const qpms_uvswfi_t[:] __ilist
@staticmethod
cdef BaseSpec from_cpointer(const qpms_vswf_set_spec_t *orig):
'''Makes an instance of BaseSpec from an existing
C pointer, copying the contents
'''
cdef const qpms_uvswfi_t[::1] ilist_orig = <qpms_uvswfi_t[:orig[0].n]> orig[0].ilist
cdef BaseSpec bs = BaseSpec(ilist_orig)
bs.s.norm = orig[0].norm
return bs
def __cinit__(self, *args, **kwargs): def __cinit__(self, *args, **kwargs):
cdef const qpms_uvswfi_t[:] ilist_memview cdef const qpms_uvswfi_t[:] ilist_memview
if len(args) == 0: if len(args) == 0:

View File

@ -72,14 +72,14 @@ cdef class CTMatrix: # N.B. there is another type called TMatrix in tmatrices.py
Wrapper over the C qpms_tmatrix_t stucture. Wrapper over the C qpms_tmatrix_t stucture.
''' '''
def __cinit__(CTMatrix self, BaseSpec spec, matrix): def __cinit__(CTMatrix self, BaseSpec spec, matrix, copy=True):
self.spec = spec self.spec = spec
self.t.spec = self.spec.rawpointer(); self.t.spec = self.spec.rawpointer();
if (matrix is None) or not np.any(matrix): if (matrix is None) or not np.any(matrix):
self.m = np.zeros((len(spec),len(spec)), dtype=complex, order='C') self.m = np.zeros((len(spec),len(spec)), dtype=complex, order='C')
else: else:
# The following will raise an exception if shape is wrong # The following will raise an exception if shape is wrong
self.m = np.array(matrix, dtype=complex, copy=True, order='C').reshape((len(spec), len(spec))) self.m = np.array(matrix, dtype=complex, copy=copy, order='C').reshape((len(spec), len(spec)))
#self.m.setflags(write=False) # checkme #self.m.setflags(write=False) # checkme
cdef cdouble[:,:] m_memview = self.m cdef cdouble[:,:] m_memview = self.m
self.t.m = &(m_memview[0,0]) self.t.m = &(m_memview[0,0])
@ -167,6 +167,22 @@ cdef qpms_arc_function_retval_t userarc(double theta, const void *params):
retval.r, retval.beta = fun(theta) retval.r, retval.beta = fun(theta)
return retval return retval
cdef qpms_errno_t qpms_tmatrix_generator_pythoncallable(
qpms_tmatrix_t *t, cdouble omega, const void *param):
'''Use a python callable as a T-matrix generator
fun is expected to be callable as
fun(CTMatrix tmatrix, cdouble omega)
Therefore we must recreate the CTMatrix object
and its BaseSpec object before passing it to fun
'''
cdef object fun = <object> param
cdef size_t n = t[0].spec[0].n
cdef BaseSpec bspec = BaseSpec.from_cpointer(t[0].spec)
cdef CTMatrix tmatrix = CTMatrix(bspec,
np.asarray(<np.complex[:n, :n]>(t[0].m)), copy=False)
return fun(tmatrix, omega)
cdef class ArcFunction: cdef class ArcFunction:
cdef qpms_arc_function_t g cdef qpms_arc_function_t g
@ -270,6 +286,11 @@ cdef class TMatrixGenerator:
self.holder = what self.holder = what
self.g.function = qpms_tmatrix_generator_interpolator self.g.function = qpms_tmatrix_generator_interpolator
self.g.params = <void*>(<TMatrixInterpolator?>self.holder).rawpointer() self.g.params = <void*>(<TMatrixInterpolator?>self.holder).rawpointer()
elif callable(what):
warnings.warn("Custom python T-matrix generators are an experimental feature. Also expect it to be slow.")
self.holder = what
self.g.function = qpms_tmatrix_generator_pythoncallable
self.g.params = <void*>self.holder
else: else:
raise TypeError("Can't construct TMatrixGenerator from that") raise TypeError("Can't construct TMatrixGenerator from that")