Source code for pst.model

from __future__ import annotations

from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, Set
import numpy as np
from astropy import units as u

Number = Union[int, float, np.number]


[docs] class Parameter: """ A model parameter that behaves like an astropy Quantity. """ def __init__( self, value: Union[Number, u.Quantity], unit: Optional[u.Unit] = None, vrange: Optional[Tuple[Union[Number, u.Quantity], Union[Number, u.Quantity]]] = None, fixed: bool = False, doc: str = "", ): self.fixed = fixed self.doc = doc q, unit_norm = self._normalize_value_and_unit(value, unit) self._q = q self.unit = unit_norm self.value = q self._vrange = self._normalize_vrange(vrange, unit_norm) @staticmethod def _normalize_value_and_unit( value: Union[Number, u.Quantity], unit: Optional[u.Unit] ) -> Tuple[u.Quantity, u.Unit]: if isinstance(value, u.Quantity): if unit is None: q = value else: target_unit = u.Unit(unit) if not value.unit.is_equivalent(target_unit): raise ValueError( f"Input quantity ({value.unit}) must be equivalent to requested unit ({target_unit})." ) q = value.to(target_unit) else: target_unit = u.Unit(unit) if unit is not None else u.dimensionless_unscaled q = value << target_unit return q, q.unit @classmethod def _normalize_vrange( cls, vrange: Optional[Tuple[Union[Number, u.Quantity], Union[Number, u.Quantity]]], unit_ref: u.Unit, ): if vrange is None: return None if not isinstance(vrange, (tuple, list)) or len(vrange) != 2: raise ValueError("vrange must be a 2-element tuple/list or None.") vmin, vmax = vrange if isinstance(vmin, cls): vmin = vmin.q if isinstance(vmax, cls): vmax = vmax.q has_quantity = isinstance(vmin, u.Quantity) or isinstance(vmax, u.Quantity) if has_quantity: if not isinstance(vmin, u.Quantity): vmin = vmin << unit_ref if not isinstance(vmax, u.Quantity): vmax = vmax << unit_ref if not vmin.unit.is_equivalent(unit_ref): raise ValueError( f"vrange minimum unit ({vmin.unit}) is not compatible with parameter unit ({unit_ref})." ) if not vmax.unit.is_equivalent(unit_ref): raise ValueError( f"vrange maximum unit ({vmax.unit}) is not compatible with parameter unit ({unit_ref})." ) vmin = vmin.to(unit_ref) vmax = vmax.to(unit_ref) if np.any(vmin > vmax): raise ValueError(f"Invalid vrange: minimum {vmin} exceeds maximum {vmax}.") return (vmin, vmax) if np.any(np.asarray(vmin) > np.asarray(vmax)): raise ValueError(f"Invalid vrange: minimum {vmin} exceeds maximum {vmax}.") return (vmin, vmax) @property def vrange(self): """Allowed interval as ``(vmin, vmax)``, or None.""" return self._vrange @vrange.setter def vrange(self, value): self._vrange = self._normalize_vrange(value, self.unit) @property def q(self) -> u.Quantity: """Return the underlying Quantity.""" return self._q @q.setter def q(self, value): """Set the underlying quantity, preserving unit compatibility.""" if not isinstance(value, u.Quantity): raise ValueError("Input value must be a quantity") if not value.unit.is_equivalent(self._q.unit): raise ValueError(f"Input quantity ({value.unit}) must be equivalent to current units ({self._q.unit})") self._q = value self.unit = self._q.unit self.value = self._q self._vrange = self._normalize_vrange(self._vrange, self.unit) @property def value_raw(self): """Raw numeric value of the underlying Quantity (unit-stripped).""" return self._q.value @property def unit_raw(self) -> u.Unit: """Unit of the underlying Quantity.""" return self._q.unit
[docs] def to(self, unit: u.Unit, equivalencies=None) -> u.Quantity: """Return a converted Quantity (does not modify the Parameter).""" return self._q.to(unit, equivalencies=equivalencies)
[docs] def to_value(self, unit: Optional[u.Unit] = None, equivalencies=None): """Return numeric value in the requested unit.""" if unit is None: return self._q.value return self._q.to_value(unit, equivalencies=equivalencies)
[docs] def convert_to(self, unit: u.Unit, equivalencies=None) -> "Parameter": """Convert the parameter in place and return self.""" if self.fixed: raise RuntimeError("Parameter is fixed and cannot be modified.") self._q = self._q.to(unit, equivalencies=equivalencies) self.unit = self._q.unit self.value = self._q self._vrange = self._normalize_vrange(self._vrange, self.unit) return self
# --- Set and validate -------------------------------------------------------
[docs] def set(self, new_value: Union[Number, u.Quantity], *, validate: bool = True) -> None: """ Update the parameter value. Parameters ---------- new_value : number or astropy.units.Quantity New parameter value. validate : bool, optional If True, enforce ``vrange`` constraints when defined. """ if self.fixed: raise RuntimeError("Parameter is fixed and cannot be modified.") if isinstance(new_value, u.Quantity): q = new_value else: q = new_value << (self.unit or u.dimensionless_unscaled) if validate and self.vrange is not None: vmin, vmax = self.vrange # Compare with units safely if isinstance(vmin, u.Quantity) or isinstance(vmax, u.Quantity) or isinstance(q, u.Quantity): vmin_q = vmin if isinstance(vmin, u.Quantity) else (vmin << self.unit) vmax_q = vmax if isinstance(vmax, u.Quantity) else (vmax << self.unit) q_cmp = q.to(self.unit) if isinstance(q, u.Quantity) else (q << self.unit) if np.any(q_cmp < vmin_q) or np.any(q_cmp > vmax_q): raise ValueError(f"Value {q_cmp} outside allowed range [{vmin_q}, {vmax_q}].") else: if float(q) < float(vmin) or float(q) > float(vmax): raise ValueError(f"Value {q} outside allowed range [{vmin}, {vmax}].") self._q = q self.unit = self._q.unit self.value = self._q self._vrange = self._normalize_vrange(self._vrange, self.unit)
[docs] def as_quantity(self) -> u.Quantity: """Return the current parameter value as an astropy Quantity.""" return self._q
@property def size(self): """Number of scalar elements in the underlying quantity.""" return self.q.size def __repr__(self) -> str: meta = [] if self.fixed: meta.append("fixed") if self.vrange is not None: meta.append(f"range={self.vrange}") meta_str = (", " + ", ".join(meta)) if meta else "" return f"Parameter({self._q!r}{meta_str})" def __float__(self) -> float: # Only valid if dimensionless or unit is compatible with float conversion expectation return float(self.value_raw) def __array__(self, dtype=None): # Allows np.asarray(Parameter) to work arr = np.asarray(self._q.value) if dtype is not None: return arr.astype(dtype, copy=False) return arr @property def __array_priority__(self): # Encourage numpy to use our __array_ufunc__ return 1000
[docs] def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): """ Intercept numpy ufuncs and delegate to Quantity.""" # Replace Parameters with underlying Quantities q_inputs = [x._q if isinstance(x, Parameter) else x for x in inputs] # Let astropy handle unit-aware ufuncs when possible # For many ufuncs, Quantity implements __array_ufunc__ and will return Quantity. result = getattr(ufunc, method)(*q_inputs, **kwargs) return result
# --- Arithmetic and binary operations ----------------------------------- def _binop(self, other: Any, op): other_q = other._q if isinstance(other, Parameter) else other return op(self._q, other_q) def __add__(self, other): return self._binop(other, lambda a, b: a + b) def __radd__(self, other): return self._binop(other, lambda a, b: b + a) def __sub__(self, other): return self._binop(other, lambda a, b: a - b) def __rsub__(self, other): return self._binop(other, lambda a, b: b - a) def __mul__(self, other): return self._binop(other, lambda a, b: a * b) def __rmul__(self, other): return self._binop(other, lambda a, b: b * a) def __truediv__(self, other): return self._binop(other, lambda a, b: a / b) def __rtruediv__(self, other): return self._binop(other, lambda a, b: b / a) def __pow__(self, other): return self._binop(other, lambda a, b: a ** b) def __rpow__(self, other): return self._binop(other, lambda a, b: b ** a)
[docs] class ModelBase: """ Base class for models with named parameters. Subclasses typically declare Parameter attributes directly, for example: class MyModel(ModelBase): name: str = "my_model" a_v: Parameter = Parameter(0.2, vrange=(0.0, 5.0)) r_v: Parameter = Parameter(3.1, vrange=(2.0, 6.0), fixed=True) Notes ----- This base class provides: - discovery of Parameter attributes - getting and setting values by name - filtering fixed or free parameters - conversion to plain dictionaries for IO """ name: str = "model"
[docs] def models_recursive( self, *, prefix: str = "", max_depth: Optional[int] = None, ) -> Dict[str, "ModelBase"]: """ Return nested models including self. Returns ------- models : dict Mapping from dotted path to model instance. """ out: Dict[str, ModelBase] = {} visited: Set[int] = set() def _walk(obj: ModelBase, base: str, depth: int) -> None: oid = id(obj) if oid in visited: return visited.add(oid) if max_depth is not None and depth > max_depth: return out[base or obj.name or "model"] = obj for attr_name, attr_val in obj.__dict__.items(): if attr_name.startswith("_"): continue if isinstance(attr_val, ModelBase): child_base = f"{base}.{attr_name}" if base else attr_name _walk(attr_val, child_base, depth + 1) _walk(self, prefix.strip("."), 0) return out
[docs] def parameters_recursive( self, *, prefix: str = "", max_depth: Optional[int] = None, include_fixed: bool = True, ) -> Dict[str, "Parameter"]: """ Return parameters from this model and nested models. Returns ------- params : dict Mapping from dotted parameter path to Parameter. """ params: Dict[str, Parameter] = {} visited: Set[int] = set() def _walk(obj: ModelBase, base: str, depth: int) -> None: oid = id(obj) if oid in visited: return visited.add(oid) if max_depth is not None and depth > max_depth: return for pname, p in obj.parameters().items(): if (not include_fixed) and p.fixed: continue key = f"{base}.{pname}" if base else pname params[key] = p for attr_name, attr_val in obj.__dict__.items(): if attr_name.startswith("_"): continue if isinstance(attr_val, ModelBase): child_base = f"{base}.{attr_name}" if base else attr_name _walk(attr_val, child_base, depth + 1) _walk(self, prefix.strip("."), 0) return params
[docs] def parameters(self) -> Dict[str, Parameter]: """ Return direct Parameter attributes defined on this model. Returns ------- params : dict Mapping from attribute name to :class:`Parameter`. """ out: Dict[str, Parameter] = {} for name in dir(self): if name.startswith("_"): continue try: v = getattr(self, name) except Exception: continue if isinstance(v, Parameter): out[name] = v return out
[docs] def parameter_names(self, *, include_fixed: bool = True) -> List[str]: """ Return parameter names. Parameters ---------- include_fixed : bool, optional If False, only return free parameters. Returns ------- names : list of str """ ps = self.parameters() if include_fixed: return list(ps.keys()) return [k for k, p in ps.items() if not p.fixed]
[docs] def get(self, name: str) -> Parameter: """ Get a Parameter object by name. Raises ------ KeyError If the parameter does not exist. """ p = getattr(self, name, None) if not isinstance(p, Parameter): raise KeyError(f"Unknown parameter '{name}'.") return p
[docs] def get_values(self, *, include_fixed: bool = True, as_quantity: bool = False) -> Dict[str, Any]: """ Get parameter values as a dict. Parameters ---------- include_fixed : bool, optional If False, only returns free parameters. as_quantity : bool, optional If True, returns values as Quantities when possible. Returns ------- values : dict Mapping from parameter name to value. """ ps = self.parameters() out: Dict[str, Any] = {} for k, p in ps.items(): if (not include_fixed) and p.fixed: continue out[k] = p.as_quantity() if as_quantity else p.value return out
[docs] def set_values(self, values: Mapping[str, Any], *, validate: bool = True, strict: bool = True) -> None: """ Set parameter values from a mapping. Parameters ---------- values : mapping Mapping from parameter name to new value. validate : bool, optional If True, checks fixed and vrange constraints. Default is True. strict : bool, optional If True, unknown keys raise KeyError. If False, unknown keys are ignored. """ for k, v in values.items(): if not hasattr(self, k) or not isinstance(getattr(self, k), Parameter): if strict: raise KeyError(f"Unknown parameter '{k}'.") continue self.get(k).set(v, validate=validate)
[docs] def freeze(self, names: Optional[Sequence[str]] = None) -> None: """ Freeze parameters. Parameters ---------- names : sequence of str or None If None, freeze all parameters. Otherwise freeze selected ones. """ ps = self.parameters() if names is None: for p in ps.values(): p.fixed = True return for n in names: self.get(n).fixed = True
[docs] def unfreeze(self, names: Optional[Sequence[str]] = None) -> None: """ Unfreeze parameters. Parameters ---------- names : sequence of str or None If None, unfreeze all parameters. Otherwise unfreeze selected ones. """ ps = self.parameters() if names is None: for p in ps.values(): p.fixed = False return for n in names: self.get(n).fixed = False
[docs] def search(self, text: str, *, in_docs: bool = True, in_names: bool = True) -> List[str]: """ Search parameters by substring. Parameters ---------- text : str Substring to search for, case-insensitive. in_docs : bool, optional If True, search in Parameter.doc strings. in_names : bool, optional If True, search in parameter names. Returns ------- matches : list of str Matching parameter names. """ q = text.lower().strip() matches: List[str] = [] for name, p in self.parameters().items(): hit = False if in_names and q in name.lower(): hit = True if in_docs and p.doc and q in p.doc.lower(): hit = True if hit: matches.append(name) return matches
[docs] def to_dict(self, *, include_fixed: bool = True) -> Dict[str, Any]: """ Serialize the model configuration to a plain dict. Notes ----- This returns a simple structure suitable for JSON/YAML. Quantities are represented as dicts with value and unit. """ out: Dict[str, Any] = {"name": self.name, "parameters": {}} for k, p in self.parameters().items(): if (not include_fixed) and p.fixed: continue val = p.value if isinstance(val, u.Quantity): val_repr: Any = {"value": float(val.value), "unit": str(val.unit)} else: val_repr = val vr = p.vrange if vr is not None: vmin, vmax = vr if isinstance(vmin, u.Quantity) or isinstance(vmax, u.Quantity): vmin_r = {"value": float(vmin.to_value(vmin.unit)), "unit": str(vmin.unit)} if isinstance(vmin, u.Quantity) else vmin vmax_r = {"value": float(vmax.to_value(vmax.unit)), "unit": str(vmax.unit)} if isinstance(vmax, u.Quantity) else vmax vr_repr = [vmin_r, vmax_r] else: vr_repr = [vmin, vmax] else: vr_repr = None out["parameters"][k] = { "value": val_repr, "vrange": vr_repr, "fixed": bool(p.fixed), "unit": (str(p.unit) if p.unit is not None else None), "doc": p.doc, } return out