Source code for ska_oso_pdm.sb_definition.scan_definition

"""
The ska_oso_pdm.sb_definition.scan_definition module defines
simple Python representation of a single observation scan
"""

__all__ = ["ScanDefinition"]

import warnings
from datetime import timedelta
from typing import Annotated, Any, Type

from astropy import units
from pydantic import AfterValidator, Field, WithJsonSchema, model_validator
from typing_extensions import Self

from ska_oso_pdm._shared import (
    AstropyQuantity,
    AstropyUnit,
    CSPConfigurationID,
    DishAllocationID,
    MCCSAllocationID,
    PdmObject,
    ScanDefinitionID,
    TargetID,
    TerseStrEnum,
    TimedeltaMs,
    UnitHelpers,
)
from ska_oso_pdm._shared.custom_types import Quantity


class ScanDurationUnits(TerseStrEnum):
    """
    Units for scan durations.
    """

    SECONDS = "s"
    MINUTES = "min"
    HOURS = "h"


ScanDurationUnitType = Annotated[
    AstropyUnit,
    AfterValidator(UnitHelpers.constrain_unit_to(ScanDurationUnits)),
    WithJsonSchema(UnitHelpers.enum_jsonschema(ScanDurationUnits)),
]


class ScanDurationQuantity(Quantity):
    unit = units.Unit(ScanDurationUnits.SECONDS)


ScanDurationQuantityType = Annotated[
    AstropyQuantity,
    ScanDurationQuantity,
    AfterValidator(UnitHelpers.constrain_unit_to(ScanDurationUnits)),
]


class PointingCorrection(TerseStrEnum):
    """
    Operation to apply to the pointing correction model.

    MAINTAIN: continue applying the current pointing correction model
    UPDATE: wait for (if necessary) and apply new pointing calibration solution
    RESET: reset the applied pointing correction to the pointing model defaults
    """

    MAINTAIN = "MAINTAIN"
    UPDATE = "UPDATE"
    RESET = "RESET"


[docs] class ScanDefinition(PdmObject): """ ScanDefinition represents the instrument configuration for a single scan. :param scan_definition_id: the unique ID for this scan definition :param scan_duration_ms: scan duration :param scan_duration: scan duration :target_ref: ID of target to observe :mccs_allocation_ref: ID of MCCS Config :target_beam_configuration_refs: SKA LOW sub-array beam configurations to apply during this scan. :dish_allocation_ref: SKA MID dish configuration ID during this scan. :scan_type_ref: SKA MID scan type ID :csp_configuration_ref: SKA MID Central Signal Processor ID :pointing_correction: operation to apply to the pointing correction model. """ scan_definition_id: ScanDefinitionID | None = None target_ref: TargetID | None = None mccs_allocation_ref: MCCSAllocationID | None = None dish_allocation_ref: DishAllocationID | None = None csp_configuration_ref: CSPConfigurationID | None = None scan_intent: str | None = None pointing_correction: PointingCorrection = PointingCorrection.MAINTAIN # marked as optional for forwards compatibility as this is to be replaced by scan_duration scan_duration_ms: TimedeltaMs | None = Field(default=None, deprecated=True) scan_duration: ScanDurationQuantityType | None = units.Quantity( value=0, unit=ScanDurationUnits.SECONDS ) # Intercept runtime updates so that updating one duration field updates the other def __setattr__(self, name: str, value: Any) -> None: # Let pydantic/BaseModel do the assignment first to keep it happy super().__setattr__(name, value) # If scan_duration_ms was updated, update scan_duration accordingly if name == "scan_duration_ms" and value is not None: # Value expected to be timedelta-like (TimedeltaMs) scan_duration_quantity = units.Quantity( value=float(value.total_seconds()), unit=ScanDurationUnits.SECONDS ) # Call the parent implementation to bypass the class's overridden # setattr to avoiding infinite recursion super().__setattr__("scan_duration", scan_duration_quantity) # If scan_duration was updated, update scan_duration_ms accordingly if name == "scan_duration" and value is not None: scan_duration_quantity = ( value if isinstance(value, units.Quantity) else units.Quantity(value) ) scan_duration_ms = scan_duration_quantity.to(units.Unit("ms")).value # Call the parent implementation to bypass the class's overridden # setattr to avoiding infinite recursion super().__setattr__( "scan_duration_ms", timedelta(milliseconds=scan_duration_ms) ) # Existing validator for transition (Runs SECOND, after validator in base class) @model_validator(mode="before") @classmethod def _handle_scan_duration_ms_to_scan_duration_transition( cls: Type[Self], data: Any ) -> Any: if isinstance(data, dict): scan_duration_ms = data.get("scan_duration_ms") scan_duration = data.get("scan_duration") if scan_duration is None and scan_duration_ms is not None: if not isinstance(scan_duration_ms, timedelta): scan_duration_ms = timedelta(milliseconds=scan_duration_ms) data["scan_duration"] = units.Quantity( value=float(scan_duration_ms.total_seconds()), unit=ScanDurationUnits.SECONDS, ) elif scan_duration_ms is None and scan_duration is not None: if isinstance(scan_duration, units.Quantity): duration_seconds = scan_duration.to(units.Unit("s")).value else: value = scan_duration.get("value") unit = scan_duration.get("unit") duration_seconds = ( units.Quantity(value=value, unit=unit).to(units.Unit("s")).value ) data["scan_duration_ms"] = timedelta(seconds=duration_seconds) return data # After validator ensures scan_duration is populated @model_validator(mode="after") def _check_scan_duration_populated(self) -> Self: if self.scan_duration is None: raise ValueError("'scan_duration' could not be determined") with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) if self.scan_duration_ms is None: raise ValueError("'scan_duration_ms' could not be determined") return self