argproc.py WIP constant t-matrix, mutual exclusivity checks

This commit is contained in:
Marek Nečada 2022-06-01 16:07:16 +03:00
parent 68d894ec96
commit ccfb13a15c
1 changed files with 99 additions and 29 deletions

View File

@ -6,6 +6,7 @@ import argparse
import sys import sys
import warnings import warnings
import ast import ast
from collections import namedtuple
def flatten(S): def flatten(S):
if S == []: if S == []:
@ -175,6 +176,33 @@ def sint(string):
else: raise exc else: raise exc
return res return res
def _resolve_exclusive_groups(namespace, poslabel, *groups):
"""
namespace: arguments namespace whose elements
are dicts that possibly
each element of groups shall be a sequence of strings,
each of the strings should correspond
if mutual exclusivity is respected, returns index
of the "selected" group.
"""
selected = None
example = None
for i, group in enumerate(groups):
for name in group:
ledict = getattr(namespace, name, None)
if ledict is not None: # TODO Maybe it shuld always be not none?
res = ledict.get(poslabel, None)
if res is not None:
if selected is None:
selected = i
example = name
elif selected != i:
raise ArgumentProcessingError("Mutually exclusive parameters (%s and %s) specified for label %s." %
(name, example, poslabel))
if selected is None and poslabel is not None: # try default params
return _resolve_exclusive_groups(namespace, None, *groups)
return selected
def material_spec(string): def material_spec(string):
"""Tries to parse a string as a material specification, i.e. a """Tries to parse a string as a material specification, i.e. a
real or complex number or one of the string in built-in Lorentz-Drude models. real or complex number or one of the string in built-in Lorentz-Drude models.
@ -193,14 +221,21 @@ def material_spec(string):
raise argparse.ArgumentTypeError("Material specification must be a supported material name %s, or a number" % (str(lorentz_drude.keys()),)) from ve raise argparse.ArgumentTypeError("Material specification must be a supported material name %s, or a number" % (str(lorentz_drude.keys()),)) from ve
return lemat return lemat
def string2hashable_complex_matrix(string): class hashable_complex_matrix(tuple):
""" """
Converts string to a hashable equivalent of two-dimensional Converts string to a hashable equivalent of two-dimensional
numpy.ndarray(..., dtype=complex) (which by itself is not hashable) numpy.ndarray(..., dtype=complex) (which by itself is not hashable)
""" """
matrix = np.ndarray(ast.literal_eval(string), dtype=complex) def __new__(self, matrix):
# TODO here could be some dimensionality checks etc. if isinstance(matrix, str):
return tuple(tuple(row) for row in matrix) matrix = ast.literal_eval(string)
matrix = np.ndarray(matrix, dtype=complex, copy=False)
return super().__new__(hashable_complex_matrix, (tuple(row) for row in matrix))
# auxiliary structures for t-matrix generator specifications
constant_tmatrix_spec = namedtuple("constant_tmatrix_spec", ("bspec", "matrix"))
cyl_sph_dimensions = namedtuple("cyl_sph_dimensions", ("radius", "height", "lMax_extend"))
tmgen_spec = namedtuple("tmgen_spec", ("bgspec", "fgspec", "dims"))
def string2bspec(string): def string2bspec(string):
""" """
@ -273,10 +308,10 @@ class ArgParser:
action=make_dict_action(argtype=string2basespec, postaction='store', first_is_key=True), action=make_dict_action(argtype=string2basespec, postaction='store', first_is_key=True),
help='Manual specification of VSWF set codes (format as a python list of integers); see docs on qpms_uvswfi_t for valid codes or simply use ++lMax instead. Overrides ++lMax and --lMax.') help='Manual specification of VSWF set codes (format as a python list of integers); see docs on qpms_uvswfi_t for valid codes or simply use ++lMax instead. Overrides ++lMax and --lMax.')
mpgrp.add_argmuent("-T", "--constant-tmatrix", nargs=1, default={}, mpgrp.add_argmuent("-T", "--constant-tmatrix", nargs=1, default={},
action=make_dict_action(argtype=string2hashable_complex_matrix, postaction='store', first_is_key=False), action=make_dict_action(argtype=hashable_complex_matrix, postaction='store', first_is_key=False),
help='constant T-matrix (elements must correspond to --vswf-set)') help='constant T-matrix (elements must correspond to --vswf-set)')
mpgrp.add_argmuent("+T", "++constant-tmatrix", nargs=2, default={}, mpgrp.add_argmuent("+T", "++constant-tmatrix", nargs=2, default={},
action=make_dict_action(argtype=string2hashable_complex_matrix, postaction='store', first_is_key=True), action=make_dict_action(argtype=hashable_complex_matrix, postaction='store', first_is_key=True),
help='constant T-matrix (elements must correspond to ++vswf-set)') help='constant T-matrix (elements must correspond to ++vswf-set)')
atomic_arguments = { atomic_arguments = {
@ -355,7 +390,10 @@ class ArgParser:
if tmgspec in self._tmg_register.keys(): if tmgspec in self._tmg_register.keys():
return self._tmg_register[tmgspec] return self._tmg_register[tmgspec]
else: else:
from .cytmatrices import TMatrixGenerator from .cytmatrices import TMatrixGenerator, CTMatrix
if isinstance(tmgspec, constant_tmatrix_spec):
tmgen = TMatrixGenerator(CTMatrix(tmgspec.bspec, tmgspec.matrix))
else:
bgspec, fgspec, (radius, height, lMax_extend) = tmgspec bgspec, fgspec, (radius, height, lMax_extend) = tmgspec
bg = self._add_emg(bgspec) bg = self._add_emg(bgspec)
fg = self._add_emg(fgspec) fg = self._add_emg(fgspec)
@ -446,7 +484,8 @@ class ArgParser:
from .cymaterials import EpsMuGenerator, lorentz_drude from .cymaterials import EpsMuGenerator, lorentz_drude
from .cytmatrices import TMatrixGenerator from .cytmatrices import TMatrixGenerator
self.foreground_emg = self._add_emg(a.material) self.foreground_emg = self._add_emg(a.material)
self.tmgen = self._add_tmg((a.background, a.material, (a.radius, a.height, a.lMax_extend))) self.tmgen = self._add_tmg(tmgen_spec(a.background, a.material,
cyl_sph_dimensions(a.radius, a.height, a.lMax_extend)))
self.bspec = self._add_bspec(a.lMax) self.bspec = self._add_bspec(a.lMax)
def _eval_single_omega(self): # feature: single_omega def _eval_single_omega(self): # feature: single_omega
@ -525,15 +564,38 @@ class ArgParser:
self.positions = {} self.positions = {}
pos13, pos23, pos33 = False, False, False # used to pos13, pos23, pos33 = False, False, False # used to
if len(a.position.keys()) == 0: if len(a.position.keys()) == 0:
warnings.warn("No particle position (-p or +p) specified, assuming single particle in the origin / single particle per unit cell!") warnings.warn("No particle position (-p or +p) specified, assuming single particle in the "
"origin / single particle per unit cell!")
a.position[None] = [(0.,0.,0.)] a.position[None] = [(0.,0.,0.)]
for poslabel in a.position.keys(): for poslabel in a.position.keys():
# TODO HERE GOES THE CODE TRYING TO LOAD CONSTANT T-MATRIX _resolve_exclusive_groups(a, poslabel, ("lMax",), ("vswf_set",))
try: try:
lMax_or_bspec = ( a.vswf_set.get(poslabel, False) lMax_or_bspec = ( a.vswf_set.get(poslabel, False)
or a.lMax.get(poslabel, False) or a.lMax.get(poslabel, False)
or a.vswf_set.get(None, False) or a.vswf_set.get(None, False)
or a.lMax[None] ) or a.lMax[None] )
bspec = self._add_bspec(lMax_or_bspec)
self.bspecs[poslabel] = bspec
except (TypeError, KeyError) as exc:
if poslabel is None:
raise ArgumentProcessingError("Unlabeled particles' positions (-p) specified, "
"but neither --lMax nor --vswf-set is specified") from exc
else:
raise ArgumentProcessingError(("Incomplete specification of '%s'-labeled particles: you must"
"provide either ++lMax or ++vswf-set argument with the label, "
"or one of the fallback arguments --lMax or --vswf-set."
)%(str(poslabel),)) from exc
agi = _resolve_exclusive_groups(a, poslabel,
("constant_tmatrix",),
("radius", "height", "material", "lMax_extend"),
)
if agi == 0: # constant T-matrix
tmspec = constant_tmatrix_spec(bspec,
hashable_complex_matrix(a.constant_tmatrix.get(poslabel, None)
or a.constant_tmatrix[None]))
elif agi == 1: #
try:
radius = a.radius.get(poslabel, False) or a.radius[None] radius = a.radius.get(poslabel, False) or a.radius[None]
# Height is "inherited" only together with radius # Height is "inherited" only together with radius
height = a.height.get(poslabel, None) if poslabel in a.radius.keys() else a.height.get(None, None) height = a.height.get(poslabel, None) if poslabel in a.radius.keys() else a.height.get(None, None)
@ -544,15 +606,23 @@ class ArgParser:
material = a.material.get(poslabel, False) or a.material[None] material = a.material.get(poslabel, False) or a.material[None]
except (TypeError, KeyError) as exc: except (TypeError, KeyError) as exc:
if poslabel is None: if poslabel is None:
raise ArgumentProcessingError("Unlabeled particles' positions (-p) specified, but some default particle properties are missing (--lMax, --radius, and --material have to be specified)") from exc raise ArgumentProcessingError("Unlabeled particles' positions (-p) specified, "
"but some default particle properties are missing "
"(either --lMax or --vswf-set, and --constant-tmatrix or "
"--radius and --material have to be specified)") from exc
else: else:
raise ArgumentProcessingError(("Incomplete specification of '%s'-labeled particles: you must" raise ArgumentProcessingError(("Incomplete specification of '%s'-labeled particles: you must"
"provide at least ++lMax, ++radius, ++material arguments with the label, or the fallback arguments" "provide at least ++lMax (or ++vswf-set), "
"--lMax, --radius, --material.")%(str(poslabel),)) from exc "++radius and ++material (or ++constant-tmatrix) arguments with the label, "
tmspec = (a.background, material, (radius, height, lMax_extend)) "or the fallback arguments --lMax (or --vswf-sit), --radius, --material "
"(or --constant-tmatrix)."
)%(str(poslabel),)) from exc
tmspec = tmgen_spec(a.background, material,
cyl_sph_dimensions(radius, height, lMax_extend))
else: raise AssertionErrorw
self.tmspecs[poslabel] = tmspec self.tmspecs[poslabel] = tmspec
self.tmgens[poslabel] = self._add_tmg(tmspec) self.tmgens[poslabel] = self._add_tmg(tmspec)
self.bspecs[poslabel] = self._add_bspec(lMax_or_bspec)
poslist_cured = [] poslist_cured = []
for pos in a.position[poslabel]: for pos in a.position[poslabel]:
if len(pos) == 1: if len(pos) == 1: