from os import environ
from typing import Any, cast
from astropy.units import Quantity
from pydantic import (
BaseModel,
ConfigDict,
SerializerFunctionWrapHandler,
model_serializer,
)
from pydantic.config import ExtraValues
from pydantic_core import PydanticUndefined
EXTRA_FIELDS = cast(ExtraValues | None, environ.get("EXTRA_FIELDS"))
[docs]
class PdmObject(BaseModel):
"""Shared Base Class for all PDM Entities."""
# https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict
model_config = ConfigDict(
extra=EXTRA_FIELDS, # Defaults to 'ignore'; can we prefer 'forbid' here?
# Validate assignments and defaults to help keep ourselves honest:
validate_assignment=True,
validate_default=True,
ser_json_timedelta="float",
)
def _is_default(self, key: str) -> bool:
field_info = self.__class__.model_fields[key]
if field_info.default_factory is not None:
default = field_info.default_factory() # type: ignore
elif field_info.default is not PydanticUndefined:
default = field_info.default
else:
default = PydanticUndefined
return getattr(self, key) == default
@staticmethod
def _is_empty(value: Any) -> bool:
return value in (None, [], {})
def _exclude_default_nulls_and_empty(
self, dumped: dict[str, Any]
) -> dict[str, Any]:
"""To avoid cluttering JSON output, we want to omit any None, [], {} values
that are present by default, but preserve any 'empty' values that were deliberately
set by callers."""
filtered = {
key: val
for key, val in dumped.items()
if not (self._is_empty(val) and self._is_default(key))
}
return filtered
@model_serializer(mode="wrap")
def _serialize(
self, default_serializer: SerializerFunctionWrapHandler
) -> dict[str, Any]:
dumped = default_serializer(self)
without_nulls = self._exclude_default_nulls_and_empty(dumped)
return without_nulls
@classmethod
def _values_equal(cls, left: Any, right: Any) -> bool:
# Astropy quantities can be vector values or scalars. Vector values
# need piecewise comparison.
if isinstance(left, Quantity) and isinstance(right, Quantity):
result = left.value == right.value
return bool(result.all()) if hasattr(result, "all") else result
if isinstance(left, dict) and isinstance(right, dict):
if left.keys() != right.keys():
return False
return all(cls._values_equal(left[k], right[k]) for k in left)
if isinstance(left, (list, tuple)) and isinstance(right, (list, tuple)):
return len(left) == len(right) and all(
cls._values_equal(li, ri) for li, ri in zip(left, right)
)
return left == right
def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False
return self._values_equal(
self.model_dump(mode="python"),
other.model_dump(mode="python"),
)