Split away MaterialInterpolator

Former-commit-id: a882ac5e1a6fa38137fa9059fda53c671534128b
This commit is contained in:
Marek Nečada 2019-08-10 09:20:39 +03:00
parent 5d1f05984e
commit 6ea386d759
7 changed files with 82 additions and 53 deletions

View File

@ -6,6 +6,7 @@ from .qpms_p import *
from .cyquaternions import CQuat, IRot3 from .cyquaternions import CQuat, IRot3
from .cybspec import VSWFNorm, BaseSpec from .cybspec import VSWFNorm, BaseSpec
from .cytranslations import trans_calculator from .cytranslations import trans_calculator
from .cymaterials import MaterialInterpolator
from .lattices2d import * from .lattices2d import *
from .hexpoints import * from .hexpoints import *
from .tmatrices import * from .tmatrices import *

1
qpms/cycommon.pxd Normal file
View File

@ -0,0 +1 @@
cdef char *make_c_string(pythonstring)

View File

@ -1,5 +1,6 @@
import numpy as np import numpy as np
from qpms_cdefs cimport * from qpms_cdefs cimport *
from libc.stdlib cimport malloc
cimport cython cimport cython
import enum import enum
@ -180,3 +181,23 @@ def uvswfi2tlm(u):
ma.append(m) ma.append(m)
return (ta, la, ma) return (ta, la, ma)
cdef char *make_c_string(pythonstring):
'''
Copies contents of a python string into a char[]
(allocating the memory with malloc())
'''
bytestring = pythonstring.encode('UTF-8')
cdef Py_ssize_t n = len(bytestring)
cdef Py_ssize_t i
cdef char *s
s = <char *>malloc(n+1)
if not s:
raise MemoryError
#s[:n] = bytestring # This segfaults; why?
for i in range(n): s[i] = bytestring[i]
s[n] = <char>0
return s
def string_c2py(const char* cstring):
return cstring.decode('UTF-8')

8
qpms/cymaterials.pxd Normal file
View File

@ -0,0 +1,8 @@
from qpms_cdefs cimport qpms_permittivity_interpolator_t
cdef class MaterialInterpolator:
cdef qpms_permittivity_interpolator_t *interp
cdef readonly double omegamin
cdef readonly double omegamax

44
qpms/cymaterials.pyx Normal file
View File

@ -0,0 +1,44 @@
# Cythonized parts of QPMS here
# -----------------------------
import numpy as np
import cmath
from qpms_cdefs cimport *
from cybspec cimport *
from cycommon import *
from cycommon cimport make_c_string
cimport cython
import enum
import warnings
import os
from libc.stdlib cimport malloc, free, calloc, abort
cdef class MaterialInterpolator:
'''
Wrapper over the qpms_permittivity_interpolator_t structure.
'''
def __cinit__(self, filename, *args, **kwargs):
'''Creates a permittivity interpolator.'''
cdef char *cpath = make_c_string(filename)
self.interp = qpms_permittivity_interpolator_from_yml(cpath, gsl_interp_cspline)
if not self.interp:
raise IOError("Could not load permittivity data from %s" % filename)
self.omegamin = qpms_permittivity_interpolator_omega_min(self.interp)
self.omegamax = qpms_permittivity_interpolator_omega_max(self.interp)
def __dealloc__(self):
qpms_permittivity_interpolator_free(self.interp)
def __call__(self, double freq):
'''Returns interpolated permittivity, corresponding to a given angular frequency.'''
if freq < self.omegamin or freq > self.omegamax:
raise ValueError("Input frequency %g is outside the interpolator domain (%g, %g)."
% (freq, self.minomega, self.freqs[self.maxomega]))
return qpms_permittivity_interpolator_eps_at_omega(self.interp, freq)
property freq_interval:
def __get__(self):
return [self.omegamin, self.omegamax]

View File

