from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, Set
import numpy as np
from astropy import units as u
Number = Union[int, float, np.number]
[docs]
class Parameter:
"""
A model parameter that behaves like an astropy Quantity.
"""
def __init__(
self,
value: Union[Number, u.Quantity],
unit: Optional[u.Unit] = None,
vrange: Optional[Tuple[Union[Number, u.Quantity], Union[Number, u.Quantity]]] = None,
fixed: bool = False,
doc: str = "",
):
self.fixed = fixed
self.doc = doc
q, unit_norm = self._normalize_value_and_unit(value, unit)
self._q = q
self.unit = unit_norm
self.value = q
self._vrange = self._normalize_vrange(vrange, unit_norm)
@staticmethod
def _normalize_value_and_unit(
value: Union[Number, u.Quantity], unit: Optional[u.Unit]
) -> Tuple[u.Quantity, u.Unit]:
if isinstance(value, u.Quantity):
if unit is None:
q = value
else:
target_unit = u.Unit(unit)
if not value.unit.is_equivalent(target_unit):
raise ValueError(
f"Input quantity ({value.unit}) must be equivalent to requested unit ({target_unit})."
)
q = value.to(target_unit)
else:
target_unit = u.Unit(unit) if unit is not None else u.dimensionless_unscaled
q = value << target_unit
return q, q.unit
@classmethod
def _normalize_vrange(
cls,
vrange: Optional[Tuple[Union[Number, u.Quantity], Union[Number, u.Quantity]]],
unit_ref: u.Unit,
):
if vrange is None:
return None
if not isinstance(vrange, (tuple, list)) or len(vrange) != 2:
raise ValueError("vrange must be a 2-element tuple/list or None.")
vmin, vmax = vrange
if isinstance(vmin, cls):
vmin = vmin.q
if isinstance(vmax, cls):
vmax = vmax.q
has_quantity = isinstance(vmin, u.Quantity) or isinstance(vmax, u.Quantity)
if has_quantity:
if not isinstance(vmin, u.Quantity):
vmin = vmin << unit_ref
if not isinstance(vmax, u.Quantity):
vmax = vmax << unit_ref
if not vmin.unit.is_equivalent(unit_ref):
raise ValueError(
f"vrange minimum unit ({vmin.unit}) is not compatible with parameter unit ({unit_ref})."
)
if not vmax.unit.is_equivalent(unit_ref):
raise ValueError(
f"vrange maximum unit ({vmax.unit}) is not compatible with parameter unit ({unit_ref})."
)
vmin = vmin.to(unit_ref)
vmax = vmax.to(unit_ref)
if np.any(vmin > vmax):
raise ValueError(f"Invalid vrange: minimum {vmin} exceeds maximum {vmax}.")
return (vmin, vmax)
if np.any(np.asarray(vmin) > np.asarray(vmax)):
raise ValueError(f"Invalid vrange: minimum {vmin} exceeds maximum {vmax}.")
return (vmin, vmax)
@property
def vrange(self):
"""Allowed interval as ``(vmin, vmax)``, or None."""
return self._vrange
@vrange.setter
def vrange(self, value):
self._vrange = self._normalize_vrange(value, self.unit)
@property
def q(self) -> u.Quantity:
"""Return the underlying Quantity."""
return self._q
@q.setter
def q(self, value):
"""Set the underlying quantity, preserving unit compatibility."""
if not isinstance(value, u.Quantity):
raise ValueError("Input value must be a quantity")
if not value.unit.is_equivalent(self._q.unit):
raise ValueError(f"Input quantity ({value.unit}) must be equivalent to current units ({self._q.unit})")
self._q = value
self.unit = self._q.unit
self.value = self._q
self._vrange = self._normalize_vrange(self._vrange, self.unit)
@property
def value_raw(self):
"""Raw numeric value of the underlying Quantity (unit-stripped)."""
return self._q.value
@property
def unit_raw(self) -> u.Unit:
"""Unit of the underlying Quantity."""
return self._q.unit
[docs]
def to(self, unit: u.Unit, equivalencies=None) -> u.Quantity:
"""Return a converted Quantity (does not modify the Parameter)."""
return self._q.to(unit, equivalencies=equivalencies)
[docs]
def to_value(self, unit: Optional[u.Unit] = None, equivalencies=None):
"""Return numeric value in the requested unit."""
if unit is None:
return self._q.value
return self._q.to_value(unit, equivalencies=equivalencies)
[docs]
def convert_to(self, unit: u.Unit, equivalencies=None) -> "Parameter":
"""Convert the parameter in place and return self."""
if self.fixed:
raise RuntimeError("Parameter is fixed and cannot be modified.")
self._q = self._q.to(unit, equivalencies=equivalencies)
self.unit = self._q.unit
self.value = self._q
self._vrange = self._normalize_vrange(self._vrange, self.unit)
return self
# --- Set and validate -------------------------------------------------------
[docs]
def set(self, new_value: Union[Number, u.Quantity], *, validate: bool = True) -> None:
"""
Update the parameter value.
Parameters
----------
new_value : number or astropy.units.Quantity
New parameter value.
validate : bool, optional
If True, enforce ``vrange`` constraints when defined.
"""
if self.fixed:
raise RuntimeError("Parameter is fixed and cannot be modified.")
if isinstance(new_value, u.Quantity):
q = new_value
else:
q = new_value << (self.unit or u.dimensionless_unscaled)
if validate and self.vrange is not None:
vmin, vmax = self.vrange
# Compare with units safely
if isinstance(vmin, u.Quantity) or isinstance(vmax, u.Quantity) or isinstance(q, u.Quantity):
vmin_q = vmin if isinstance(vmin, u.Quantity) else (vmin << self.unit)
vmax_q = vmax if isinstance(vmax, u.Quantity) else (vmax << self.unit)
q_cmp = q.to(self.unit) if isinstance(q, u.Quantity) else (q << self.unit)
if np.any(q_cmp < vmin_q) or np.any(q_cmp > vmax_q):
raise ValueError(f"Value {q_cmp} outside allowed range [{vmin_q}, {vmax_q}].")
else:
if float(q) < float(vmin) or float(q) > float(vmax):
raise ValueError(f"Value {q} outside allowed range [{vmin}, {vmax}].")
self._q = q
self.unit = self._q.unit
self.value = self._q
self._vrange = self._normalize_vrange(self._vrange, self.unit)
[docs]
def as_quantity(self) -> u.Quantity:
"""Return the current parameter value as an astropy Quantity."""
return self._q
@property
def size(self):
"""Number of scalar elements in the underlying quantity."""
return self.q.size
def __repr__(self) -> str:
meta = []
if self.fixed:
meta.append("fixed")
if self.vrange is not None:
meta.append(f"range={self.vrange}")
meta_str = (", " + ", ".join(meta)) if meta else ""
return f"Parameter({self._q!r}{meta_str})"
def __float__(self) -> float:
# Only valid if dimensionless or unit is compatible with float conversion expectation
return float(self.value_raw)
def __array__(self, dtype=None):
# Allows np.asarray(Parameter) to work
arr = np.asarray(self._q.value)
if dtype is not None:
return arr.astype(dtype, copy=False)
return arr
@property
def __array_priority__(self):
# Encourage numpy to use our __array_ufunc__
return 1000
[docs]
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
"""
Intercept numpy ufuncs and delegate to Quantity."""
# Replace Parameters with underlying Quantities
q_inputs = [x._q if isinstance(x, Parameter) else x for x in inputs]
# Let astropy handle unit-aware ufuncs when possible
# For many ufuncs, Quantity implements __array_ufunc__ and will return Quantity.
result = getattr(ufunc, method)(*q_inputs, **kwargs)
return result
# --- Arithmetic and binary operations -----------------------------------
def _binop(self, other: Any, op):
other_q = other._q if isinstance(other, Parameter) else other
return op(self._q, other_q)
def __add__(self, other): return self._binop(other, lambda a, b: a + b)
def __radd__(self, other): return self._binop(other, lambda a, b: b + a)
def __sub__(self, other): return self._binop(other, lambda a, b: a - b)
def __rsub__(self, other): return self._binop(other, lambda a, b: b - a)
def __mul__(self, other): return self._binop(other, lambda a, b: a * b)
def __rmul__(self, other): return self._binop(other, lambda a, b: b * a)
def __truediv__(self, other): return self._binop(other, lambda a, b: a / b)
def __rtruediv__(self, other): return self._binop(other, lambda a, b: b / a)
def __pow__(self, other): return self._binop(other, lambda a, b: a ** b)
def __rpow__(self, other): return self._binop(other, lambda a, b: b ** a)
[docs]
class ModelBase:
"""
Base class for models with named parameters.
Subclasses typically declare Parameter attributes directly, for example:
class MyModel(ModelBase):
name: str = "my_model"
a_v: Parameter = Parameter(0.2, vrange=(0.0, 5.0))
r_v: Parameter = Parameter(3.1, vrange=(2.0, 6.0), fixed=True)
Notes
-----
This base class provides:
- discovery of Parameter attributes
- getting and setting values by name
- filtering fixed or free parameters
- conversion to plain dictionaries for IO
"""
name: str = "model"
[docs]
def models_recursive(
self,
*,
prefix: str = "",
max_depth: Optional[int] = None,
) -> Dict[str, "ModelBase"]:
"""
Return nested models including self.
Returns
-------
models : dict
Mapping from dotted path to model instance.
"""
out: Dict[str, ModelBase] = {}
visited: Set[int] = set()
def _walk(obj: ModelBase, base: str, depth: int) -> None:
oid = id(obj)
if oid in visited:
return
visited.add(oid)
if max_depth is not None and depth > max_depth:
return
out[base or obj.name or "model"] = obj
for attr_name, attr_val in obj.__dict__.items():
if attr_name.startswith("_"):
continue
if isinstance(attr_val, ModelBase):
child_base = f"{base}.{attr_name}" if base else attr_name
_walk(attr_val, child_base, depth + 1)
_walk(self, prefix.strip("."), 0)
return out
[docs]
def parameters_recursive(
self,
*,
prefix: str = "",
max_depth: Optional[int] = None,
include_fixed: bool = True,
) -> Dict[str, "Parameter"]:
"""
Return parameters from this model and nested models.
Returns
-------
params : dict
Mapping from dotted parameter path to Parameter.
"""
params: Dict[str, Parameter] = {}
visited: Set[int] = set()
def _walk(obj: ModelBase, base: str, depth: int) -> None:
oid = id(obj)
if oid in visited:
return
visited.add(oid)
if max_depth is not None and depth > max_depth:
return
for pname, p in obj.parameters().items():
if (not include_fixed) and p.fixed:
continue
key = f"{base}.{pname}" if base else pname
params[key] = p
for attr_name, attr_val in obj.__dict__.items():
if attr_name.startswith("_"):
continue
if isinstance(attr_val, ModelBase):
child_base = f"{base}.{attr_name}" if base else attr_name
_walk(attr_val, child_base, depth + 1)
_walk(self, prefix.strip("."), 0)
return params
[docs]
def parameters(self) -> Dict[str, Parameter]:
"""
Return direct Parameter attributes defined on this model.
Returns
-------
params : dict
Mapping from attribute name to :class:`Parameter`.
"""
out: Dict[str, Parameter] = {}
for name in dir(self):
if name.startswith("_"):
continue
try:
v = getattr(self, name)
except Exception:
continue
if isinstance(v, Parameter):
out[name] = v
return out
[docs]
def parameter_names(self, *, include_fixed: bool = True) -> List[str]:
"""
Return parameter names.
Parameters
----------
include_fixed : bool, optional
If False, only return free parameters.
Returns
-------
names : list of str
"""
ps = self.parameters()
if include_fixed:
return list(ps.keys())
return [k for k, p in ps.items() if not p.fixed]
[docs]
def get(self, name: str) -> Parameter:
"""
Get a Parameter object by name.
Raises
------
KeyError
If the parameter does not exist.
"""
p = getattr(self, name, None)
if not isinstance(p, Parameter):
raise KeyError(f"Unknown parameter '{name}'.")
return p
[docs]
def get_values(self, *, include_fixed: bool = True, as_quantity: bool = False) -> Dict[str, Any]:
"""
Get parameter values as a dict.
Parameters
----------
include_fixed : bool, optional
If False, only returns free parameters.
as_quantity : bool, optional
If True, returns values as Quantities when possible.
Returns
-------
values : dict
Mapping from parameter name to value.
"""
ps = self.parameters()
out: Dict[str, Any] = {}
for k, p in ps.items():
if (not include_fixed) and p.fixed:
continue
out[k] = p.as_quantity() if as_quantity else p.value
return out
[docs]
def set_values(self, values: Mapping[str, Any], *, validate: bool = True, strict: bool = True) -> None:
"""
Set parameter values from a mapping.
Parameters
----------
values : mapping
Mapping from parameter name to new value.
validate : bool, optional
If True, checks fixed and vrange constraints. Default is True.
strict : bool, optional
If True, unknown keys raise KeyError. If False, unknown keys are ignored.
"""
for k, v in values.items():
if not hasattr(self, k) or not isinstance(getattr(self, k), Parameter):
if strict:
raise KeyError(f"Unknown parameter '{k}'.")
continue
self.get(k).set(v, validate=validate)
[docs]
def freeze(self, names: Optional[Sequence[str]] = None) -> None:
"""
Freeze parameters.
Parameters
----------
names : sequence of str or None
If None, freeze all parameters. Otherwise freeze selected ones.
"""
ps = self.parameters()
if names is None:
for p in ps.values():
p.fixed = True
return
for n in names:
self.get(n).fixed = True
[docs]
def unfreeze(self, names: Optional[Sequence[str]] = None) -> None:
"""
Unfreeze parameters.
Parameters
----------
names : sequence of str or None
If None, unfreeze all parameters. Otherwise unfreeze selected ones.
"""
ps = self.parameters()
if names is None:
for p in ps.values():
p.fixed = False
return
for n in names:
self.get(n).fixed = False
[docs]
def search(self, text: str, *, in_docs: bool = True, in_names: bool = True) -> List[str]:
"""
Search parameters by substring.
Parameters
----------
text : str
Substring to search for, case-insensitive.
in_docs : bool, optional
If True, search in Parameter.doc strings.
in_names : bool, optional
If True, search in parameter names.
Returns
-------
matches : list of str
Matching parameter names.
"""
q = text.lower().strip()
matches: List[str] = []
for name, p in self.parameters().items():
hit = False
if in_names and q in name.lower():
hit = True
if in_docs and p.doc and q in p.doc.lower():
hit = True
if hit:
matches.append(name)
return matches
[docs]
def to_dict(self, *, include_fixed: bool = True) -> Dict[str, Any]:
"""
Serialize the model configuration to a plain dict.
Notes
-----
This returns a simple structure suitable for JSON/YAML. Quantities are
represented as dicts with value and unit.
"""
out: Dict[str, Any] = {"name": self.name, "parameters": {}}
for k, p in self.parameters().items():
if (not include_fixed) and p.fixed:
continue
val = p.value
if isinstance(val, u.Quantity):
val_repr: Any = {"value": float(val.value), "unit": str(val.unit)}
else:
val_repr = val
vr = p.vrange
if vr is not None:
vmin, vmax = vr
if isinstance(vmin, u.Quantity) or isinstance(vmax, u.Quantity):
vmin_r = {"value": float(vmin.to_value(vmin.unit)), "unit": str(vmin.unit)} if isinstance(vmin, u.Quantity) else vmin
vmax_r = {"value": float(vmax.to_value(vmax.unit)), "unit": str(vmax.unit)} if isinstance(vmax, u.Quantity) else vmax
vr_repr = [vmin_r, vmax_r]
else:
vr_repr = [vmin, vmax]
else:
vr_repr = None
out["parameters"][k] = {
"value": val_repr,
"vrange": vr_repr,
"fixed": bool(p.fixed),
"unit": (str(p.unit) if p.unit is not None else None),
"doc": p.doc,
}
return out