diff --git a/qpms/polynomials.c b/qpms/polynomials.c index 7e16bda..b3c7eb4 100644 --- a/qpms/polynomials.c +++ b/qpms/polynomials.c @@ -59,6 +59,62 @@ void mpzs_hh_set(mpzs_hh_t x, const mpzs_hh_t y) { mpz_set(x->_2, y->_2); } +/* Append an mpzs element to a mpzs sum that must not contain a matching key + * in advance. + */ +static inline void mpzs_hash_append(struct _qp_mpzs_hashed **hash, + const struct _qp_mpzs_hashed *newelem) { + struct _qp_mpzs_hashed *n; + QPMS_CRASHING_MALLOC(n, sizeof(mpzs_hh_t)); + mpzs_hh_init(n); + mpzs_hh_set(n, newelem); + HASH_ADD_KEYPTR(hh, *hash, mpz_limbs_read(n->_2), + mpz_size(n->_2) * sizeof(mp_limb_t), n); +} + +// Arithmetically add an mpzs element to a mpzs sum/hash. +static inline void mpzs_hash_addelem(struct _qp_mpzs_hashed **hash, + const struct _qp_mpzs_hashed *addend) { + struct _qp_mpzs_hashed *s; + HASH_FIND(hh, *hash, mpz_limbs_read(addend->_2), + mpz_size(addend->_2), s); + if (!s) mpzs_hash_append(hash, addend); // if not found + else { + mpz_add(s->_1, s->_1, addend->_1); + if(!mpz_sgn(s->_1)) { // If zero, annihilate + HASH_DEL(*hash, s); + mpzs_hh_clear(s); + free(s); + } + } +} + +// Multiplies two mpzs hashes, adding them to *prod; +// *prod must differ from x and y (this is not checked) +static inline void mpzs_hash_mul(struct _qp_mpzs_hashed **prod, + const struct _qp_mpzs_hashed *x, const struct _qp_mpzs_hashed *y) { + mpz_t gcd; mpz_init(gcd); // common denominator of the sqrt to factor out + mpz_t mx, my; mpz_init(mx); mpz_init(my); + mpzs_hh_t addend; mpzs_hh_init(addend); + + for (const struct _qp_mpzs_hashed *nx = x; + nx != NULL; nx = nx->hh.next) { + for (const struct _qp_mpzs_hashed *ny = y; + ny != NULL; ny = ny->hh.next) { + mpz_gcd(gcd, nx->_2, ny->_2); + mpz_divexact(mx, nx->_2, gcd); + mpz_divexact(my, ny->_2, gcd); + mpz_mul(addend->_2, mx, my); + mpz_mul(addend->_1, ny->_1, nx->_1); + mpz_mul(addend->_1, addend->_1, gcd); + mpzs_hash_addelem(prod, addend); + } + } + mpzs_hh_clear(addend); + mpz_clear(mx); mpz_clear(my); mpz_clear(gcd); +} + + //===== mpqs_t ===== void mpqs_init(mpqs_t x) { @@ -84,13 +140,9 @@ void mpqs_clear(mpqs_t x) { x->nt = 0; } +// Append an element to mpqs_t's numerator hash without any checks. void mpqs_nt_append(mpqs_t x, const struct _qp_mpzs_hashed *numelem) { - struct _qp_mpzs_hashed *n; - QPMS_CRASHING_MALLOC(n, sizeof(mpzs_hh_t)); - mpzs_hh_init(n); - mpzs_hh_set(n, numelem); - HASH_ADD_KEYPTR(hh, x->nt, mpz_limbs_read(n->_2), - mpz_size(n->_2) * sizeof(mp_limb_t), n); + mpzs_hash_append(&(x->nt), numelem); } void mpqs_set_z(mpqs_t x, const mpz_t num1, const mpz_t num2, @@ -121,19 +173,9 @@ void mpqs_set_si(mpqs_t x, long num1, unsigned long num2, mpzs_hh_clear(tmp); } -void mpqs_nt_add(mpqs_t x, const struct _qp_mpzs_hashed *addend) { - struct _qp_mpzs_hashed *s; - HASH_FIND(hh, x->nt, mpz_limbs_read(addend->_2), - mpz_size(addend->_2), s); - if (!s) mpqs_nt_append(x, addend); // if not found - else { - mpz_add(s->_1, s->_1, addend->_1); - if(!mpz_sgn(s->_1)) { // If zero, annihilate - HASH_DEL(x->nt, s); - mpzs_hh_clear(s); - free(s); - } - } +// Arithmetically adds an element to mpqs's numerator hash. +void mpqs_nt_addelem(mpqs_t x, const struct _qp_mpzs_hashed *addend) { + mpzs_hash_addelem(&(x->nt), addend); } void mpqs_init_set(mpqs_t dest, const mpqs_t src) { @@ -175,6 +217,9 @@ void mpqs_canonicalise(mpqs_t x) { mpz_t gcd; mpz_init(gcd); mpqs_nt_gcd(gcd, x); mpz_mul(mpq_numref(x->f), mpq_numref(x->f), gcd); + for(struct _qp_mpzs_hashed *n = (x->nt); n != NULL; n = n->hh.next) { + mpz_divexact(n->_1, n->_1, gcd); + } mpz_clear(gcd); mpq_canonicalize(x->f); } @@ -213,7 +258,7 @@ void mpqs_add(mpqs_t sum_final, const mpqs_t x, const mpqs_t y) { for(const struct _qp_mpzs_hashed *n = y->nt; n != NULL; n = n->hh.next) { mpz_mul(addend->_1, tmp, n->_1); mpz_set(addend->_2, n->_2); - mpqs_nt_add(sum, addend); + mpqs_nt_addelem(sum, addend); } mpzs_hh_clear(addend); mpz_clear(tmp); @@ -232,31 +277,15 @@ void mpqs_sub(mpqs_t dif, const mpqs_t x, const mpqs_t y) { mpqs_clear(tmp); } + void mpqs_mul(mpqs_t product_final, const mpqs_t x, const mpqs_t y) { mpqs_t prod; mpqs_init(prod); mpz_mul(mpq_numref(prod->f), mpq_numref(x->f), mpq_numref(y->f)); mpz_mul(mpq_denref(prod->f), mpq_denref(x->f), mpq_denref(y->f)); - mpz_t gcd; mpz_init(gcd); // common denominator of the sqrt to factor out - mpz_t mx, my; mpz_init(mx); mpz_init(my); - mpzs_hh_t addend; mpzs_hh_init(addend); + mpzs_hash_mul(&(prod->nt), x->nt, y->nt); - for (const struct _qp_mpzs_hashed *nx = x->nt->hh.next; - nx != NULL; nx = nx->hh.next) { - for (const struct _qp_mpzs_hashed *ny = x->nt->hh.next; - ny != NULL; ny = ny->hh.next) { - mpz_gcd(gcd, nx->_2, ny->_2); - mpz_divexact(mx, nx->_2, gcd); - mpz_divexact(my, ny->_2, gcd); - mpz_mul(addend->_2, mx, my); - mpz_mul(addend->_1, ny->_1, nx->_1); - mpz_mul(addend->_1, addend->_1, gcd); - mpqs_nt_add(prod, addend); - } - } - mpzs_hh_clear(addend); - mpz_clear(mx); mpz_clear(my); mpz_clear(gcd); mpqs_canonicalise(prod); mpqs_set(product_final, prod); mpqs_clear(prod);