Merge branch 'python_tmg': T-matrix generators from python functions
This commit is contained in:
commit
572e15edbb
|
@ -0,0 +1,36 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
from qpms import TMatrixGenerator, BaseSpec, eV, hbar
|
||||||
|
import numpy as np
|
||||||
|
import sys
|
||||||
|
|
||||||
|
errors = 0
|
||||||
|
|
||||||
|
def tmg_diagonal_fun(tmatrix, omega):
|
||||||
|
'''
|
||||||
|
Example of a python function used as a custom T-matrix generator
|
||||||
|
|
||||||
|
It receives a CTMatrix argument with pre-filled BaseSpec
|
||||||
|
(in tmatrix.spec) and angular frequency.
|
||||||
|
|
||||||
|
It has to fill in the T-matrix elements tmatrix[...]
|
||||||
|
(a numpy array of shape (len(tmatrix.spec),len(tmatrix.spec)))
|
||||||
|
and return zero (on success) or other integral value on error.
|
||||||
|
|
||||||
|
Note that this in justa an example of using the API,
|
||||||
|
not supposed to be anything physical.
|
||||||
|
'''
|
||||||
|
l = tmatrix.spec.l()
|
||||||
|
tmatrix[...] = np.diag(1./l**2)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Wrap the function as an actual TMatrixGenerator
|
||||||
|
tmg_diagonal = TMatrixGenerator(tmg_diagonal_fun)
|
||||||
|
|
||||||
|
bspec = BaseSpec(lMax=2)
|
||||||
|
|
||||||
|
tmatrix = tmg_diagonal(bspec, (2.0+.01j) * eV/hbar)
|
||||||
|
|
||||||
|
errors += np.sum(tmatrix[...] != np.diag(1./bspec.l()**2))
|
||||||
|
|
||||||
|
sys.exit(errors)
|
||||||
|
|
|
@ -6,4 +6,6 @@ cdef class BaseSpec:
|
||||||
cdef qpms_vswf_set_spec_t s
|
cdef qpms_vswf_set_spec_t s
|
||||||
cdef np.ndarray __ilist
|
cdef np.ndarray __ilist
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef BaseSpec from_cpointer(const qpms_vswf_set_spec_t *orig)
|
||||||
cdef qpms_vswf_set_spec_t *rawpointer(BaseSpec self)
|
cdef qpms_vswf_set_spec_t *rawpointer(BaseSpec self)
|
||||||
|
|
|
@ -27,6 +27,7 @@ try:
|
||||||
except AttributeError: # For older Python versions, use IntEnum instead
|
except AttributeError: # For older Python versions, use IntEnum instead
|
||||||
VSWFNorm = enum.IntEnum('VSWFNorm', names=__VSWF_norm_dict, module=__name__)
|
VSWFNorm = enum.IntEnum('VSWFNorm', names=__VSWF_norm_dict, module=__name__)
|
||||||
|
|
||||||
|
|
||||||
cdef class BaseSpec:
|
cdef class BaseSpec:
|
||||||
'''Cython wrapper over qpms_vswf_set_spec_t.
|
'''Cython wrapper over qpms_vswf_set_spec_t.
|
||||||
|
|
||||||
|
@ -39,6 +40,16 @@ cdef class BaseSpec:
|
||||||
#cdef np.ndarray __ilist # in pxd
|
#cdef np.ndarray __ilist # in pxd
|
||||||
#cdef const qpms_uvswfi_t[:] __ilist
|
#cdef const qpms_uvswfi_t[:] __ilist
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
cdef BaseSpec from_cpointer(const qpms_vswf_set_spec_t *orig):
|
||||||
|
'''Makes an instance of BaseSpec from an existing
|
||||||
|
C pointer, copying the contents
|
||||||
|
'''
|
||||||
|
cdef const qpms_uvswfi_t[::1] ilist_orig = <qpms_uvswfi_t[:orig[0].n]> orig[0].ilist
|
||||||
|
cdef BaseSpec bs = BaseSpec(ilist_orig)
|
||||||
|
bs.s.norm = orig[0].norm
|
||||||
|
return bs
|
||||||
|
|
||||||
def __cinit__(self, *args, **kwargs):
|
def __cinit__(self, *args, **kwargs):
|
||||||
cdef const qpms_uvswfi_t[:] ilist_memview
|
cdef const qpms_uvswfi_t[:] ilist_memview
|
||||||
if len(args) == 0:
|
if len(args) == 0:
|
||||||
|
|
|
@ -72,14 +72,14 @@ cdef class CTMatrix: # N.B. there is another type called TMatrix in tmatrices.py
|
||||||
Wrapper over the C qpms_tmatrix_t stucture.
|
Wrapper over the C qpms_tmatrix_t stucture.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __cinit__(CTMatrix self, BaseSpec spec, matrix):
|
def __cinit__(CTMatrix self, BaseSpec spec, matrix, copy=True):
|
||||||
self.spec = spec
|
self.spec = spec
|
||||||
self.t.spec = self.spec.rawpointer();
|
self.t.spec = self.spec.rawpointer();
|
||||||
if (matrix is None) or not np.any(matrix):
|
if (matrix is None) or not np.any(matrix):
|
||||||
self.m = np.zeros((len(spec),len(spec)), dtype=complex, order='C')
|
self.m = np.zeros((len(spec),len(spec)), dtype=complex, order='C')
|
||||||
else:
|
else:
|
||||||
# The following will raise an exception if shape is wrong
|
# The following will raise an exception if shape is wrong
|
||||||
self.m = np.array(matrix, dtype=complex, copy=True, order='C').reshape((len(spec), len(spec)))
|
self.m = np.array(matrix, dtype=complex, copy=copy, order='C').reshape((len(spec), len(spec)))
|
||||||
#self.m.setflags(write=False) # checkme
|
#self.m.setflags(write=False) # checkme
|
||||||
cdef cdouble[:,:] m_memview = self.m
|
cdef cdouble[:,:] m_memview = self.m
|
||||||
self.t.m = &(m_memview[0,0])
|
self.t.m = &(m_memview[0,0])
|
||||||
|
@ -167,6 +167,22 @@ cdef qpms_arc_function_retval_t userarc(double theta, const void *params):
|
||||||
retval.r, retval.beta = fun(theta)
|
retval.r, retval.beta = fun(theta)
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
|
cdef qpms_errno_t qpms_tmatrix_generator_pythoncallable(
|
||||||
|
qpms_tmatrix_t *t, cdouble omega, const void *param):
|
||||||
|
'''Use a python callable as a T-matrix generator
|
||||||
|
|
||||||
|
fun is expected to be callable as
|
||||||
|
fun(CTMatrix tmatrix, cdouble omega)
|
||||||
|
Therefore we must recreate the CTMatrix object
|
||||||
|
and its BaseSpec object before passing it to fun
|
||||||
|
'''
|
||||||
|
cdef object fun = <object> param
|
||||||
|
cdef size_t n = t[0].spec[0].n
|
||||||
|
cdef BaseSpec bspec = BaseSpec.from_cpointer(t[0].spec)
|
||||||
|
cdef CTMatrix tmatrix = CTMatrix(bspec,
|
||||||
|
np.asarray(<np.complex[:n, :n]>(t[0].m)), copy=False)
|
||||||
|
return fun(tmatrix, omega)
|
||||||
|
|
||||||
|
|
||||||
cdef class ArcFunction:
|
cdef class ArcFunction:
|
||||||
cdef qpms_arc_function_t g
|
cdef qpms_arc_function_t g
|
||||||
|
@ -270,6 +286,11 @@ cdef class TMatrixGenerator:
|
||||||
self.holder = what
|
self.holder = what
|
||||||
self.g.function = qpms_tmatrix_generator_interpolator
|
self.g.function = qpms_tmatrix_generator_interpolator
|
||||||
self.g.params = <void*>(<TMatrixInterpolator?>self.holder).rawpointer()
|
self.g.params = <void*>(<TMatrixInterpolator?>self.holder).rawpointer()
|
||||||
|
elif callable(what):
|
||||||
|
warnings.warn("Custom python T-matrix generators are an experimental feature. Also expect it to be slow.")
|
||||||
|
self.holder = what
|
||||||
|
self.g.function = qpms_tmatrix_generator_pythoncallable
|
||||||
|
self.g.params = <void*>self.holder
|
||||||
else:
|
else:
|
||||||
raise TypeError("Can't construct TMatrixGenerator from that")
|
raise TypeError("Can't construct TMatrixGenerator from that")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue