Merge branch 'python_tmg': T-matrix generators from python functions
This commit is contained in:
commit
572e15edbb
|
@ -0,0 +1,36 @@
|
|||
#!/usr/bin/env python3
|
||||
from qpms import TMatrixGenerator, BaseSpec, eV, hbar
|
||||
import numpy as np
|
||||
import sys
|
||||
|
||||
errors = 0
|
||||
|
||||
def tmg_diagonal_fun(tmatrix, omega):
|
||||
'''
|
||||
Example of a python function used as a custom T-matrix generator
|
||||
|
||||
It receives a CTMatrix argument with pre-filled BaseSpec
|
||||
(in tmatrix.spec) and angular frequency.
|
||||
|
||||
It has to fill in the T-matrix elements tmatrix[...]
|
||||
(a numpy array of shape (len(tmatrix.spec),len(tmatrix.spec)))
|
||||
and return zero (on success) or other integral value on error.
|
||||
|
||||
Note that this in justa an example of using the API,
|
||||
not supposed to be anything physical.
|
||||
'''
|
||||
l = tmatrix.spec.l()
|
||||
tmatrix[...] = np.diag(1./l**2)
|
||||
return 0
|
||||
|
||||
# Wrap the function as an actual TMatrixGenerator
|
||||
tmg_diagonal = TMatrixGenerator(tmg_diagonal_fun)
|
||||
|
||||
bspec = BaseSpec(lMax=2)
|
||||
|
||||
tmatrix = tmg_diagonal(bspec, (2.0+.01j) * eV/hbar)
|
||||
|
||||
errors += np.sum(tmatrix[...] != np.diag(1./bspec.l()**2))
|
||||
|
||||
sys.exit(errors)
|
||||
|
|
@ -6,4 +6,6 @@ cdef class BaseSpec:
|
|||
cdef qpms_vswf_set_spec_t s
|
||||
cdef np.ndarray __ilist
|
||||
|
||||
@staticmethod
|
||||
cdef BaseSpec from_cpointer(const qpms_vswf_set_spec_t *orig)
|
||||
cdef qpms_vswf_set_spec_t *rawpointer(BaseSpec self)
|
||||
|
|
|
@ -27,6 +27,7 @@ try:
|
|||
except AttributeError: # For older Python versions, use IntEnum instead
|
||||
VSWFNorm = enum.IntEnum('VSWFNorm', names=__VSWF_norm_dict, module=__name__)
|
||||
|
||||
|
||||
cdef class BaseSpec:
|
||||
'''Cython wrapper over qpms_vswf_set_spec_t.
|
||||
|
||||
|
@ -39,6 +40,16 @@ cdef class BaseSpec:
|
|||
#cdef np.ndarray __ilist # in pxd
|
||||
#cdef const qpms_uvswfi_t[:] __ilist
|
||||
|
||||
@staticmethod
|
||||
cdef BaseSpec from_cpointer(const qpms_vswf_set_spec_t *orig):
|
||||
'''Makes an instance of BaseSpec from an existing
|
||||
C pointer, copying the contents
|
||||
'''
|
||||
cdef const qpms_uvswfi_t[::1] ilist_orig = <qpms_uvswfi_t[:orig[0].n]> orig[0].ilist
|
||||
cdef BaseSpec bs = BaseSpec(ilist_orig)
|
||||
bs.s.norm = orig[0].norm
|
||||
return bs
|
||||
|
||||
def __cinit__(self, *args, **kwargs):
|
||||
cdef const qpms_uvswfi_t[:] ilist_memview
|
||||
if len(args) == 0:
|
||||
|
|
|
@ -72,14 +72,14 @@ cdef class CTMatrix: # N.B. there is another type called TMatrix in tmatrices.py
|
|||
Wrapper over the C qpms_tmatrix_t stucture.
|
||||
'''
|
||||
|
||||
def __cinit__(CTMatrix self, BaseSpec spec, matrix):
|
||||
def __cinit__(CTMatrix self, BaseSpec spec, matrix, copy=True):
|
||||
self.spec = spec
|
||||
self.t.spec = self.spec.rawpointer();
|
||||
if (matrix is None) or not np.any(matrix):
|
||||
self.m = np.zeros((len(spec),len(spec)), dtype=complex, order='C')
|
||||
else:
|
||||
# The following will raise an exception if shape is wrong
|
||||
self.m = np.array(matrix, dtype=complex, copy=True, order='C').reshape((len(spec), len(spec)))
|
||||
self.m = np.array(matrix, dtype=complex, copy=copy, order='C').reshape((len(spec), len(spec)))
|
||||
#self.m.setflags(write=False) # checkme
|
||||
cdef cdouble[:,:] m_memview = self.m
|
||||
self.t.m = &(m_memview[0,0])
|
||||
|
@ -167,6 +167,22 @@ cdef qpms_arc_function_retval_t userarc(double theta, const void *params):
|
|||
retval.r, retval.beta = fun(theta)
|
||||
return retval
|
||||
|
||||
cdef qpms_errno_t qpms_tmatrix_generator_pythoncallable(
|
||||
qpms_tmatrix_t *t, cdouble omega, const void *param):
|
||||
'''Use a python callable as a T-matrix generator
|
||||
|
||||
fun is expected to be callable as
|
||||
fun(CTMatrix tmatrix, cdouble omega)
|
||||
Therefore we must recreate the CTMatrix object
|
||||
and its BaseSpec object before passing it to fun
|
||||
'''
|
||||
cdef object fun = <object> param
|
||||
cdef size_t n = t[0].spec[0].n
|
||||
cdef BaseSpec bspec = BaseSpec.from_cpointer(t[0].spec)
|
||||
cdef CTMatrix tmatrix = CTMatrix(bspec,
|
||||
np.asarray(<np.complex[:n, :n]>(t[0].m)), copy=False)
|
||||
return fun(tmatrix, omega)
|
||||
|
||||
|
||||
cdef class ArcFunction:
|
||||
cdef qpms_arc_function_t g
|
||||
|
@ -270,6 +286,11 @@ cdef class TMatrixGenerator:
|
|||
self.holder = what
|
||||
self.g.function = qpms_tmatrix_generator_interpolator
|
||||
self.g.params = <void*>(<TMatrixInterpolator?>self.holder).rawpointer()
|
||||
elif callable(what):
|
||||
warnings.warn("Custom python T-matrix generators are an experimental feature. Also expect it to be slow.")
|
||||
self.holder = what
|
||||
self.g.function = qpms_tmatrix_generator_pythoncallable
|
||||
self.g.params = <void*>self.holder
|
||||
else:
|
||||
raise TypeError("Can't construct TMatrixGenerator from that")
|
||||
|
||||
|
|
Loading…
Reference in New Issue