From ba06abe13f7a071ebc43250f8b209359bf1f741b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marek=20Ne=C4=8Dada?= Date: Tue, 7 Apr 2020 19:32:58 +0300 Subject: [PATCH] BaseSpec __eq__() and __hash__() Former-commit-id: da4245315207dc75a70ca6a036fb8be17e243bf9 --- qpms/cybspec.pyx | 6 ++++++ qpms/qpms_cdefs.pxd | 1 + qpms/vswf.c | 1 + 3 files changed, 8 insertions(+) diff --git a/qpms/cybspec.pyx b/qpms/cybspec.pyx index d3ff8bf..e2813b8 100644 --- a/qpms/cybspec.pyx +++ b/qpms/cybspec.pyx @@ -83,6 +83,12 @@ cdef class BaseSpec: raise ValueError # If this happens, it's probably a bug, as it should have failed already at qpms_uvswfi2tmn self.s.lMax = max(self.s.lMax, l) + def __eq__(self, BaseSpec other): + return bool(qpms_vswf_set_spec_isidentical(&self.s, &other.s)) + + def __hash__(self): # Very inefficient implementation, but this is not to be used very often + return hash((self.s.norm, self.s.n, tuple(self.__ilist[:self.s.n]))) + def tlm(self): cdef const qpms_uvswfi_t[:] ilist_memview = self.s.ilist #cdef qpms_vswf_type_t[:] t = np.empty(shape=(self.s.n,), dtype=qpms_vswf_type_t) # does not work, workaround: diff --git a/qpms/qpms_cdefs.pxd b/qpms/qpms_cdefs.pxd index 1977421..9f4fa8c 100644 --- a/qpms/qpms_cdefs.pxd +++ b/qpms/qpms_cdefs.pxd @@ -172,6 +172,7 @@ ctypedef union qpms_incfield_planewave_params_E: csphvec_t sph cdef extern from "vswf.h": + bint qpms_vswf_set_spec_isidentical(const qpms_vswf_set_spec_t *a, const qpms_vswf_set_spec_t *b) ctypedef qpms_errno_t (*qpms_incfield_t)(cdouble *target, const qpms_vswf_set_spec_t *bspec, const cart3_t evalpoint, const void *args, bint add) ctypedef struct qpms_incfield_planewave_params_t: diff --git a/qpms/vswf.c b/qpms/vswf.c index 25339d3..37e728a 100644 --- a/qpms/vswf.c +++ b/qpms/vswf.c @@ -55,6 +55,7 @@ qpms_errno_t qpms_vswf_set_spec_append(qpms_vswf_set_spec_t *s, const qpms_uvswf bool qpms_vswf_set_spec_isidentical(const qpms_vswf_set_spec_t *a, const qpms_vswf_set_spec_t *b) { if (a == b) return true; + if (a->norm != b->norm) return false; if (a->n != b->n) return false; for (size_t i = 0; i < a->n; ++i) if (a->ilist[i] != b->ilist[i])