Skip to content

estimation_phase Module

The default defintion of an estimation phase in Spyral

EstimationPhase

Bases: PhaseLike

The default Spyral estimation phase, inheriting from PhaseLike

The goal of the estimation phase is to get reasonable estimations of the physical properties of a particle trajectory (Bρ , reaction angle, etc.) for use in the more complex solving phase to follow. EstimationPhase should come after ClusterPhase and before InterpSolverPhase in the Pipeline.

Parameters:

Name Type Description Default
estimate_params EstimateParameters

Parameters controlling the estimation algorithm

required
det_params DetectorParameters

Parameters describing the detector

required

Attributes:

Name Type Description
estimate_params EstimateParameters

Parameters controlling the estimation algorithm

det_params DetectorParameters

Parameters describing the detector

Source code in src/spyral/phases/estimation_phase.py
class EstimationPhase(PhaseLike):
    """The default Spyral estimation phase, inheriting from PhaseLike

    The goal of the estimation phase is to get reasonable estimations of
    the physical properties of a particle trajectory (Bρ , reaction angle, etc.)
    for use in the more complex solving phase to follow. EstimationPhase should come
    after ClusterPhase and before InterpSolverPhase in the Pipeline.

    Parameters
    ----------
    estimate_params: EstimateParameters
        Parameters controlling the estimation algorithm
    det_params: DetectorParameters
        Parameters describing the detector

    Attributes
    ----------
    estimate_params: EstimateParameters
        Parameters controlling the estimation algorithm
    det_params: DetectorParameters
        Parameters describing the detector

    """

    def __init__(
        self, estimate_params: EstimateParameters, det_params: DetectorParameters
    ):
        super().__init__(
            "Estimation",
            incoming_schema=ResultSchema(CLUSTER_SCHEMA),
            outgoing_schema=ResultSchema(ESTIMATE_SCHEMA),
        )
        self.estimate_params = estimate_params
        self.det_params = det_params

    def create_assets(self, workspace_path: Path) -> bool:
        return True

    def construct_artifact(
        self, payload: PhaseResult, workspace_path: Path
    ) -> PhaseResult:
        result = PhaseResult(
            artifacts={
                "estimation": self.get_artifact_path(workspace_path)
                / f"{form_run_string(payload.run_number)}.parquet",
                "cluster": payload.artifacts["cluster"],
            },
            successful=True,
            run_number=payload.run_number,
        )
        return result

    def run(
        self,
        payload: PhaseResult,
        workspace_path: Path,
        msg_queue: SimpleQueue,
        rng: Generator,
    ) -> PhaseResult:
        # Check that clusters exist
        cluster_path = payload.artifacts["cluster"]
        if not cluster_path.exists() or not payload.successful:
            spyral_warn(
                __name__,
                f"Cluster file for run {payload.run_number} not present for phase 3. Skipping.",
            )
            return PhaseResult.invalid_result(payload.run_number)

        result = self.construct_artifact(payload, workspace_path)

        cluster_file = h5.File(cluster_path, "r")
        cluster_group: h5.Group = cluster_file["cluster"]  # type: ignore
        if not isinstance(cluster_group, h5.Group):
            spyral_error(
                __name__, f"Cluster group not present for run {payload.run_number}!"
            )
            return PhaseResult.invalid_result(payload.run_number)

        min_event: int = cluster_group.attrs["min_event"]  # type: ignore
        max_event: int = cluster_group.attrs["max_event"]  # type: ignore

        nevents = max_event - min_event + 1
        total: int
        flush_val: int
        if nevents < 100:
            total = nevents
            flush_val = 0
        else:
            flush_percent = 0.01
            flush_val = int(flush_percent * (max_event - min_event))
            total = 100

        count = 0

        # estimation results, list of dicts
        results = []

        msg = StatusMessage(
            self.name, 1, total, payload.run_number
        )  # We always increment by 1
        # Process data
        for idx in range(min_event, max_event + 1):
            count += 1
            if count > flush_val:
                count = 0
                msg_queue.put(msg)

            event: h5.Group | None = None
            event_name = f"event_{idx}"
            if event_name not in cluster_group:
                continue
            else:
                event = cluster_group[event_name]  # type: ignore

            nclusters: int = event.attrs["nclusters"]  # type: ignore
            ic_amp = float(event.attrs["ic_amplitude"])  # type: ignore
            ic_cent = float(event.attrs["ic_centroid"])  # type: ignore
            ic_int = float(event.attrs["ic_integral"])  # type: ignore
            ic_mult = float(event.attrs["ic_multiplicity"])  # type: ignore
            orig_run = int(event.attrs["orig_run"])  # type: ignore
            orig_event = int(event.attrs["orig_event"])  # type: ignore
            # Go through every cluster in each event
            for cidx in range(0, nclusters):
                local_cluster: h5.Group | None = None
                cluster_name = f"cluster_{cidx}"
                if cluster_name not in event:  # type: ignore
                    continue
                else:
                    local_cluster = event[cluster_name]  # type: ignore

                cluster = Cluster(
                    idx,
                    local_cluster.attrs["label"],  # type: ignore
                    local_cluster["cloud"][:].copy(),  # type: ignore
                )

                # Cluster is loaded do some analysis
                res = estimate_physics(
                    cidx,
                    cluster,
                    ic_amp,
                    ic_cent,
                    ic_int,
                    ic_mult,
                    orig_run,
                    orig_event,
                    self.estimate_params,
                    self.det_params,
                )
                if res is not None:
                    results.append(vars(res))

        # Write the results to a DataFrame
        df = pl.DataFrame(results)
        df.write_parquet(result.artifacts["estimation"])
        spyral_info(__name__, f"Phase Estimation complete for run {payload.run_number}")
        # Next step also needs to know where to find the clusters
        return result