Non-mandatory numba import and conditional @jit decorator.

Former-commit-id: 327ece069ac0aa7ccc3c2aedefc4255a935699e7
This commit is contained in:
Marek Nečada 2016-12-04 22:27:03 +02:00
parent febd164740
commit 47daced370
2 changed files with 39 additions and 9 deletions

View File

@ -10,10 +10,39 @@ import math
import cmath import cmath
import quaternion, spherical_functions as sf # because of the Wigner matrices import quaternion, spherical_functions as sf # because of the Wigner matrices
import sys, time import sys, time
from numba import jit
'''
Try to import numba. Its pre-0.28.0 versions can not handle functions
containing utf8 identifiers, so we keep track about that.
'''
try:
import numba
use_jit = True
if numba.__version__ >= '0.28.0':
use_jit_utf8 = True
else:
use_jit_utf8 = False
except ImportError:
use_jit = False
use_jit_utf8 = False
'''
Accordingly, we define our own jit decorator that handles
different versions of numba or does nothing if numba is not
present. Note that functions that include unicode identifiers
must be decorated with @jit(u=True)
'''
def jit(u=False):
def resdec(f):
if u and use_jit_utf8:
return numba.jit(f)
if (not u) and use_jit:
return numba.jit(f)
return f
return resdec
# Coordinate transforms for arrays of "arbitrary" shape # Coordinate transforms for arrays of "arbitrary" shape
#@jit @jit(u=True)
def cart2sph(cart,axis=-1): def cart2sph(cart,axis=-1):
if (cart.shape[axis] != 3): if (cart.shape[axis] != 3):
raise ValueError("The converted array has to have dimension 3" raise ValueError("The converted array has to have dimension 3"
@ -25,7 +54,7 @@ def cart2sph(cart,axis=-1):
φ = np.arctan2(y,x) # arctan2 handles zeroes correctly itself φ = np.arctan2(y,x) # arctan2 handles zeroes correctly itself
return np.concatenate((r,θ,φ),axis=axis) return np.concatenate((r,θ,φ),axis=axis)
#@jit @jit(u=True)
def sph2cart(sph, axis=-1): def sph2cart(sph, axis=-1):
if (sph.shape[axis] != 3): if (sph.shape[axis] != 3):
raise ValueError("The converted array has to have dimension 3" raise ValueError("The converted array has to have dimension 3"
@ -37,7 +66,7 @@ def sph2cart(sph, axis=-1):
z = r * np.cos(θ) z = r * np.cos(θ)
return np.concatenate((x,y,z),axis=axis) return np.concatenate((x,y,z),axis=axis)
#@jit @jit(u=True)
def sph_loccart2cart(loccart, sph, axis=-1): def sph_loccart2cart(loccart, sph, axis=-1):
""" """
Transformation of vector specified in local orthogonal coordinates Transformation of vector specified in local orthogonal coordinates
@ -87,7 +116,7 @@ def sph_loccart2cart(loccart, sph, axis=-1):
out=inr̂*r̂+inθ̂*θ̂+inφ̂*φ̂ out=inr̂*r̂+inθ̂*θ̂+inφ̂*φ̂
return out return out
#@jit @jit(u=True)
def sph_loccart_basis(sph, sphaxis=-1, cartaxis=None): def sph_loccart_basis(sph, sphaxis=-1, cartaxis=None):
""" """
Returns the local cartesian basis in terms of global cartesian basis. Returns the local cartesian basis in terms of global cartesian basis.
@ -304,7 +333,7 @@ def get_π̃τ̃_y1(θ,nmax):
τ̃_y = prenorm * dPy * (- math.sin(θ)) # TADY BACHA!!!!!!!!!! * (- math.sin(pos_sph[1])) ??? τ̃_y = prenorm * dPy * (- math.sin(θ)) # TADY BACHA!!!!!!!!!! * (- math.sin(pos_sph[1])) ???
return (π̃_y,τ̃_y) return (π̃_y,τ̃_y)
#@jit @jit(u=True)
def vswf_yr1(pos_sph,nmax,J=1): def vswf_yr1(pos_sph,nmax,J=1):
""" """
As vswf_yr, but evaluated only at single position (i.e. pos_sph has As vswf_yr, but evaluated only at single position (i.e. pos_sph has
@ -510,7 +539,7 @@ def Ã(m,n,μ,ν,kdlj,θlj,φlj,r_ge_d,J):
return presum * np.sum(summandq) return presum * np.sum(summandq)
# ZDE OPĚT JINAK ZNAMÉNKA než v Xu (J. comp. phys 127, 285) # ZDE OPĚT JINAK ZNAMÉNKA než v Xu (J. comp. phys 127, 285)
#@jit @jit(u=True)
def B̃(m,n,μ,ν,kdlj,θlj,φlj,r_ge_d,J): def B̃(m,n,μ,ν,kdlj,θlj,φlj,r_ge_d,J):
""" """
The B̃ translation coefficient for spherical vector waves. The B̃ translation coefficient for spherical vector waves.
@ -661,7 +690,7 @@ def mie_coefficients(a, nmax, #ω, ε_i, ε_e=1, J_ext=1, J_scat=3
TH = -(( η_inv_e * že * zs - η_inv_e * ze * žs)/(-η_inv_i * ži * zs + η_inv_e * zi * žs)) TH = -(( η_inv_e * že * zs - η_inv_e * ze * žs)/(-η_inv_i * ži * zs + η_inv_e * zi * žs))
return (RH, RV, TH, TV) return (RH, RV, TH, TV)
#@jit @jit(u=True)
def G_Mie_scat_precalc_cart_new(source_cart, dest_cart, RH, RV, a, nmax, k_i, k_e, μ_i=1, μ_e=1, J_ext=1, J_scat=3): def G_Mie_scat_precalc_cart_new(source_cart, dest_cart, RH, RV, a, nmax, k_i, k_e, μ_i=1, μ_e=1, J_ext=1, J_scat=3):
""" """
Implementation according to Kristensson, page 50 Implementation according to Kristensson, page 50
@ -698,6 +727,7 @@ def G_Mie_scat_precalc_cart_new(source_cart, dest_cart, RH, RV, a, nmax, k_i, k_
RV[ny][:,ň,ň] * Ñlo_cart_y[:,:,ň].conj() * Ñhi_cart_y[:,ň,:]) / (ny * (ny+1))[:,ň,ň] RV[ny][:,ň,ň] * Ñlo_cart_y[:,:,ň].conj() * Ñhi_cart_y[:,ň,:]) / (ny * (ny+1))[:,ň,ň]
return 1j* k_e*np.sum(G_y,axis=0) return 1j* k_e*np.sum(G_y,axis=0)
@jit(u=True)
def G_Mie_scat_precalc_cart(source_cart, dest_cart, RH, RV, a, nmax, k_i, k_e, μ_i=1, μ_e=1, J_ext=1, J_scat=3): def G_Mie_scat_precalc_cart(source_cart, dest_cart, RH, RV, a, nmax, k_i, k_e, μ_i=1, μ_e=1, J_ext=1, J_scat=3):
""" """
r1_cart (destination), r2_cart (source) and the result are in cartesian coordinates r1_cart (destination), r2_cart (source) and the result are in cartesian coordinates

View File

@ -12,7 +12,7 @@ qpms_c = Extension('qpms_c',
sources = ['qpms/qpms_c.pyx']) sources = ['qpms/qpms_c.pyx'])
setup(name='qpms', setup(name='qpms',
version = "0.1", version = "0.1.2",
packages=['qpms'], packages=['qpms'],
# setup_requires=['setuptools_cython'], # setup_requires=['setuptools_cython'],
install_requires=['cython>=0.21','quaternion','spherical_functions','py_gmm'], install_requires=['cython>=0.21','quaternion','spherical_functions','py_gmm'],