Some cython code cleanup

Former-commit-id: 187ae611bcd6112caf7b0a0b37223f0bc1836392
This commit is contained in:
Marek Nečada 2017-05-08 22:17:19 +03:00
parent 92b33d7993
commit 535fb3c3d7
1 changed files with 4 additions and 73 deletions

View File

@ -564,7 +564,7 @@ cdef class trans_calculator:
# BIG FIXME: make skalars valid arguments, now r, theta, phi, r_ge_d have to be ndarrays # BIG FIXME: make skalars valid arguments, now r, theta, phi, r_ge_d have to be ndarrays
cdef: cdef:
int daxis, saxis, smallaxis, bigaxis, resnd, i, j, d, ax, errval int daxis, saxis, smallaxis, bigaxis, resnd, i, j, d, ax, errval
np.npy_intp sstride, dstride, longi, longstride np.npy_intp sstride, dstride, longi
int *local_indices int *local_indices
char *r_p char *r_p
char *theta_p char *theta_p
@ -612,14 +612,10 @@ cdef class trans_calculator:
resnd = len(baseshape)+2 resnd = len(baseshape)+2
daxis = (resnd-2) if destaxis is None else destaxis daxis = (resnd-2) if destaxis is None else destaxis
saxis = (resnd-1) if srcaxis is None else srcaxis saxis = (resnd-1) if srcaxis is None else srcaxis
print(daxis)
print(saxis)
if daxis < 0: if daxis < 0:
daxis = resnd + daxis daxis = resnd + daxis
print(daxis)
if saxis < 0: if saxis < 0:
saxis = resnd + saxis saxis = resnd + saxis
print(saxis)
if daxis < 0 or saxis < 0 or daxis >= resnd or saxis >= resnd or daxis == saxis: if daxis < 0 or saxis < 0 or daxis >= resnd or saxis >= resnd or daxis == saxis:
raise ValueError('invalid axes provided') # TODO better error formulation raise ValueError('invalid axes provided') # TODO better error formulation
resultshape = list(baseshape) resultshape = list(baseshape)
@ -635,22 +631,7 @@ cdef class trans_calculator:
theta = np.expand_dims(np.expand_dims(theta.astype(np.float_, copy=False), smallaxis), bigaxis) theta = np.expand_dims(np.expand_dims(theta.astype(np.float_, copy=False), smallaxis), bigaxis)
phi = np.expand_dims(np.expand_dims(phi.astype(np.float_, copy=False), smallaxis), bigaxis) phi = np.expand_dims(np.expand_dims(phi.astype(np.float_, copy=False), smallaxis), bigaxis)
r_ge_d = np.expand_dims(np.expand_dims(r_ge_d.astype(np.bool_, copy=False), smallaxis), bigaxis) r_ge_d = np.expand_dims(np.expand_dims(r_ge_d.astype(np.bool_, copy=False), smallaxis), bigaxis)
print(baseshape)
print(len(baseshape), resnd,smallaxis, bigaxis)
print(r.shape, theta.shape,phi.shape,r_ge_d.shape)
'''
cdef int longest_axis = 0
# FIxME: the whole thing with longest_axis will fail if none is longer than 1
for i in range(resnd):
if resultshape[i] > resultshape[longest_axis]:
longest_axis = i
cdef int* innerloop_shape = <int *> malloc(resnd * sizeof(int))
if innerloop_shape == NULL:
abort()
for i in range(resnd):
innerloop_shape[i] = resultshape[i]
innerloop_shape[longest_axis] = 1 # longest axis will be iterated in the outer (parallelized) loop. Therefore, longest axis, together with saxis and daxis, will not be iterated in the inner loop
'''
resultshape[daxis] = self.c[0].nelem resultshape[daxis] = self.c[0].nelem
resultshape[saxis] = self.c[0].nelem resultshape[saxis] = self.c[0].nelem
cdef np.ndarray r_c = np.broadcast_to(r,resultshape) cdef np.ndarray r_c = np.broadcast_to(r,resultshape)
@ -661,11 +642,7 @@ cdef class trans_calculator:
cdef np.ndarray b = np.empty(resultshape, dtype=complex) cdef np.ndarray b = np.empty(resultshape, dtype=complex)
dstride = a.strides[daxis] dstride = a.strides[daxis]
sstride = a.strides[saxis] sstride = a.strides[saxis]
#longstride = a.strides[longest_axis] with nogil:
#if innerloop_shape[daxis] != 1: raise
#if innerloop_shape[saxis] != 1: raise
# TODO write this in C (as a function) and parallelize there
with nogil: #, parallel(): # FIXME rewrite this part in C
errval = qpms_cython_trans_calculator_get_AB_arrays_loop( errval = qpms_cython_trans_calculator_get_AB_arrays_loop(
self.c, J, resnd, self.c, J, resnd,
daxis, saxis, daxis, saxis,
@ -676,52 +653,6 @@ cdef class trans_calculator:
phi_c.data, phi_c.shape, phi_c.strides, phi_c.data, phi_c.shape, phi_c.strides,
r_ge_d_c.data, r_ge_d_c.shape, r_ge_d_c.strides r_ge_d_c.data, r_ge_d_c.shape, r_ge_d_c.strides
) )
"""
local_indices = <int *> calloc(resnd, sizeof(int))
if local_indices == NULL: abort()
for longi in range(a.shape[longest_axis]): # outer loop (to be parallelized)
# this might be done also in the inverse order, but this is more 'c-contiguous' way of incrementing the indices
ax = resnd - 1
while ax >= 0:
# calculate the correct index/pointer for each array used. This can be further optimized from O(resnd * total size of the result array) to O(total size of the result array), but fick that now
r_p = r_c.data + r_c.strides[longest_axis] * longi
theta_p = theta_c.data + theta_c.strides[longest_axis] * longi
phi_p = phi_c.data + phi_c.strides[longest_axis] * longi
r_ge_d_p = r_ge_d_c.data + r_ge_d_c.strides[longest_axis] * longi
a_p = a.data + a.strides[longest_axis] * longi
b_p = b.data + b.strides[longest_axis] * longi
for i in range(resnd):
if i == longest_axis: continue
if i == saxis or i == daxis: continue
r_p += r_c.strides[i] * local_indices[i]
theta_p += theta_c.strides[i] * local_indices[i]
phi_p += phi_c.strides[i] * local_indices[i]
r_ge_d_p += r_ge_d_c.strides[i] * local_indices[i]
a_p += a.strides[i] * local_indices[i]
b_p += b.strides[i] * local_indices[i]
# perform the actual task here
errval = qpms_trans_calculator_get_AB_arrays_ext(self.c,
<cdouble*>a_p, <cdouble*>b_p,
dstride // sizeof(cdouble), sstride // sizeof(cdouble),
(<double*>r_p)[0], (<double*>theta_p)[0], (<double*>phi_p)[0], <int>((<np.npy_bool*>r_ge_d_p)[0]), J)
if errval: abort()
# increment the last index 'digit' (ax is now resnd-1; we don't have do-while loop in python)
local_indices[ax] += 1
while (local_indices[ax] == innerloop_shape[ax] and ax >= 0): # overflow to the next digit but stop when we reach below the last one
local_indices[ax] = 0
ax -= 1
local_indices[ax] += 1
if ax >= 0: # did not overflow, get back to the lowest index
ax = resnd - 1
''' wtf?
for ax in range(a.ndim):
if (ax == longest_axis or ax == daxis or ax == saxis):
continue
'''
free(local_indices)
free(innerloop_shape)
"""
return a, b return a, b
# TODO make possible to access the attributes (to show normalization etc) # TODO make possible to access the attributes (to show normalization etc)