Cython wrap qpms_vswf_set_spec_t

Former-commit-id: b0c253beff2af0eb61495152addd7942c6d092e4
This commit is contained in:
Marek Nečada 2019-02-25 22:25:42 +02:00
parent 980e281e4d
commit 9785327445
3 changed files with 159 additions and 3 deletions

View File

@ -636,6 +636,89 @@ def complex_crep(complex c, parentheses = False, shortI = True, has_Imaginary =
+ (')' if parentheses else '')
)
cdef class basespec:
'''Cython wrapper over qpms_vswf_set_spec_t.
It should be kept immutable. The memory is managed by numpy/cython, not directly by the C functions, therefore
whenever used in other wrapper classes that need the pointer
to qpms_vswf_set_spec_t, remember to set a (private, probably immutable) reference to qpms.basespec to ensure
correct reference counting and garbage collection.
'''
cdef qpms_vswf_set_spec_t s
cdef np.ndarray __ilist
#cdef const qpms_uvswfi_t[:] __ilist
def __cinit__(self, *args, **kwargs):
cdef const qpms_uvswfi_t[:] ilist_memview
if len(args) > 0:
ilist = args[0]
#self.__ilist = np.array(args[0], dtype=qpms_uvswfi_t, order='C', copy=True) # FIXME define the dtypes at qpms_cdef.pxd level
self.__ilist = np.array(args[0], dtype=np.ulonglong, order='C', copy=True)
self.__ilist.setflags(write=False)
ilist_memview = self.__ilist
self.s.ilist = &ilist_memview[0]
self.s.n = len(self.__ilist)
self.s.capacity = 0 # is this the best way?
else:
raise ValueError
if 'norm' in kwargs.keys():
self.s.norm = kwargs['norm']
else:
self.s.norm = QPMS_NORMALISATION_UNDEF
# set the other metadata
cdef qpms_l_t l
cdef qpms_m_t m
cdef qpms_vswf_type_t t
for i in range(self.s.n):
if(qpms_uvswfi2tmn(ilist_memview[i], &t, &m, &l) != QPMS_SUCCESS):
raise ValueError("Invalid uvswf index")
if (t == QPMS_VSWF_ELECTRIC):
self.s.lMax_N = max(self.s.lMax_N, l)
elif (t == QPMS_VSWF_MAGNETIC):
self.s.lMax_M = max(self.s.lMax_M, l)
elif (t == QPMS_VSWF_LONGITUDINAL):
self.s.lMax.L = max(self.s.lMax_L, l)
else:
raise ValueError # If this happens, it's probably a bug, as it should have failed already at qpms_uvswfi2tmn
self.s.lMax = max(self.s.lMax, l)
def tlm(self):
cdef const qpms_uvswfi_t[:] ilist_memview = <qpms_uvswfi_t[:self.s.n]> self.s.ilist
#cdef qpms_vswf_type_t[:] t = np.empty(shape=(self.s.n,), dtype=qpms_vswf_type_t) # does not work, workaround:
cdef size_t i
cdef np.ndarray ta = np.empty(shape=(self.s.n,), dtype=np.intc)
cdef int[:] t = ta
#cdef qpms_l_t[:] l = np.empty(shape=(self.s.n,), dtype=qpms_l_t) # FIXME explicit dtype again
cdef np.ndarray la = np.empty(shape=(self.s.n,), dtype=np.intc)
cdef qpms_l_t[:] l = la
#cdef qpms_m_t[:] m = np.empty(shape=(self.s.n,), dtype=qpms_m_t) # FIXME explicit dtype again
cdef np.ndarray ma = np.empty(shape=(self.s.n,), dtype=np.intc)
cdef qpms_m_t[:] m = ma
for i in range(self.s.n):
qpms_uvswfi2tmn(self.s.ilist[i], <qpms_vswf_type_t*>&t[i], &m[i], &l[i])
return (ta, la, ma)
def m(self): # ugly
return self.tlm()[2]
def t(self): # ugly
return self.tlm()[0]
def l(self): # ugly
return self.tlm()[1]
property ilist:
def __get__(self):
return self.__ilist
property rawpointer:
'''Pointer to the qpms_vswf_set_spec_t structure.
Don't forget to reference the basespec object itself!!!
'''
def __get__(self):
return <uintptr_t> &(self.s)
# Quaternions from wigner.h
# (mainly for testing; use moble's quaternions in python)
@ -919,7 +1002,45 @@ cdef class irot3:
r.rot = cquat(math.cos(math.pi/n),0,0,math.sin(math.pi/n))
return r
def tlm2uvswfi(t, l, m):
''' TODO doc
'''
# Very low-priority TODO: add some types / cythonize
if isinstance(t, int) and isinstance(l, int) and isinstance(m, int):
return qpms_tmn2uvswfi(t, m, l)
elif len(t) == len(l) and len(t) == len(m):
u = list()
for i in range(len(t)):
if not (isinstance(t[i], int) and isinstance(l[i], int) and isinstance(m[i], int)): # not the best check possible, though
raise ValueError # TODO error message
u.append(qpms_tmn2uvswfi(t[i],m[i],l[i]))
return u
else:
raise ValueError # TODO error message
def uvswfi2tlm(u):
''' TODO doc
'''
cdef qpms_vswf_type_t t
cdef qpms_l_t l
cdef qpms_m_t m
cdef size_t i
if isinstance(u, (int, np.ulonglong)):
if (qpms_uvswfi2tmn(u, &t, &m, &l) != QPMS_SUCCESS):
raise ValueError("Invalid uvswf index")
return (t, l, m)
else:
ta = list()
la = list()
ma = list()
for i in range(len(u)):
if (qpms_uvswfi2tmn(u[i], &t, &m, &l) != QPMS_SUCCESS):
raise ValueError("Invalid uvswf index")
ta.append(t)
la.append(l)
ma.append(m)
return (ta, la, ma)

