"""
The ska_oso_pdm.sb_definition.scan_definition module defines
simple Python representation of a single observation scan
"""
__all__ = ["ScanDefinition"]
import warnings
from datetime import timedelta
from typing import Annotated, Any, Type
from astropy import units
from pydantic import AfterValidator, Field, WithJsonSchema, model_validator
from typing_extensions import Self
from ska_oso_pdm._shared import (
AstropyQuantity,
AstropyUnit,
CSPConfigurationID,
DishAllocationID,
MCCSAllocationID,
PdmObject,
ScanDefinitionID,
TargetID,
TerseStrEnum,
TimedeltaMs,
UnitHelpers,
)
from ska_oso_pdm._shared.custom_types import Quantity
class ScanDurationUnits(TerseStrEnum):
"""
Units for scan durations.
"""
SECONDS = "s"
MINUTES = "min"
HOURS = "h"
ScanDurationUnitType = Annotated[
AstropyUnit,
AfterValidator(UnitHelpers.constrain_unit_to(ScanDurationUnits)),
WithJsonSchema(UnitHelpers.enum_jsonschema(ScanDurationUnits)),
]
class ScanDurationQuantity(Quantity):
unit = units.Unit(ScanDurationUnits.SECONDS)
ScanDurationQuantityType = Annotated[
AstropyQuantity,
ScanDurationQuantity,
AfterValidator(UnitHelpers.constrain_unit_to(ScanDurationUnits)),
]
class PointingCorrection(TerseStrEnum):
"""
Operation to apply to the pointing correction model.
MAINTAIN: continue applying the current pointing correction model
UPDATE: wait for (if necessary) and apply new pointing calibration solution
RESET: reset the applied pointing correction to the pointing model defaults
"""
MAINTAIN = "MAINTAIN"
UPDATE = "UPDATE"
RESET = "RESET"
[docs]
class ScanDefinition(PdmObject):
"""
ScanDefinition represents the instrument configuration for a single scan.
:param scan_definition_id: the unique ID for this scan definition
:param scan_duration_ms: scan duration
:param scan_duration: scan duration
:target_ref: ID of target to observe
:mccs_allocation_ref: ID of MCCS Config
:target_beam_configuration_refs: SKA LOW sub-array beam configurations to apply during this scan.
:dish_allocation_ref: SKA MID dish configuration ID during this scan.
:scan_type_ref: SKA MID scan type ID
:csp_configuration_ref: SKA MID Central Signal Processor ID
:pointing_correction: operation to apply to the pointing correction model.
"""
scan_definition_id: ScanDefinitionID | None = None
target_ref: TargetID | None = None
mccs_allocation_ref: MCCSAllocationID | None = None
dish_allocation_ref: DishAllocationID | None = None
csp_configuration_ref: CSPConfigurationID | None = None
scan_intent: str | None = None
pointing_correction: PointingCorrection = PointingCorrection.MAINTAIN
# marked as optional for forwards compatibility as this is to be replaced by scan_duration
scan_duration_ms: TimedeltaMs | None = Field(default=None, deprecated=True)
scan_duration: ScanDurationQuantityType | None = units.Quantity(
value=0, unit=ScanDurationUnits.SECONDS
)
# Intercept runtime updates so that updating one duration field updates the other
def __setattr__(self, name: str, value: Any) -> None:
# Let pydantic/BaseModel do the assignment first to keep it happy
super().__setattr__(name, value)
# If scan_duration_ms was updated, update scan_duration accordingly
if name == "scan_duration_ms" and value is not None:
# Value expected to be timedelta-like (TimedeltaMs)
scan_duration_quantity = units.Quantity(
value=float(value.total_seconds()), unit=ScanDurationUnits.SECONDS
)
# Call the parent implementation to bypass the class's overridden
# setattr to avoiding infinite recursion
super().__setattr__("scan_duration", scan_duration_quantity)
# If scan_duration was updated, update scan_duration_ms accordingly
if name == "scan_duration" and value is not None:
scan_duration_quantity = (
value if isinstance(value, units.Quantity) else units.Quantity(value)
)
scan_duration_ms = scan_duration_quantity.to(units.Unit("ms")).value
# Call the parent implementation to bypass the class's overridden
# setattr to avoiding infinite recursion
super().__setattr__(
"scan_duration_ms", timedelta(milliseconds=scan_duration_ms)
)
# Existing validator for transition (Runs SECOND, after validator in base class)
@model_validator(mode="before")
@classmethod
def _handle_scan_duration_ms_to_scan_duration_transition(
cls: Type[Self], data: Any
) -> Any:
if isinstance(data, dict):
scan_duration_ms = data.get("scan_duration_ms")
scan_duration = data.get("scan_duration")
if scan_duration is None and scan_duration_ms is not None:
if not isinstance(scan_duration_ms, timedelta):
scan_duration_ms = timedelta(milliseconds=scan_duration_ms)
data["scan_duration"] = units.Quantity(
value=float(scan_duration_ms.total_seconds()),
unit=ScanDurationUnits.SECONDS,
)
elif scan_duration_ms is None and scan_duration is not None:
if isinstance(scan_duration, units.Quantity):
duration_seconds = scan_duration.to(units.Unit("s")).value
else:
value = scan_duration.get("value")
unit = scan_duration.get("unit")
duration_seconds = (
units.Quantity(value=value, unit=unit).to(units.Unit("s")).value
)
data["scan_duration_ms"] = timedelta(seconds=duration_seconds)
return data
# After validator ensures scan_duration is populated
@model_validator(mode="after")
def _check_scan_duration_populated(self) -> Self:
if self.scan_duration is None:
raise ValueError("'scan_duration' could not be determined")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
if self.scan_duration_ms is None:
raise ValueError("'scan_duration_ms' could not be determined")
return self