Ewald sum optimisation

Avoiding repeated cpow() calls yields more than 5x speedup
in the off-plane case.
This commit is contained in:
Marek Nečada 2020-07-03 11:46:34 +03:00
parent f7883a713b
commit 1328077490
1 changed files with 26 additions and 4 deletions

View File

@ -349,6 +349,17 @@ int ewald3_21_xy_sigma_long (
// space for Gamma_pq[j]'s // space for Gamma_pq[j]'s
complex double Gamma_pq[lMax/2+1]; complex double Gamma_pq[lMax/2+1];
double Gamma_pq_err[lMax/2+1]; double Gamma_pq_err[lMax/2+1];
// cpow() is expensive, so we want to save and reuse these too:
complex double minus_k_shiftz_pow[2*lMax + 1]; // __[i] = cpow(-k * particle_shift.z, i)
{
long complex double x = 1;
for(qpms_l_t i = 0; i <= 2*lMax; ++i) {
minus_k_shiftz_pow[i] = x;
x *= -k * particle_shift.z;
}
}
complex double rbeta_pq_div_k_pow[lMax + 1]; // __[i] = cpow_0lim_zi(rbeta_pq/k, i)
complex double gamma_pq_powm1[2*lMax + 1]; // __[i] = cpow(gamma_pq, i-1)
// CHOOSE POINT BEGIN // CHOOSE POINT BEGIN
// TODO maybe PGen_next_sph is not the best coordinate system choice here // TODO maybe PGen_next_sph is not the best coordinate system choice here
@ -382,8 +393,19 @@ int ewald3_21_xy_sigma_long (
//void ewald3_2_sigma_long_Delta(complex double *target, int maxn, complex double x, complex double z) { //void ewald3_2_sigma_long_Delta(complex double *target, int maxn, complex double x, complex double z) {
if (new_rbeta_pq) { if (new_rbeta_pq) {
gamma_pq = clilgamma(rbeta_pq/k); gamma_pq = clilgamma(rbeta_pq/k);
{ // fill gamma_pq_powm1[] and rbeta_pq_div_k_pow[]
long complex double x = 1./gamma_pq;
for(qpms_l_t i = 0; i <= 2*lMax; ++i) {
gamma_pq_powm1[i] = x;
x *= gamma_pq;
}
for(qpms_l_t i = 0; i <= lMax; ++i) // not fastest, but foolproof
rbeta_pq_div_k_pow[i] = cpow_0lim_zi(rbeta_pq / k, i);
}
complex double x = gamma_pq*k/(2*eta); complex double x = gamma_pq*k/(2*eta);
complex double x2 = x*x; complex double x2 = x*x;
if(particle_shift.z == 0) { if(particle_shift.z == 0) {
for(qpms_l_t j = 0; j <= lMax/2; ++j) { for(qpms_l_t j = 0; j <= lMax/2; ++j) {
qpms_csf_result Gam; qpms_csf_result Gam;
@ -421,9 +443,9 @@ int ewald3_21_xy_sigma_long (
if (particle_shift.z == 0) { // TODO remove when the general case is stable and tested if (particle_shift.z == 0) { // TODO remove when the general case is stable and tested
assert((n-abs(m))/2 == c->s1_jMaxes[y]); assert((n-abs(m))/2 == c->s1_jMaxes[y]);
for(qpms_l_t j = 0; j <= c->s1_jMaxes[y]/*(n-abs(m))/2*/; ++j) { // FIXME </<= ? for(qpms_l_t j = 0; j <= c->s1_jMaxes[y]/*(n-abs(m))/2*/; ++j) { // FIXME </<= ?
complex double summand = cpow_0lim_zi(rbeta_pq/k, n-2*j) complex double summand = rbeta_pq_div_k_pow[n-2*j]
* e_imalpha_pq * c->legendre0[gsl_sf_legendre_array_index(n,abs(m))] * min1pow_m_neg(m) // This line can actually go outside j-loop * e_imalpha_pq * c->legendre0[gsl_sf_legendre_array_index(n,abs(m))] * min1pow_m_neg(m) // This line can actually go outside j-loop
* cpow(gamma_pq, 2*j-1) // * Gamma_pq[j] bellow (GGG) after error computation * gamma_pq_powm1[2*j]// * Gamma_pq[j] bellow (GGG) after error computation
* c->s1_constfacs[y][j]; * c->s1_constfacs[y][j];
if(err) { if(err) {
// FIXME include also other errors than Gamma_pq's relative error // FIXME include also other errors than Gamma_pq's relative error
@ -448,11 +470,11 @@ int ewald3_21_xy_sigma_long (
(s <= 2 * j) && (s <= L_M); (s <= 2 * j) && (s <= L_M);
s += 2) { s += 2) {
complex double ssummand = c->S1_constfacs[y][constidx] complex double ssummand = c->S1_constfacs[y][constidx]
* cpow(-k * particle_shift.z, 2*j - s) * cpow_0lim_zi(rbeta_pq / k, n - s); * minus_k_shiftz_pow[2*j - s] * rbeta_pq_div_k_pow[n - s];
ckahanadd(&ssum, &ssum_c, ssummand); ckahanadd(&ssum, &ssum_c, ssummand);
++constidx; ++constidx;
} }
const complex double jfactor = e_imalpha_pq * Gamma_pq[j] * cpow(gamma_pq, 2*j - 1); const complex double jfactor = e_imalpha_pq * Gamma_pq[j] * gamma_pq_powm1[2*j];
if (err) { // FIXME include also other sources of error than Gamma_pq's relative error if (err) { // FIXME include also other sources of error than Gamma_pq's relative error
double jfactor_err = Gamma_pq_err[j] * pow(cabs(gamma_pq), 2*j - 1); double jfactor_err = Gamma_pq_err[j] * pow(cabs(gamma_pq), 2*j - 1);
kahanadd(&jsum_err, &jsum_err_c, jfactor_err * ssum); kahanadd(&jsum_err, &jsum_err_c, jfactor_err * ssum);