Custom python T-matrix generators
This commit is contained in:
parent
314cde1b99
commit
a5b137847a
|
@ -72,14 +72,14 @@ cdef class CTMatrix: # N.B. there is another type called TMatrix in tmatrices.py
|
||||||
Wrapper over the C qpms_tmatrix_t stucture.
|
Wrapper over the C qpms_tmatrix_t stucture.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __cinit__(CTMatrix self, BaseSpec spec, matrix):
|
def __cinit__(CTMatrix self, BaseSpec spec, matrix, copy=True):
|
||||||
self.spec = spec
|
self.spec = spec
|
||||||
self.t.spec = self.spec.rawpointer();
|
self.t.spec = self.spec.rawpointer();
|
||||||
if (matrix is None) or not np.any(matrix):
|
if (matrix is None) or not np.any(matrix):
|
||||||
self.m = np.zeros((len(spec),len(spec)), dtype=complex, order='C')
|
self.m = np.zeros((len(spec),len(spec)), dtype=complex, order='C')
|
||||||
else:
|
else:
|
||||||
# The following will raise an exception if shape is wrong
|
# The following will raise an exception if shape is wrong
|
||||||
self.m = np.array(matrix, dtype=complex, copy=True, order='C').reshape((len(spec), len(spec)))
|
self.m = np.array(matrix, dtype=complex, copy=copy, order='C').reshape((len(spec), len(spec)))
|
||||||
#self.m.setflags(write=False) # checkme
|
#self.m.setflags(write=False) # checkme
|
||||||
cdef cdouble[:,:] m_memview = self.m
|
cdef cdouble[:,:] m_memview = self.m
|
||||||
self.t.m = &(m_memview[0,0])
|
self.t.m = &(m_memview[0,0])
|
||||||
|
@ -167,6 +167,22 @@ cdef qpms_arc_function_retval_t userarc(double theta, const void *params):
|
||||||
retval.r, retval.beta = fun(theta)
|
retval.r, retval.beta = fun(theta)
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
|
cdef qpms_errno_t qpms_tmatrix_generator_pythoncallable(
|
||||||
|
qpms_tmatrix_t *t, cdouble omega, const void *param):
|
||||||
|
'''Use a python callable as a T-matrix generator
|
||||||
|
|
||||||
|
fun is expected to be callable as
|
||||||
|
fun(CTMatrix tmatrix, cdouble omega)
|
||||||
|
Therefore we must recreate the CTMatrix object
|
||||||
|
and its BaseSpec object before passing it to fun
|
||||||
|
'''
|
||||||
|
cdef object fun = <object> param
|
||||||
|
cdef size_t n = t[0].spec[0].n
|
||||||
|
cdef BaseSpec bspec = BaseSpec.from_cpointer(t[0].spec)
|
||||||
|
cdef CTMatrix tmatrix = CTMatrix(bspec,
|
||||||
|
np.asarray(<np.complex[:n, :n]>(t[0].m)), copy=False)
|
||||||
|
return fun(tmatrix, omega)
|
||||||
|
|
||||||
|
|
||||||
cdef class ArcFunction:
|
cdef class ArcFunction:
|
||||||
cdef qpms_arc_function_t g
|
cdef qpms_arc_function_t g
|
||||||
|
@ -270,6 +286,11 @@ cdef class TMatrixGenerator:
|
||||||
self.holder = what
|
self.holder = what
|
||||||
self.g.function = qpms_tmatrix_generator_interpolator
|
self.g.function = qpms_tmatrix_generator_interpolator
|
||||||
self.g.params = <void*>(<TMatrixInterpolator?>self.holder).rawpointer()
|
self.g.params = <void*>(<TMatrixInterpolator?>self.holder).rawpointer()
|
||||||
|
elif callable(what):
|
||||||
|
warnings.warn("Custom python T-matrix generators are an experimental feature. Also expect it to be slow.")
|
||||||
|
self.holder = what
|
||||||
|
self.g.function = qpms_tmatrix_generator_pythoncallable
|
||||||
|
self.g.params = <void*>self.holder
|
||||||
else:
|
else:
|
||||||
raise TypeError("Can't construct TMatrixGenerator from that")
|
raise TypeError("Can't construct TMatrixGenerator from that")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue