Few modifications, did not help...

Former-commit-id: 47121008c9846d8539d85b367763d950e037a1c9
This commit is contained in:
Marek Nečada 2017-05-03 05:19:33 +03:00
parent 40809ce90a
commit c599ee839e
1 changed files with 10 additions and 13 deletions

View File

@ -554,10 +554,9 @@ cdef class trans_calculator:
# TODO CHECK (and try to cast) INPUT ARRAY TYPES (now is done) # TODO CHECK (and try to cast) INPUT ARRAY TYPES (now is done)
# BIG FIXME: make skalars valid arguments, now r, theta, phi, r_ge_d have to be ndarrays # BIG FIXME: make skalars valid arguments, now r, theta, phi, r_ge_d have to be ndarrays
cdef: cdef:
int daxis, saxis, smallaxis, bigaxis, resnd, longest_axis, i, j, d, ax, errval int daxis, saxis, smallaxis, bigaxis, resnd, i, j, d, ax, errval
np.npy_intp sstride, dstride, longi, longstride np.npy_intp sstride, dstride, longi, longstride
int *local_indices int *local_indices
int *innerloop_shape
char *r_p char *r_p
char *theta_p char *theta_p
char *phi_p char *phi_p
@ -630,18 +629,19 @@ cdef class trans_calculator:
print(baseshape) print(baseshape)
print(len(baseshape), resnd,smallaxis, bigaxis) print(len(baseshape), resnd,smallaxis, bigaxis)
print(r.shape, theta.shape,phi.shape,r_ge_d.shape) print(r.shape, theta.shape,phi.shape,r_ge_d.shape)
#'''
longest_axis = 0 cdef int longest_axis = 0
# FIxME: the whole thing with longest_axis will fail if none is longer than 1 # FIxME: the whole thing with longest_axis will fail if none is longer than 1
for i in range(resnd): for i in range(resnd):
if resultshape[i] > resultshape[longest_axis]: if resultshape[i] > resultshape[longest_axis]:
longest_axis = i longest_axis = i
innerloop_shape = <int *> malloc(resnd * sizeof(int)) cdef int* innerloop_shape = <int *> malloc(resnd * sizeof(int))
if innerloop_shape == NULL: if innerloop_shape == NULL:
abort() abort()
for i in range(resnd): for i in range(resnd):
innerloop_shape[i] = resultshape[i] innerloop_shape[i] = resultshape[i]
innerloop_shape[longest_axis] = 1 # longest axis will be iterated in the outer (parallelized) loop. Therefore, longest axis, together with saxis and daxis, will not be iterated in the inner loop innerloop_shape[longest_axis] = 1 # longest axis will be iterated in the outer (parallelized) loop. Therefore, longest axis, together with saxis and daxis, will not be iterated in the inner loop
#'''
resultshape[daxis] = self.c[0].nelem resultshape[daxis] = self.c[0].nelem
resultshape[saxis] = self.c[0].nelem resultshape[saxis] = self.c[0].nelem
cdef np.ndarray r_c = np.broadcast_to(r,resultshape) cdef np.ndarray r_c = np.broadcast_to(r,resultshape)
@ -653,6 +653,8 @@ cdef class trans_calculator:
dstride = a.strides[daxis] dstride = a.strides[daxis]
sstride = a.strides[saxis] sstride = a.strides[saxis]
longstride = a.strides[longest_axis] longstride = a.strides[longest_axis]
if innerloop_shape[daxis] != 1: raise
if innerloop_shape[saxis] != 1: raise
# TODO write this in C (as a function) and parallelize there # TODO write this in C (as a function) and parallelize there
with nogil: #, parallel(): # FIXME rewrite this part in C with nogil: #, parallel(): # FIXME rewrite this part in C
local_indices = <int *> calloc(resnd, sizeof(int)) local_indices = <int *> calloc(resnd, sizeof(int))
@ -670,11 +672,11 @@ cdef class trans_calculator:
b_p = b.data + b.strides[longest_axis] * longi b_p = b.data + b.strides[longest_axis] * longi
for i in range(resnd): for i in range(resnd):
if i == longest_axis: continue if i == longest_axis: continue
if i == saxis or i == daxis: continue
r_p += r_c.strides[i] * local_indices[i] r_p += r_c.strides[i] * local_indices[i]
theta_p += theta_c.strides[i] * local_indices[i] theta_p += theta_c.strides[i] * local_indices[i]
phi_p += phi_c.strides[i] * local_indices[i] phi_p += phi_c.strides[i] * local_indices[i]
r_ge_d_p += r_ge_d_c.strides[i] * local_indices[i] r_ge_d_p += r_ge_d_c.strides[i] * local_indices[i]
if i == saxis or i == daxis: continue
a_p += a.strides[i] * local_indices[i] a_p += a.strides[i] * local_indices[i]
b_p += b.strides[i] * local_indices[i] b_p += b.strides[i] * local_indices[i]
@ -693,16 +695,11 @@ cdef class trans_calculator:
local_indices[ax] += 1 local_indices[ax] += 1
if ax >= 0: # did not overflow, get back to the lowest index if ax >= 0: # did not overflow, get back to the lowest index
ax = resnd - 1 ax = resnd - 1
''' wtf?
for ax in range(a.ndim): for ax in range(a.ndim):
if (ax == longest_axis or ax == daxis or ax == saxis): if (ax == longest_axis or ax == daxis or ax == saxis):
continue continue
'''
free(local_indices) free(local_indices)
free(innerloop_shape) free(innerloop_shape)
return a, b return a, b