From 47daced370be2ad17ba0fff01792691a58a7ac7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Ne=C4=8Dada?= Date: Sun, 4 Dec 2016 22:27:03 +0200 Subject: [PATCH] Non-mandatory numba import and conditional @jit decorator. Former-commit-id: 327ece069ac0aa7ccc3c2aedefc4255a935699e7 --- qpms/qpms_p.py | 46 ++++++++++++++++++++++++++++++++++++++-------- setup.py | 2 +- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/qpms/qpms_p.py b/qpms/qpms_p.py index e0c2437..4bea15a 100644 --- a/qpms/qpms_p.py +++ b/qpms/qpms_p.py @@ -10,10 +10,39 @@ import math import cmath import quaternion, spherical_functions as sf # because of the Wigner matrices import sys, time -from numba import jit + +''' +Try to import numba. Its pre-0.28.0 versions can not handle functions +containing utf8 identifiers, so we keep track about that. +''' +try: + import numba + use_jit = True + if numba.__version__ >= '0.28.0': + use_jit_utf8 = True + else: + use_jit_utf8 = False +except ImportError: + use_jit = False + use_jit_utf8 = False +''' +Accordingly, we define our own jit decorator that handles +different versions of numba or does nothing if numba is not +present. Note that functions that include unicode identifiers +must be decorated with @jit(u=True) +''' +def jit(u=False): + def resdec(f): + if u and use_jit_utf8: + return numba.jit(f) + if (not u) and use_jit: + return numba.jit(f) + return f + return resdec + # Coordinate transforms for arrays of "arbitrary" shape -#@jit +@jit(u=True) def cart2sph(cart,axis=-1): if (cart.shape[axis] != 3): raise ValueError("The converted array has to have dimension 3" @@ -25,7 +54,7 @@ def cart2sph(cart,axis=-1): φ = np.arctan2(y,x) # arctan2 handles zeroes correctly itself return np.concatenate((r,θ,φ),axis=axis) -#@jit +@jit(u=True) def sph2cart(sph, axis=-1): if (sph.shape[axis] != 3): raise ValueError("The converted array has to have dimension 3" @@ -37,7 +66,7 @@ def sph2cart(sph, axis=-1): z = r * np.cos(θ) return np.concatenate((x,y,z),axis=axis) -#@jit +@jit(u=True) def sph_loccart2cart(loccart, sph, axis=-1): """ Transformation of vector specified in local orthogonal coordinates @@ -87,7 +116,7 @@ def sph_loccart2cart(loccart, sph, axis=-1): out=inr̂*r̂+inθ̂*θ̂+inφ̂*φ̂ return out -#@jit +@jit(u=True) def sph_loccart_basis(sph, sphaxis=-1, cartaxis=None): """ Returns the local cartesian basis in terms of global cartesian basis. @@ -304,7 +333,7 @@ def get_π̃τ̃_y1(θ,nmax): τ̃_y = prenorm * dPy * (- math.sin(θ)) # TADY BACHA!!!!!!!!!! * (- math.sin(pos_sph[1])) ??? return (π̃_y,τ̃_y) -#@jit +@jit(u=True) def vswf_yr1(pos_sph,nmax,J=1): """ As vswf_yr, but evaluated only at single position (i.e. pos_sph has @@ -510,7 +539,7 @@ def Ã(m,n,μ,ν,kdlj,θlj,φlj,r_ge_d,J): return presum * np.sum(summandq) # ZDE OPĚT JINAK ZNAMÉNKA než v Xu (J. comp. phys 127, 285) -#@jit +@jit(u=True) def B̃(m,n,μ,ν,kdlj,θlj,φlj,r_ge_d,J): """ The B̃ translation coefficient for spherical vector waves. @@ -661,7 +690,7 @@ def mie_coefficients(a, nmax, #ω, ε_i, ε_e=1, J_ext=1, J_scat=3 TH = -(( η_inv_e * že * zs - η_inv_e * ze * žs)/(-η_inv_i * ži * zs + η_inv_e * zi * žs)) return (RH, RV, TH, TV) -#@jit +@jit(u=True) def G_Mie_scat_precalc_cart_new(source_cart, dest_cart, RH, RV, a, nmax, k_i, k_e, μ_i=1, μ_e=1, J_ext=1, J_scat=3): """ Implementation according to Kristensson, page 50 @@ -698,6 +727,7 @@ def G_Mie_scat_precalc_cart_new(source_cart, dest_cart, RH, RV, a, nmax, k_i, k_ RV[ny][:,ň,ň] * Ñlo_cart_y[:,:,ň].conj() * Ñhi_cart_y[:,ň,:]) / (ny * (ny+1))[:,ň,ň] return 1j* k_e*np.sum(G_y,axis=0) +@jit(u=True) def G_Mie_scat_precalc_cart(source_cart, dest_cart, RH, RV, a, nmax, k_i, k_e, μ_i=1, μ_e=1, J_ext=1, J_scat=3): """ r1_cart (destination), r2_cart (source) and the result are in cartesian coordinates diff --git a/setup.py b/setup.py index 7507c5b..d87201b 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ qpms_c = Extension('qpms_c', sources = ['qpms/qpms_c.pyx']) setup(name='qpms', - version = "0.1", + version = "0.1.2", packages=['qpms'], # setup_requires=['setuptools_cython'], install_requires=['cython>=0.21','quaternion','spherical_functions','py_gmm'],