lattices: Basic stuff fixed, measuring time

Former-commit-id: 7019be9053b601e266eeddafdd859f6339313df7
This commit is contained in:
Marek Nečada 2016-12-21 14:28:55 +02:00
parent c18e1209f4
commit b642991bb1
2 changed files with 42 additions and 2 deletions

View File

@ -8,6 +8,35 @@ import sys
from qpms_c import get_mn_y # TODO be explicit about what is imported
from .qpms_p import cart2sph, nelem2lMax, Ã, B̃ # TODO be explicit about what is imported
def _time_b(active = True, name = None, step = None):
'''
Auxiliary function for keeping track of elapsed time.
Returns current time (to be used by _time_e).
'''
now = time.time()
if active:
if not name:
name = sys._getframe(1).f_code.co_name
if step:
print('%.4f: %s in function %s started.' % (now, step, name), file = sys.stderr)
else:
print('%.4f: Function %s started.' % (now, name), file=sys.stderr)
sys.stderr.flush()
return now
def _time_e(start_time, active = True, name = None, step = None):
now = time.time()
if active:
if not name:
name = sys._getframe(1).f_code.co_name
if step:
print('%.4f: %s in function %s finished (elapsed %.2f s).'
% (now, step, name, now - start_time), file = sys.stderr)
else:
print('%.4f: Function %s finished (elapsed %.2f s).'
% (now, name, now - start_time), file = sys.stderr)
sys.stderr.flush()
class Scattering(object):
'''
@ -59,6 +88,7 @@ class Scattering(object):
self.TMatrices = np.broadcast_to(TMatrices, (self.N,2,nelem,2,nelem))
def prepare(self, keep_interaction_matrix = False, verbose=False):
btime = _time_b(verbose)
if not self.prepared:
if not self.interaction_matrix:
self.build_interaction_matrix(verbose=verbose)
@ -66,12 +96,15 @@ class Scattering(object):
if not keep_interaction_matrix:
self.interaction_matrix = None
self.prepared = True
_time_e(btime, verbose)
def build_interaction_matrix(self,verbose = False):
btime = _time_b(verbose)
N = self.N
my, ny = get_mn_y(self.lMax)
nelem = len(my)
leftmatrix = np.zeros((N,2,nelem,N,2,nelem), dtype=complex)
sbtime = _time_b(verbose, step = 'Calculating interparticle translation coefficients')
for i in range(N):
for j in range(N):
for yi in range(nelem):
@ -84,6 +117,7 @@ class Scattering(object):
leftmatrix[j,1,yj,i,1,yi] = a
leftmatrix[j,0,yj,i,1,yi] = b
leftmatrix[j,1,yj,i,0,yi] = b
_time_e(sbtime, verbose, step = 'Calculating interparticle translation coefficients')
# at this point, leftmatrix is the translation matrix
n2id = np.identity(2*nelem)
n2id.shape = (2,nelem,2,nelem)
@ -94,8 +128,13 @@ class Scattering(object):
# now we are done, 1-MT
leftmatrix.shape=(N*2*nelem,N*2*nelem)
self.interaction_matrix = leftmatrix
_time_e(btime, verbose)
def scatter(self, pq_0, verbose = False):
pass
def scatter_constmultipole(self, pq_0_c, verbose = False):
btime = _time_b(verbose)
N = self.N
self.prepare(verbose=verbose)
nelem = self.nelem
@ -111,8 +150,9 @@ class Scattering(object):
for j in range(N):
MP_0[j] = np.tensordot(self.TMatrices[j], pq_0[j],axes=([-2,-1],[-2,-1]))
MP_0.shape = (N*2*nelem,)
a[N_or_M,yy] = scipy.linalg.lu_solve(lupiv,MP_0)
ab[N_or_M,yy] = scipy.linalg.lu_solve(self.lupiv,MP_0)
ab.shape = (2,nelem,N,2,nelem)
_time_e(btime, verbose)
return ab
class Scattering_lattice(Scattering):

View File

@ -12,7 +12,7 @@ qpms_c = Extension('qpms_c',
sources = ['qpms/qpms_c.pyx'])
setup(name='qpms',
version = "0.1.6",
version = "0.1.7",
packages=['qpms'],
# setup_requires=['setuptools_cython'],
install_requires=['cython>=0.21','quaternion','spherical_functions','py_gmm'],