Fix last commit

Former-commit-id: 0fa8ce938f13fbae048cfc8a04a09781ff7631b8
This commit is contained in:
Marek Nečada 2016-09-19 03:46:13 +03:00
parent 72b0c96d49
commit febd164740
1 changed files with 6 additions and 4 deletions

View File

@ -1496,29 +1496,31 @@ def scatter_constmultipole_rectarray(omega, epsilon_b, xN, yN, xd, yd, TMatrices
timec = time.time() timec = time.time()
print('%.4f: factorization complete (elapsed %.2f s)' % (timec, timec-timecold), print('%.4f: factorization complete (elapsed %.2f s)' % (timec, timec-timecold),
file = sys.stderr) file = sys.stderr)
print('%.4f: solving the scattering problem for %d incoming waves' % (timec, K), print('%.4f: solving the scattering problem for %d incoming multipoles' % (timec, nelem*2),
file=sys.stderr) file=sys.stderr)
sys.stderr.flush() sys.stderr.flush()
timecold = timec timecold = timec
if(pq_0_c == 1):
pq_0_c = np.full((2,nelem),1)
ab = np.empty((2,nelem,N*2*nelem), dtype=complex) ab = np.empty((2,nelem,N*2*nelem), dtype=complex)
for N_or_M in range(2): for N_or_M in range(2):
for yy in range(nelem): for yy in range(nelem):
pq_0 = np.zeros((2,nelem), dtype=np.complex_) pq_0 = np.zeros((2,nelem), dtype=np.complex_)
pq_0[N_or_M,yy] = pq_0_c[N_or_M,yy] pq_0[N_or_M,yy] = pq_0_c[N_or_M,yy]
pq_0 = broadcast_to(pq_0, (N, 2, nelem)) pq_0 = np.broadcast_to(pq_0, (N, 2, nelem))
MP_0 = np.empty((N,2,nelem),dtype=np.complex_) MP_0 = np.empty((N,2,nelem),dtype=np.complex_)
for j in range(N): # I wonder how this can be done without this loop... for j in range(N): # I wonder how this can be done without this loop...
MP_0[j] = np.tensordot(TMatrices[xij, yij],pq_0[j],axes=([-2,-1],[-2,-1])) MP_0[j] = np.tensordot(TMatrices[xij, yij],pq_0[j],axes=([-2,-1],[-2,-1]))
MP_0.shape = (N*2*nelem,) MP_0.shape = (N*2*nelem,)
ab[N_or_M, yy] = scipy.linalg.lu_solve(lupiv, MP_0) ab[N_or_M, yy] = scipy.linalg.lu_solve(lupiv, MP_0)
ab.shape = (nelem, xN, yN, 2, nelem) ab.shape = (2,nelem, xN, yN, 2, nelem)
if watch_time: if watch_time:
timec = time.time() timec = time.time()
print('%.4f: done (elapsed %.2f s)' % (timec, timec-timecold),file = sys.stderr) print('%.4f: done (elapsed %.2f s)' % (timec, timec-timecold),file = sys.stderr)
sys.stderr.flush() sys.stderr.flush()
if not (return_pq_0 + return_pq + return_xy): if not (return_pq + return_xy):
return ab return ab
returnlist = [ab] returnlist = [ab]
if (return_pq): if (return_pq):