Fixes to cython wrappers
Former-commit-id: 4598958f0c4bdc3262f71ce3e85467f5b07eee96
This commit is contained in:
parent
862cacf4b2
commit
833e82ab5b
|
@ -6,6 +6,10 @@ import cmath
|
|||
from qpms_cdefs cimport *
|
||||
cimport cython
|
||||
from cython.parallel cimport parallel, prange
|
||||
#import enum
|
||||
|
||||
# Here will be enum and dtype definitions; maybe move these to a separate file
|
||||
|
||||
|
||||
import math # for copysign in crep methods
|
||||
#import re # TODO for crep methods?
|
||||
|
@ -650,17 +654,26 @@ cdef class BaseSpec:
|
|||
|
||||
def __cinit__(self, *args, **kwargs):
|
||||
cdef const qpms_uvswfi_t[:] ilist_memview
|
||||
if len(args) > 0:
|
||||
if len(args) == 0:
|
||||
if 'lMax' in kwargs.keys(): # if only lMax is specified, create the 'usual' definition in ('E','M') order
|
||||
lMax = kwargs['lMax']
|
||||
my, ny = get_mn_y(lMax)
|
||||
nelem = len(my)
|
||||
tlist = nelem * (QPMS_VSWF_ELECTRIC,) + nelem * (QPMS_VSWF_MAGNETIC,)
|
||||
mlist = 2*list(my)
|
||||
llist = 2*list(ny)
|
||||
ilist = tlm2uvswfi(tlist,llist,mlist)
|
||||
else:
|
||||
raise ValueError
|
||||
else: # len(args) > 0:
|
||||
ilist = args[0]
|
||||
#self.__ilist = np.array(args[0], dtype=qpms_uvswfi_t, order='C', copy=True) # FIXME define the dtypes at qpms_cdef.pxd level
|
||||
self.__ilist = np.array(args[0], dtype=np.ulonglong, order='C', copy=True)
|
||||
self.__ilist = np.array(ilist, dtype=np.ulonglong, order='C', copy=True)
|
||||
self.__ilist.setflags(write=False)
|
||||
ilist_memview = self.__ilist
|
||||
self.s.ilist = &ilist_memview[0]
|
||||
self.s.n = len(self.__ilist)
|
||||
self.s.capacity = 0 # is this the best way?
|
||||
else:
|
||||
raise ValueError
|
||||
if 'norm' in kwargs.keys():
|
||||
self.s.norm = kwargs['norm']
|
||||
else:
|
||||
|
@ -936,7 +949,7 @@ cdef class IRot3:
|
|||
|
||||
def __mul__(IRot3 self, IRot3 other):
|
||||
res = IRot3(CQuat(1,0,0,0), 1)
|
||||
res.qd = qpms_IRot3_mult(self.qd, other.qd)
|
||||
res.qd = qpms_irot3_mult(self.qd, other.qd)
|
||||
return res
|
||||
|
||||
def __pow__(IRot3 self, n, _):
|
||||
|
@ -946,7 +959,7 @@ cdef class IRot3:
|
|||
else:
|
||||
raise ValueError("The exponent of an IRot3 has to have an integer value.")
|
||||
res = IRot3(CQuat(1,0,0,0), 1)
|
||||
res.qd = qpms_IRot3_pow(self.qd, n)
|
||||
res.qd = qpms_irot3_pow(self.qd, n)
|
||||
return res
|
||||
|
||||
def isclose(IRot3 self, IRot3 other, rtol=1e-5, atol=1e-8):
|
||||
|
@ -1011,14 +1024,14 @@ cdef class IRot3:
|
|||
r.rot = CQuat(math.cos(math.pi/n),0,0,math.sin(math.pi/n))
|
||||
return r
|
||||
|
||||
def as_uvswf_matrix(IRot3 self, basespec bspec):
|
||||
def as_uvswf_matrix(IRot3 self, BaseSpec bspec):
|
||||
'''
|
||||
Returns the uvswf representation of the current transform as a numpy array
|
||||
'''
|
||||
cdef ssize_t sz = len(bspec)
|
||||
cdef np.ndarray m = np.empty((sz, sz), dtype=complex, order='C') # FIXME explicit dtype
|
||||
cdef cdouble[:, ::1] view = m
|
||||
qpms_IRot3_uvswfi_dense(&view[0,0], bspec.rawpointer(), self.qd)
|
||||
qpms_irot3_uvswfi_dense(&view[0,0], bspec.rawpointer(), self.qd)
|
||||
return m
|
||||
|
||||
cdef class TMatrixInterpolator:
|
||||
|
@ -1045,10 +1058,20 @@ cdef class TMatrix:
|
|||
self.t.spec = spec.rawpointer();
|
||||
# The following will raise an exception if shape is wrong
|
||||
self.m = np.array(matrix, dtype=complex, copy=True, order='C').reshape((len(spec), len(spec)))
|
||||
#self.m.setflags(write=False) # checkme
|
||||
cdef cdouble[:,::1] m_memview = self.m
|
||||
self.t.m = &(m_memview[0,0])
|
||||
self.t.owns_m = False # Memory in self.t.m is "owned" by self.m, not by self.t...
|
||||
|
||||
cdef qpms_tmatrix_t *rawpointer(TMatrix self):
|
||||
'''Pointer to the qpms_tmatrix_t structure.
|
||||
Don't forget to reference the BaseSpec object itself when storing the pointer anywhere!!!
|
||||
'''
|
||||
return &(self.t)
|
||||
|
||||
def as_array(TMatrix self):
|
||||
return np.array(self.m, copy=True)
|
||||
|
||||
cdef class FinitePointGroup:
|
||||
'''
|
||||
Wrapper over the qpms_finite_group_t structure.
|
||||
|
@ -1059,7 +1082,7 @@ cdef class Particle:
|
|||
'''
|
||||
Wrapper over the qpms_particle_t structure.
|
||||
'''
|
||||
cdef readonly qpms_particle_t p
|
||||
cdef qpms_particle_t p
|
||||
cdef readonly TMatrix t # We hold the reference to the T-matrix to ensure correct reference counting
|
||||
|
||||
def __cinit__(Particle self, position, TMatrix t):
|
||||
|
@ -1075,6 +1098,7 @@ cdef class ScatteringSystem:
|
|||
|
||||
def tlm2uvswfi(t, l, m):
|
||||
''' TODO doc
|
||||
And TODO this should rather be an ufunc.
|
||||
'''
|
||||
# Very low-priority TODO: add some types / cythonize
|
||||
if isinstance(t, int) and isinstance(l, int) and isinstance(m, int):
|
||||
|
@ -1082,16 +1106,19 @@ def tlm2uvswfi(t, l, m):
|
|||
elif len(t) == len(l) and len(t) == len(m):
|
||||
u = list()
|
||||
for i in range(len(t)):
|
||||
if not (isinstance(t[i], int) and isinstance(l[i], int) and isinstance(m[i], int)): # not the best check possible, though
|
||||
if not (t[i] % 1 == 0 and l[i] % 1 == 0 and m[i] % 1 == 0): # maybe not the best check possible, though
|
||||
raise ValueError # TODO error message
|
||||
u.append(qpms_tmn2uvswfi(t[i],m[i],l[i]))
|
||||
return u
|
||||
else:
|
||||
raise ValueError # TODO error message
|
||||
print(len(t), len(l), len(m))
|
||||
raise ValueError("Lengths of the t,l,m arrays must be equal, but they are %d, %d, %d."
|
||||
% (len(t), len(l), len(m)))
|
||||
|
||||
|
||||
def uvswfi2tlm(u):
|
||||
''' TODO doc
|
||||
and TODO this should rather be an ufunc.
|
||||
'''
|
||||
cdef qpms_vswf_type_t t
|
||||
cdef qpms_l_t l
|
||||
|
|
Loading…
Reference in New Issue