class ClusterPhase(PhaseLike):
"""The default Spyral clustering phase, inheriting from PhaseLike
The goal of the clustering phase is to take in a point cloud
and separate the points into individual particle trajectories. In
the default version here, we use scikit-learn's HDBSCAN clustering
algorithm. The clustering phase should come after the Pointcloud/PointcloudLegacy
Phase in the Pipeline and before the EstimationPhase.
Parameters
----------
cluster_params: ClusterParameters
Parameters controlling the clustering algorithm
det_params: DetectorParameters
Parameters describing the detector
Attributes
----------
cluster_params: ClusterParameters
Parameters controlling the clustering algorithm
det_params: DetectorParameters
Parameters describing the detector
"""
def __init__(
self, cluster_params: ClusterParameters, det_params: DetectorParameters
) -> None:
super().__init__(
"Cluster",
incoming_schema=ResultSchema(json_string=POINTCLOUD_SCHEMA),
outgoing_schema=ResultSchema(json_string=CLUSTER_SCHEMA),
)
self.cluster_params = cluster_params
self.det_params = det_params
def create_assets(self, workspace_path: Path) -> bool:
return True
def construct_artifact(
self, payload: PhaseResult, workspace_path: Path
) -> PhaseResult:
result = PhaseResult(
artifacts={
"cluster": self.get_artifact_path(workspace_path)
/ f"{form_run_string(payload.run_number)}.h5"
},
successful=True,
run_number=payload.run_number,
)
return result
def run(
self,
payload: PhaseResult,
workspace_path: Path,
msg_queue: SimpleQueue,
rng: Generator,
) -> PhaseResult:
# Check that point clouds exist
point_path = payload.artifacts["pointcloud"]
if not point_path.exists() or not payload.successful:
spyral_warn(
__name__,
f"Point cloud data does not exist for run {payload.run_number} at phase 2. Skipping.",
)
return PhaseResult.invalid_result(payload.run_number)
result = self.construct_artifact(payload, workspace_path)
point_file = h5.File(point_path, "r")
cluster_file = h5.File(result.artifacts["cluster"], "w")
cloud_group: h5.Group = point_file["cloud"] # type: ignore
if not isinstance(cloud_group, h5.Group):
spyral_error(
__name__, f"Point cloud group not present in run {payload.run_number}!"
)
return PhaseResult.invalid_result(payload.run_number)
min_event: int = cloud_group.attrs["min_event"] # type: ignore
max_event: int = cloud_group.attrs["max_event"] # type: ignore
cluster_group: h5.Group = cluster_file.create_group("cluster")
cluster_group.attrs["min_event"] = min_event
cluster_group.attrs["max_event"] = max_event
nevents = max_event - min_event + 1
total: int
flush_val: int
if nevents < 100:
total = nevents
flush_val = 0
else:
flush_percent = 0.01
flush_val = int(flush_percent * (max_event - min_event))
total = 100
count = 0
msg = StatusMessage(
self.name, 1, total, payload.run_number
) # we always increment by 1
# Process the data
for idx in range(min_event, max_event + 1):
count += 1
if count > flush_val:
count = 0
msg_queue.put(msg)
cloud_data: h5.Dataset | None = None
cloud_name = f"cloud_{idx}"
if cloud_name not in cloud_group:
continue
else:
cloud_data = cloud_group[cloud_name] # type: ignore
if cloud_data is None:
continue
cloud = PointCloud(idx, cloud_data[:].copy())
if np.any(np.diff(cloud.data[:, 2]) < 0.0):
spyral_warn(
__name__,
f"Clustering for event {cloud.event_number} failed because point cloud was not sorted in z",
)
continue
# Here we don't need to use the labels array.
# We just pass it along as needed.
clusters, labels = form_clusters(cloud, self.cluster_params)
joined, labels = join_clusters(clusters, self.cluster_params, labels)
cleaned, _ = cleanup_clusters(joined, self.cluster_params, labels)
# Each event can contain many clusters
cluster_event_group = cluster_group.create_group(f"event_{idx}")
cluster_event_group.attrs["nclusters"] = len(cleaned)
cluster_event_group.attrs["orig_run"] = cloud_data.attrs["orig_run"]
cluster_event_group.attrs["orig_event"] = cloud_data.attrs["orig_event"]
cluster_event_group.attrs["ic_amplitude"] = cloud_data.attrs["ic_amplitude"]
cluster_event_group.attrs["ic_centroid"] = cloud_data.attrs["ic_centroid"]
cluster_event_group.attrs["ic_integral"] = cloud_data.attrs["ic_integral"]
cluster_event_group.attrs["ic_multiplicity"] = cloud_data.attrs[
"ic_multiplicity"
]
for cidx, cluster in enumerate(cleaned):
local_group = cluster_event_group.create_group(f"cluster_{cidx}")
local_group.attrs["label"] = cluster.label
local_group.create_dataset("cloud", data=cluster.data)
spyral_info(__name__, f"Phase Cluster complete for run {payload.run_number}")
return result