Check lMax > 0; more robust coordinate handling in lattices.py

Former-commit-id: 9e5ac45158d5c7241427a6148db2d1842dd74cc6
This commit is contained in:
Marek Nečada 2017-05-17 11:32:17 +03:00
parent 3479a30e9b
commit c16e0df2f2
4 changed files with 31 additions and 13 deletions

View File

@ -55,8 +55,8 @@ class Scattering(object):
self.N = self.positions.shape[0]
self.k_0 = k_0
self.lMax = lMax if lMax else nelem2lMax(TMatrices.shape[-1])
self.tc = trans_calculator(lMax)
nelem = lMax * (lMax + 2) #!
self.tc = trans_calculator(self.lMax)
nelem = self.lMax * (self.lMax + 2) #!
self.nelem = nelem #!
self.prepared = False
self.TMatrices = np.broadcast_to(TMatrices, (self.N,2,nelem,2,nelem))
@ -159,6 +159,7 @@ class Scattering(object):
_time_e(btime, verbose)
return ab
"""
class Scattering_2D_lattice_rectcells(Scattering):
def __init__(self, rectcell_dims, rectcell_elem_positions, cellspec, k_0, rectcell_TMatrices = None, TMatrices = None, lMax = None, verbose=False, J_scat=3):
'''
@ -177,6 +178,7 @@ class Scattering_2D_lattice_rectcells(Scattering):
self.nelem = nelem #!
self.prepared = False
self.TMatrices = np.broadcast_to(TMatrices, (self.N,2,nelem,2,nelem))
"""
class Scattering_2D_zsym(Scattering):
def __init__(self, positions, TMatrices, k_0, lMax = None, verbose=False, J_scat=3):
@ -187,7 +189,7 @@ class Scattering_2D_zsym(Scattering):
self.my, self.ny = get_mn_y(self.lMax)
self.TE_NMz = (self.my + self.ny) % 2
self.TM_NMz = 1 - self.TE_NMz
self.tc = trans_calculator(lMax)
self.tc = trans_calculator(self.lMax)
# TODO možnost zadávat T-matice rovnou ve zhuštěné podobě
TMatrices_TE = TMatrices[...,self.TE_NMz[:,nx],self.TE_yz[:,nx],self.TE_NMz[nx,:],self.TE_yz[nx,:]]
TMatrices_TM = TMatrices[...,self.TM_NMz[:,nx],self.TM_yz[:,nx],self.TM_NMz[nx,:],self.TM_yz[nx,:]]
@ -239,7 +241,7 @@ class Scattering_2D_zsym(Scattering):
elif (TE_or_TM is None):
EoMl = (0,1)
sbtime = _time_b(verbose, step = 'Calculating interparticle translation coefficients')
kdji = cart2sph(self.positions[:,nx,:] - self.positions[nx,:,:])
kdji = cart2sph(self.positions[:,nx,:] - self.positions[nx,:,:], allow2d=True)
kdji[:,:,0] *= self.k_0
# get_AB array structure: [j,yj,i,yi]
# FIXME I could save some memory by calculating only half of these coefficients

View File

@ -482,9 +482,15 @@ cdef class trans_calculator:
object get_A, get_B, get_AB
def __cinit__(self, int lMax, int normalization = 1):
if (lMax <= 0):
raise ValueError('lMax has to be greater than 0.')
self.c = qpms_trans_calculator_init(lMax, normalization)
if self.c is NULL:
raise MemoryError
def __init__(self, int lMax, int normalization = 1):
if self.c is NULL:
raise MemoryError()
self.get_A_data[0].c = self.c
self.get_A_data[0].cmethod = <void *>qpms_trans_calculator_get_A_ext
self.get_A_data_p[0] = &(self.get_A_data[0])
@ -537,7 +543,8 @@ cdef class trans_calculator:
0 # unused
)
def __dealloc__(self):
qpms_trans_calculator_free(self.c)
if self.c is not NULL:
qpms_trans_calculator_free(self.c)
# TODO Reference counts to get_A, get_B, get_AB?
def lMax(self):

View File

@ -48,14 +48,22 @@ def ujit(f):
# Coordinate transforms for arrays of "arbitrary" shape
#@ujit
def cart2sph(cart,axis=-1):
if (cart.shape[axis] != 3):
raise ValueError("The converted array has to have dimension 3"
" along the given axis")
[x, y, z] = np.split(cart,3,axis=axis)
r = np.linalg.norm(cart,axis=axis,keepdims=True)
r_zero = np.logical_not(r)
θ = np.arccos(z/(r+r_zero))
def cart2sph(cart,axis=-1, allow2d=False):
if cart.shape[axis] == 3:
[x, y, z] = np.split(cart,3,axis=axis)
r = np.linalg.norm(cart,axis=axis,keepdims=True)
r_zero = np.logical_not(r)
θ = np.arccos(z/(r+r_zero))
elif cart.shape[axis] == 2 and allow2d:
[x, y] = np.split(cart,2,axis=axis)
r = np.linalg.norm(cart,axis=axis,keepdims=True)
r_zero = np.logical_not(r)
θ = np.broadcast_to(np.pi/2, x.shape)
else:
raise ValueError("The converted array has to have dimension 3 "
"(or 2 if allow2d==True)"
" along the given axis, not %d" % cart.shape[axis])
φ = np.arctan2(y,x) # arctan2 handles zeroes correctly itself
return np.concatenate((r,θ,φ),axis=axis)

View File

@ -353,6 +353,7 @@ int qpms_trans_calculator_multipliers_B(qpms_normalization_t norm, complex doubl
qpms_trans_calculator
*qpms_trans_calculator_init (int lMax, qpms_normalization_t normalization) {
assert(lMax > 0);
qpms_trans_calculator *c = malloc(sizeof(qpms_trans_calculator));
c->lMax = lMax;
c->nelem = lMax * (lMax+2);