View File

@ -1,5 +1,9 @@
cimport numpy as np
ctypedef double complex cdouble
from libc.stdint cimport uintptr_t
cdef extern from "qpms_types.h":
cdef struct cart3_t:
double x
@ -48,6 +52,30 @@ cdef extern from "qpms_types.h":
struct qpms_irot3_t:
qpms_quat_t rot
short det
ctypedef np.ulonglong_t qpms_uvswfi_t
struct qpms_vswf_set_spec_t:
size_t n
qpms_uvswfi_t *ilist
qpms_l_t lMax
qpms_l_t lMax_M
qpms_l_t lMax_N
qpms_l_t lMax_L
size_t capacity
qpms_normalisation_t norm
ctypedef enum qpms_errno_t:
QPMS_SUCCESS
QPMS_ERROR
# more if needed
ctypedef enum qpms_vswf_type_t:
QPMS_VSWF_ELECTRIC
QPMS_VSWF_MAGNETIC
QPMS_VSWF_LONGITUDINAL
# maybe more if needed
cdef extern from "indexing.h":
qpms_uvswfi_t qpms_tmn2uvswfi(qpms_vswf_type_t t, qpms_m_t m, qpms_l_t n)
qpms_errno_t qpms_uvswfi2tmn(qpms_uvswfi_t u, qpms_vswf_type_t* t, qpms_m_t* m, qpms_l_t* n)
qpms_m_t qpms_uvswfi2m(qpms_uvswfi_t u)
# maybe more if needed
# Point generators from lattices.h
@ -83,7 +111,6 @@ cdef extern from "lattices.h":
PGen PGen_1D_new_minMaxR(double period, double offset, double minR, bint inc_minR,
double maxR, bint inc_maxR, PGen_1D_incrementDirection incdir)
ctypedef double complex cdouble
cdef extern from "wigner.h":
qpms_quat_t qpms_quat_2c_from_4d(qpms_quat4d_t q)
@ -103,6 +130,13 @@ cdef extern from "wigner.h":
qpms_irot3_t qpms_irot3_mult(qpms_irot3_t p, qpms_irot3_t q)
qpms_irot3_t qpms_irot3_pow(qpms_irot3_t p, int n)
cdef extern from "symmetries.h":
cdouble *qpms_zflip_uvswi_dense(cdouble *target, const qpms_vswf_set_spec_t *bspec)
cdouble *qpms_yflip_uvswi_dense(cdouble *target, const qpms_vswf_set_spec_t *bspec)
cdouble *qpms_xflip_uvswi_dense(cdouble *target, const qpms_vswf_set_spec_t *bspec)
cdouble *qpms_zrot_uvswi_dense(cdouble *target, const qpms_vswf_set_spec_t *bspec, double phi)
cdouble *qpms_zrot_rational_uvswi_dense(cdouble *target, const qpms_vswf_set_spec_t *bspec, int N, int w)
cdouble *qpms_irot3_uvswfi_dense(cdouble *target, const qpms_vswf_set_spec_t *bspec, qpms_irot3_t transf)
#cdef extern from "numpy/arrayobject.h":
# cdef enum NPY_TYPES:

View File

@ -6,6 +6,7 @@
#include <complex.h>
#include <stdbool.h>
#include <stddef.h>
//#include <stdint.h>
#ifndef M_PI_2
#define M_PI_2 (1.570796326794896619231321691639751442098584699687552910487)
@ -65,7 +66,7 @@ typedef enum {
* from qpms_types.h instead
* as the formula might change in future versions.
*/
typedef size_t qpms_uvswfi_t;
typedef unsigned long long qpms_uvswfi_t;
/// Error codes / return values for certain numerical functions.
/** These are de facto a subset of the GSL error codes. */