Non-mandatory numba import and conditional @jit decorator.

Former-commit-id: 327ece069ac0aa7ccc3c2aedefc4255a935699e7
This commit is contained in:
Marek Nečada 2016-12-04 22:27:03 +02:00
parent febd164740
commit 47daced370
2 changed files with 39 additions and 9 deletions

View File

@ -10,10 +10,39 @@ import math
import cmath
import quaternion, spherical_functions as sf # because of the Wigner matrices
import sys, time
from numba import jit
'''
Try to import numba. Its pre-0.28.0 versions can not handle functions
containing utf8 identifiers, so we keep track about that.
'''
try:
import numba
use_jit = True
if numba.__version__ >= '0.28.0':
use_jit_utf8 = True
else:
use_jit_utf8 = False
except ImportError:
use_jit = False
use_jit_utf8 = False
'''
Accordingly, we define our own jit decorator that handles
different versions of numba or does nothing if numba is not
present. Note that functions that include unicode identifiers
must be decorated with @jit(u=True)
'''
def jit(u=False):
def resdec(f):
if u and use_jit_utf8:
return numba.jit(f)
if (not u) and use_jit:
return numba.jit(f)
return f
return resdec
# Coordinate transforms for arrays of "arbitrary" shape
#@jit
@jit(u=True)
def cart2sph(cart,axis=-1):
if (cart.shape[axis] != 3):
raise ValueError("The converted array has to have dimension 3"
@ -25,7 +54,7 @@ def cart2sph(cart,axis=-1):
φ = np.arctan2(y,x) # arctan2 handles zeroes correctly itself
return np.concatenate((r,θ,φ),axis=axis)
#@jit
@jit(u=True)
def sph2cart(sph, axis=-1):
if (sph.shape[axis] != 3):
raise ValueError("The converted array has to have dimension 3"
@ -37,7 +66,7 @@ def sph2cart(sph, axis=-1):
z = r * np.cos(θ)
return np.concatenate((x,y,z),axis=axis)
#@jit
@jit(u=True)
def sph_loccart2cart(loccart, sph, axis=-1):
"""
Transformation of vector specified in local orthogonal coordinates
@ -87,7 +116,7 @@ def sph_loccart2cart(loccart, sph, axis=-1):
out=inr̂*r̂+inθ̂*θ̂+inφ̂*φ̂
return out
#@jit
@jit(u=True)
def sph_loccart_basis(sph, sphaxis=-1, cartaxis=None):
"""
Returns the local cartesian basis in terms of global cartesian basis.
@ -304,7 +333,7 @@ def get_π̃τ̃_y1(θ,nmax):
τ̃_y = prenorm * dPy * (- math.sin(θ)) # TADY BACHA!!!!!!!!!! * (- math.sin(pos_sph[1])) ???
return (π̃_y,τ̃_y)
#@jit
@jit(u=True)
def vswf_yr1(pos_sph,nmax,J=1):
"""
As vswf_yr, but evaluated only at single position (i.e. pos_sph has
@ -510,7 +539,7 @@ def Ã(m,n,μ,ν,kdlj,θlj,φlj,r_ge_d,J):
return presum * np.sum(summandq)
# ZDE OPĚT JINAK ZNAMÉNKA než v Xu (J. comp. phys 127, 285)
#@jit
@jit(u=True)
def B̃(m,n,μ,ν,kdlj,θlj,φlj,r_ge_d,J):
"""
The B̃ translation coefficient for spherical vector waves.
@ -661,7 +690,7 @@ def mie_coefficients(a, nmax, #ω, ε_i, ε_e=1, J_ext=1, J_scat=3
TH = -(( η_inv_e * že * zs - η_inv_e * ze * žs)/(-η_inv_i * ži * zs + η_inv_e * zi * žs))
return (RH, RV, TH, TV)
#@jit
@jit(u=True)
def G_Mie_scat_precalc_cart_new(source_cart, dest_cart, RH, RV, a, nmax, k_i, k_e, μ_i=1, μ_e=1, J_ext=1, J_scat=3):
"""
Implementation according to Kristensson, page 50
@ -698,6 +727,7 @@ def G_Mie_scat_precalc_cart_new(source_cart, dest_cart, RH, RV, a, nmax, k_i, k_
RV[ny][:,ň,ň] * Ñlo_cart_y[:,:,ň].conj() * Ñhi_cart_y[:,ň,:]) / (ny * (ny+1))[:,ň,ň]
return 1j* k_e*np.sum(G_y,axis=0)
@jit(u=True)
def G_Mie_scat_precalc_cart(source_cart, dest_cart, RH, RV, a, nmax, k_i, k_e, μ_i=1, μ_e=1, J_ext=1, J_scat=3):
"""
r1_cart (destination), r2_cart (source) and the result are in cartesian coordinates

View File

@ -12,7 +12,7 @@ qpms_c = Extension('qpms_c',
sources = ['qpms/qpms_c.pyx'])
setup(name='qpms',
version = "0.1",
version = "0.1.2",
packages=['qpms'],
# setup_requires=['setuptools_cython'],
install_requires=['cython>=0.21','quaternion','spherical_functions','py_gmm'],