lattices: Basic stuff fixed, measuring time
Former-commit-id: 7019be9053b601e266eeddafdd859f6339313df7
This commit is contained in:
parent
c18e1209f4
commit
b642991bb1
|
@ -8,6 +8,35 @@ import sys
|
|||
from qpms_c import get_mn_y # TODO be explicit about what is imported
|
||||
from .qpms_p import cart2sph, nelem2lMax, Ã, B̃ # TODO be explicit about what is imported
|
||||
|
||||
def _time_b(active = True, name = None, step = None):
|
||||
'''
|
||||
Auxiliary function for keeping track of elapsed time.
|
||||
Returns current time (to be used by _time_e).
|
||||
'''
|
||||
now = time.time()
|
||||
if active:
|
||||
if not name:
|
||||
name = sys._getframe(1).f_code.co_name
|
||||
if step:
|
||||
print('%.4f: %s in function %s started.' % (now, step, name), file = sys.stderr)
|
||||
else:
|
||||
print('%.4f: Function %s started.' % (now, name), file=sys.stderr)
|
||||
sys.stderr.flush()
|
||||
return now
|
||||
|
||||
def _time_e(start_time, active = True, name = None, step = None):
|
||||
now = time.time()
|
||||
if active:
|
||||
if not name:
|
||||
name = sys._getframe(1).f_code.co_name
|
||||
if step:
|
||||
print('%.4f: %s in function %s finished (elapsed %.2f s).'
|
||||
% (now, step, name, now - start_time), file = sys.stderr)
|
||||
else:
|
||||
print('%.4f: Function %s finished (elapsed %.2f s).'
|
||||
% (now, name, now - start_time), file = sys.stderr)
|
||||
sys.stderr.flush()
|
||||
|
||||
class Scattering(object):
|
||||
'''
|
||||
|
||||
|
@ -59,6 +88,7 @@ class Scattering(object):
|
|||
self.TMatrices = np.broadcast_to(TMatrices, (self.N,2,nelem,2,nelem))
|
||||
|
||||
def prepare(self, keep_interaction_matrix = False, verbose=False):
|
||||
btime = _time_b(verbose)
|
||||
if not self.prepared:
|
||||
if not self.interaction_matrix:
|
||||
self.build_interaction_matrix(verbose=verbose)
|
||||
|
@ -66,12 +96,15 @@ class Scattering(object):
|
|||
if not keep_interaction_matrix:
|
||||
self.interaction_matrix = None
|
||||
self.prepared = True
|
||||
_time_e(btime, verbose)
|
||||
|
||||
def build_interaction_matrix(self,verbose = False):
|
||||
btime = _time_b(verbose)
|
||||
N = self.N
|
||||
my, ny = get_mn_y(self.lMax)
|
||||
nelem = len(my)
|
||||
leftmatrix = np.zeros((N,2,nelem,N,2,nelem), dtype=complex)
|
||||
sbtime = _time_b(verbose, step = 'Calculating interparticle translation coefficients')
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
for yi in range(nelem):
|
||||
|
@ -84,6 +117,7 @@ class Scattering(object):
|
|||
leftmatrix[j,1,yj,i,1,yi] = a
|
||||
leftmatrix[j,0,yj,i,1,yi] = b
|
||||
leftmatrix[j,1,yj,i,0,yi] = b
|
||||
_time_e(sbtime, verbose, step = 'Calculating interparticle translation coefficients')
|
||||
# at this point, leftmatrix is the translation matrix
|
||||
n2id = np.identity(2*nelem)
|
||||
n2id.shape = (2,nelem,2,nelem)
|
||||
|
@ -94,8 +128,13 @@ class Scattering(object):
|
|||
# now we are done, 1-MT
|
||||
leftmatrix.shape=(N*2*nelem,N*2*nelem)
|
||||
self.interaction_matrix = leftmatrix
|
||||
_time_e(btime, verbose)
|
||||
|
||||
def scatter(self, pq_0, verbose = False):
|
||||
pass
|
||||
|
||||
def scatter_constmultipole(self, pq_0_c, verbose = False):
|
||||
btime = _time_b(verbose)
|
||||
N = self.N
|
||||
self.prepare(verbose=verbose)
|
||||
nelem = self.nelem
|
||||
|
@ -111,8 +150,9 @@ class Scattering(object):
|
|||
for j in range(N):
|
||||
MP_0[j] = np.tensordot(self.TMatrices[j], pq_0[j],axes=([-2,-1],[-2,-1]))
|
||||
MP_0.shape = (N*2*nelem,)
|
||||
a[N_or_M,yy] = scipy.linalg.lu_solve(lupiv,MP_0)
|
||||
ab[N_or_M,yy] = scipy.linalg.lu_solve(self.lupiv,MP_0)
|
||||
ab.shape = (2,nelem,N,2,nelem)
|
||||
_time_e(btime, verbose)
|
||||
return ab
|
||||
|
||||
class Scattering_lattice(Scattering):
|
||||
|
|
Loading…
Reference in New Issue