Fixes to cython wrappers

Former-commit-id: 4598958f0c4bdc3262f71ce3e85467f5b07eee96
This commit is contained in:
Marek Nečada 2019-03-01 14:44:00 +02:00
parent 862cacf4b2
commit 833e82ab5b
1 changed files with 44 additions and 17 deletions

View File

@ -6,6 +6,10 @@ import cmath
from qpms_cdefs cimport * from qpms_cdefs cimport *
cimport cython cimport cython
from cython.parallel cimport parallel, prange from cython.parallel cimport parallel, prange
#import enum
# Here will be enum and dtype definitions; maybe move these to a separate file
import math # for copysign in crep methods import math # for copysign in crep methods
#import re # TODO for crep methods? #import re # TODO for crep methods?
@ -650,17 +654,26 @@ cdef class BaseSpec:
def __cinit__(self, *args, **kwargs): def __cinit__(self, *args, **kwargs):
cdef const qpms_uvswfi_t[:] ilist_memview cdef const qpms_uvswfi_t[:] ilist_memview
if len(args) > 0: if len(args) == 0:
if 'lMax' in kwargs.keys(): # if only lMax is specified, create the 'usual' definition in ('E','M') order
lMax = kwargs['lMax']
my, ny = get_mn_y(lMax)
nelem = len(my)
tlist = nelem * (QPMS_VSWF_ELECTRIC,) + nelem * (QPMS_VSWF_MAGNETIC,)
mlist = 2*list(my)
llist = 2*list(ny)
ilist = tlm2uvswfi(tlist,llist,mlist)
else:
raise ValueError
else: # len(args) > 0:
ilist = args[0] ilist = args[0]
#self.__ilist = np.array(args[0], dtype=qpms_uvswfi_t, order='C', copy=True) # FIXME define the dtypes at qpms_cdef.pxd level #self.__ilist = np.array(args[0], dtype=qpms_uvswfi_t, order='C', copy=True) # FIXME define the dtypes at qpms_cdef.pxd level
self.__ilist = np.array(args[0], dtype=np.ulonglong, order='C', copy=True) self.__ilist = np.array(ilist, dtype=np.ulonglong, order='C', copy=True)
self.__ilist.setflags(write=False) self.__ilist.setflags(write=False)
ilist_memview = self.__ilist ilist_memview = self.__ilist
self.s.ilist = &ilist_memview[0] self.s.ilist = &ilist_memview[0]
self.s.n = len(self.__ilist) self.s.n = len(self.__ilist)
self.s.capacity = 0 # is this the best way? self.s.capacity = 0 # is this the best way?
else:
raise ValueError
if 'norm' in kwargs.keys(): if 'norm' in kwargs.keys():
self.s.norm = kwargs['norm'] self.s.norm = kwargs['norm']
else: else:
@ -936,7 +949,7 @@ cdef class IRot3:
def __mul__(IRot3 self, IRot3 other): def __mul__(IRot3 self, IRot3 other):
res = IRot3(CQuat(1,0,0,0), 1) res = IRot3(CQuat(1,0,0,0), 1)
res.qd = qpms_IRot3_mult(self.qd, other.qd) res.qd = qpms_irot3_mult(self.qd, other.qd)
return res return res
def __pow__(IRot3 self, n, _): def __pow__(IRot3 self, n, _):
@ -946,7 +959,7 @@ cdef class IRot3:
else: else:
raise ValueError("The exponent of an IRot3 has to have an integer value.") raise ValueError("The exponent of an IRot3 has to have an integer value.")
res = IRot3(CQuat(1,0,0,0), 1) res = IRot3(CQuat(1,0,0,0), 1)
res.qd = qpms_IRot3_pow(self.qd, n) res.qd = qpms_irot3_pow(self.qd, n)
return res return res
def isclose(IRot3 self, IRot3 other, rtol=1e-5, atol=1e-8): def isclose(IRot3 self, IRot3 other, rtol=1e-5, atol=1e-8):
@ -1011,14 +1024,14 @@ cdef class IRot3:
r.rot = CQuat(math.cos(math.pi/n),0,0,math.sin(math.pi/n)) r.rot = CQuat(math.cos(math.pi/n),0,0,math.sin(math.pi/n))
return r return r
def as_uvswf_matrix(IRot3 self, basespec bspec): def as_uvswf_matrix(IRot3 self, BaseSpec bspec):
''' '''
Returns the uvswf representation of the current transform as a numpy array Returns the uvswf representation of the current transform as a numpy array
''' '''
cdef ssize_t sz = len(bspec) cdef ssize_t sz = len(bspec)
cdef np.ndarray m = np.empty((sz, sz), dtype=complex, order='C') # FIXME explicit dtype cdef np.ndarray m = np.empty((sz, sz), dtype=complex, order='C') # FIXME explicit dtype
cdef cdouble[:, ::1] view = m cdef cdouble[:, ::1] view = m
qpms_IRot3_uvswfi_dense(&view[0,0], bspec.rawpointer(), self.qd) qpms_irot3_uvswfi_dense(&view[0,0], bspec.rawpointer(), self.qd)
return m return m
cdef class TMatrixInterpolator: cdef class TMatrixInterpolator:
@ -1045,10 +1058,20 @@ cdef class TMatrix:
self.t.spec = spec.rawpointer(); self.t.spec = spec.rawpointer();
# The following will raise an exception if shape is wrong # The following will raise an exception if shape is wrong
self.m = np.array(matrix, dtype=complex, copy=True, order='C').reshape((len(spec), len(spec))) self.m = np.array(matrix, dtype=complex, copy=True, order='C').reshape((len(spec), len(spec)))
#self.m.setflags(write=False) # checkme
cdef cdouble[:,::1] m_memview = self.m cdef cdouble[:,::1] m_memview = self.m
self.t.m = &(m_memview[0,0]) self.t.m = &(m_memview[0,0])
self.t.owns_m = False # Memory in self.t.m is "owned" by self.m, not by self.t... self.t.owns_m = False # Memory in self.t.m is "owned" by self.m, not by self.t...
cdef qpms_tmatrix_t *rawpointer(TMatrix self):
'''Pointer to the qpms_tmatrix_t structure.
Don't forget to reference the BaseSpec object itself when storing the pointer anywhere!!!
'''
return &(self.t)
def as_array(TMatrix self):
return np.array(self.m, copy=True)
cdef class FinitePointGroup: cdef class FinitePointGroup:
''' '''
Wrapper over the qpms_finite_group_t structure. Wrapper over the qpms_finite_group_t structure.
@ -1059,7 +1082,7 @@ cdef class Particle:
''' '''
Wrapper over the qpms_particle_t structure. Wrapper over the qpms_particle_t structure.
''' '''
cdef readonly qpms_particle_t p cdef qpms_particle_t p
cdef readonly TMatrix t # We hold the reference to the T-matrix to ensure correct reference counting cdef readonly TMatrix t # We hold the reference to the T-matrix to ensure correct reference counting
def __cinit__(Particle self, position, TMatrix t): def __cinit__(Particle self, position, TMatrix t):
@ -1075,6 +1098,7 @@ cdef class ScatteringSystem:
def tlm2uvswfi(t, l, m): def tlm2uvswfi(t, l, m):
''' TODO doc ''' TODO doc
And TODO this should rather be an ufunc.
''' '''
# Very low-priority TODO: add some types / cythonize # Very low-priority TODO: add some types / cythonize
if isinstance(t, int) and isinstance(l, int) and isinstance(m, int): if isinstance(t, int) and isinstance(l, int) and isinstance(m, int):
@ -1082,16 +1106,19 @@ def tlm2uvswfi(t, l, m):
elif len(t) == len(l) and len(t) == len(m): elif len(t) == len(l) and len(t) == len(m):
u = list() u = list()
for i in range(len(t)): for i in range(len(t)):
if not (isinstance(t[i], int) and isinstance(l[i], int) and isinstance(m[i], int)): # not the best check possible, though if not (t[i] % 1 == 0 and l[i] % 1 == 0 and m[i] % 1 == 0): # maybe not the best check possible, though
raise ValueError # TODO error message raise ValueError # TODO error message
u.append(qpms_tmn2uvswfi(t[i],m[i],l[i])) u.append(qpms_tmn2uvswfi(t[i],m[i],l[i]))
return u return u
else: else:
raise ValueError # TODO error message print(len(t), len(l), len(m))
raise ValueError("Lengths of the t,l,m arrays must be equal, but they are %d, %d, %d."
% (len(t), len(l), len(m)))
def uvswfi2tlm(u): def uvswfi2tlm(u):
''' TODO doc ''' TODO doc
and TODO this should rather be an ufunc.
''' '''
cdef qpms_vswf_type_t t cdef qpms_vswf_type_t t
cdef qpms_l_t l cdef qpms_l_t l