@ -15,44 +15,13 @@ from cyquaternions cimport *
from cybspec cimport * from cybspec cimport *
#from cybspec import * #from cybspec import *
from cycommon import * from cycommon import *
from cycommon cimport make_c_string
cimport cython cimport cython
import enum import enum
import warnings import warnings
import os import os
from libc.stdlib cimport malloc, free, calloc, abort from libc.stdlib cimport malloc, free, calloc, abort
cdef class MaterialInterpolator:
'''
Wrapper over the qpms_permittivity_interpolator_t structure.
'''
cdef qpms_permittivity_interpolator_t *interp
cdef readonly double omegamin
cdef readonly double omegamax
def __cinit__(self, filename, *args, **kwargs):
'''Creates a permittivity interpolator.'''
cdef char *cpath = make_c_string(filename)
self.interp = qpms_permittivity_interpolator_from_yml(cpath, gsl_interp_cspline)
if not self.interp:
raise IOError("Could not load permittivity data from %s" % filename)
self.omegamin = qpms_permittivity_interpolator_omega_min(self.interp)
self.omegamax = qpms_permittivity_interpolator_omega_max(self.interp)
def __dealloc__(self):
qpms_permittivity_interpolator_free(self.interp)
def __call__(self, double freq):
'''Returns interpolated permittivity, corresponding to a given angular frequency.'''
if freq < self.omegamin or freq > self.omegamax:
raise ValueError("Input frequency %g is outside the interpolator domain (%g, %g)."
% (freq, self.minomega, self.freqs[self.maxomega]))
return qpms_permittivity_interpolator_eps_at_omega(self.interp, freq)
property freq_interval:
def __get__(self):
return [self.omegamin, self.omegamax]
cdef class TMatrixInterpolator: cdef class TMatrixInterpolator:
''' '''
Wrapper over the qpms_tmatrix_interpolator_t structure. Wrapper over the qpms_tmatrix_interpolator_t structure.
@ -177,26 +146,6 @@ cdef class CTMatrix: # N.B. there is another type called TMatrix in tmatrices.py
tm.spherical_perm_fill(radius, freq, epsilon_int, epsilon_ext) tm.spherical_perm_fill(radius, freq, epsilon_int, epsilon_ext)
return tm return tm
cdef char *make_c_string(pythonstring):
'''
Copies contents of a python string into a char[]
(allocating the memory with malloc())
'''
bytestring = pythonstring.encode('UTF-8')
cdef Py_ssize_t n = len(bytestring)
cdef Py_ssize_t i
cdef char *s
s = <char *>malloc(n+1)
if not s:
raise MemoryError
#s[:n] = bytestring # This segfaults; why?
for i in range(n): s[i] = bytestring[i]
s[n] = <char>0
return s
def string_c2py(const char* cstring):
return cstring.decode('UTF-8')
cdef class PointGroup: cdef class PointGroup:
cdef readonly qpms_pointgroup_t G cdef readonly qpms_pointgroup_t G

View File

@ -102,6 +102,11 @@ cybspec = Extension('qpms.cybspec',
extra_link_args=['qpms/libqpms.a'], extra_link_args=['qpms/libqpms.a'],
libraries=['gsl', 'lapacke', 'blas', 'gslcblas', 'pthread',] libraries=['gsl', 'lapacke', 'blas', 'gslcblas', 'pthread',]
) )
cymaterials = Extension('qpms.cymaterials',
sources = ['qpms/cymaterials.pyx'],
extra_link_args=['qpms/libqpms.a'],
libraries=['gsl', 'lapacke', 'blas', 'gslcblas', 'pthread',]
)
cyquaternions = Extension('qpms.cyquaternions', cyquaternions = Extension('qpms.cyquaternions',
sources = ['qpms/cyquaternions.pyx'], sources = ['qpms/cyquaternions.pyx'],
extra_link_args=['amos/libamos.a', 'qpms/libqpms.a'], extra_link_args=['amos/libamos.a', 'qpms/libqpms.a'],
@ -130,7 +135,7 @@ setup(name='qpms',
#'quaternion','spherical_functions', #'quaternion','spherical_functions',
'scipy>=0.18.0', 'sympy>=1.2'], 'scipy>=0.18.0', 'sympy>=1.2'],
#dependency_links=['https://github.com/moble/quaternion/archive/v2.0.tar.gz','https://github.com/moble/spherical_functions/archive/master.zip'], #dependency_links=['https://github.com/moble/quaternion/archive/v2.0.tar.gz','https://github.com/moble/spherical_functions/archive/master.zip'],
ext_modules=cythonize([qpms_c, cytranslations, cycommon, cyquaternions, cybspec], include_path=['qpms', 'amos'], gdb_debug=True), ext_modules=cythonize([qpms_c, cytranslations, cycommon, cyquaternions, cybspec, cymaterials], include_path=['qpms', 'amos'], gdb_debug=True),
cmdclass = {'build_ext': build_ext}, cmdclass = {'build_ext': build_ext},
zip_safe=False zip_safe=False
) )