Source code for katgpucbf.vgpu.engine

################################################################################
# Copyright (c) 2025-2026, National Research Foundation (SARAO)
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use
# this file except in compliance with the License. You may obtain a copy
# of the License at
#
#   https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################

"""Engine class, which does all the actual processing."""

import asyncio
import logging
from collections.abc import AsyncIterator, Sequence
from dataclasses import dataclass
from fractions import Fraction

import aiokatcp
import cupy as cp
import cupyx
import katcbf_vlbi_resample.cupy_bridge
import katcbf_vlbi_resample.parameters
import katcbf_vlbi_resample.polarisation
import katcbf_vlbi_resample.power
import katcbf_vlbi_resample.rechunk
import katcbf_vlbi_resample.resample
import katcbf_vlbi_resample.stream
import katcbf_vlbi_resample.utils
import katcbf_vlbi_resample.vdif_writer
import numpy as np
import spead2.recv.asyncio
import xarray as xr
from astropy.time import Time

from .. import COMPLEX, N_POLS
from .. import recv as base_recv
from ..monitor import Monitor
from ..recv import RECV_SENSOR_TIMEOUT_CHUNKS, RECV_SENSOR_TIMEOUT_MIN
from ..ringbuffer import ChunkRingbuffer
from ..utils import Engine, TimeConverter
from . import N_SIDEBANDS, recv, send

logger = logging.getLogger(__name__)


[docs] class RecvStream: """Wrap the incoming data stream into a :class:`.katcbf_vlbi_resample.stream.Stream`.""" def __init__( self, layout: recv.Layout, time_converter: TimeConverter, stream_group: spead2.recv.ChunkStreamRingGroup, sensors: aiokatcp.SensorSet, pol_labels: tuple[str, str], min_timestamp: int, ) -> None: self._layout = layout self._time_converter = time_converter self._stream_group = stream_group self._sensors = sensors self._pol_labels = pol_labels self._samples_between_spectra = layout.heap_timestamp_step // layout.n_spectra_per_heap self._min_timestamp = min_timestamp # Properties required by the Stream protocol self.channels = layout.n_channels self.is_cupy = True self.time_base = Time(time_converter.sync_time, scale="utc", format="unix") self.time_scale = Fraction(self._samples_between_spectra) / Fraction(time_converter.adc_sample_rate) async def __aiter__(self) -> AsyncIterator[xr.DataArray]: for stream in self._stream_group: stream.start() data_ringbuffer = self._stream_group.data_ringbuffer assert isinstance(data_ringbuffer, spead2.recv.asyncio.ChunkRingbuffer) last_chunk_id: int | None = None async for chunk in recv.iter_chunks( data_ringbuffer, self._layout, self._sensors, self._time_converter, [label[-1] for label in self._pol_labels], ): with chunk: if chunk.timestamp < self._min_timestamp: continue # TODO: need to do something with the presence flags # TODO: pipeline these transfers (but keeping in mind # that we need to recycle the chunk only when the transfer # is complete). data = cp.asarray(chunk.data, blocking=False) await katcbf_vlbi_resample.utils.stream_future(None) # There are two time axes. Transpose to place them together, then flatten # over them. The current shape is # (N_POLS, n_batches_per_chunk, channels, n_spectra_per_heap, COMPLEX) data = data.transpose(0, 1, 3, 2, 4) # Now it is # (N_POLS, n_batches_per_chunk, n_spectra_per_heap, channels, COMPLEX) data = data.reshape(N_POLS, -1, self.channels, COMPLEX) # Now it is # (N_POLS, n_spectra_per_chunk, channels, COMPLEX) # Convert Gaussian integers to complex data = cp.ascontiguousarray(data.astype(np.float32)).view(np.complex64)[..., 0] arr = xr.DataArray( data, dims=("pol", "time", "channel"), coords={"pol": list(self._pol_labels)}, attrs={"time_bias": chunk.timestamp // self._samples_between_spectra}, ) # TODO (NGC-1689): need to properly handle missing data in # katcbf-vlbi-resample. This is a quick hack to keep things # running by injecting zero data into the stream. while last_chunk_id is not None and last_chunk_id < chunk.chunk_id - 1: last_chunk_id += 1 zero_arr = xr.zeros_like(arr) timestamp = last_chunk_id * self._layout.chunk_timestamp_step zero_arr.attrs["time_bias"] = timestamp // self._samples_between_spectra yield zero_arr last_chunk_id = chunk.chunk_id yield arr
[docs] class RecordPower(katcbf_vlbi_resample.power.RecordPower): """Record power levels to sensors.""" def __init__(self, *args, sensors: aiokatcp.SensorSet, **kwargs) -> None: super().__init__(*args, **kwargs) self.sensors = sensors
[docs] def record_rms(self, start: int, length: int, rms: xr.DataArray) -> None: # noqa: D102 end = start + length end_time = self.time_base + katcbf_vlbi_resample.utils.fraction_to_time_delta(end * self.time_scale) end_time_unix = float(end_time.unix) power = rms**2 for pol in power.coords["pol"].values: for sideband in power.coords["sideband"].values: channel = ["lsb", "usb"].index(sideband) sensor = self.sensors[f"{pol}{channel}.mean-power"] sensor.set_value(power.sel(pol=pol, sideband=sideband).item(), timestamp=end_time_unix)
[docs] @dataclass class RecvConfig: """Container for all the configuration for receiving data.""" sync_time: float adc_sample_rate: float n_channels: int n_channels_per_substream: int n_spectra_per_heap: int n_samples_between_spectra: int n_batches_per_chunk: int sample_bits: int srcs: list[list[tuple[str, int]]] interface: str | None ibv: bool affinity: int comp_vector: int buffer_size: int pols: tuple[str, str] @property def pol_labels(self) -> list[str]: """Incoming polarisations without any ± prefix.""" return [pol[-1] for pol in self.pols] def __post_init__(self) -> None: self.layout = recv.Layout( sample_bits=self.sample_bits, n_channels=self.n_channels, n_channels_per_substream=self.n_channels_per_substream, n_spectra_per_heap=self.n_spectra_per_heap, n_batches_per_chunk=self.n_batches_per_chunk, heap_timestamp_step=self.n_samples_between_spectra * self.n_spectra_per_heap, ) self.time_converter = TimeConverter(self.sync_time, self.adc_sample_rate)
[docs] @dataclass class SendConfig: """Container for all the configuration for sending data.""" pols: tuple[str, str] bandwidth: float n_samples_per_frame: int rate_factor: float station: str dsts: list[tuple[str, int]] interfaces: list[str] buffer_size: int ttl: int
[docs] @dataclass class CaptureConfig: """Container for all the configuration needed to run a capture session.""" recv_config: RecvConfig send_config: SendConfig fir_taps: int hilbert_taps: int passband: float threshold: float power_int_time: int def __post_init__(self) -> None: self.pol_matrix = katcbf_vlbi_resample.polarisation.from_linear(self.send_config.pols) self.pol_matrix @= katcbf_vlbi_resample.polarisation.to_linear(self.recv_config.pols) self.resample_parameters = katcbf_vlbi_resample.parameters.ResampleParameters( fir_taps=self.fir_taps, hilbert_taps=self.hilbert_taps, passband=self.passband, ) self.threads = [ {"sideband": sideband, "pol": pol} for sideband in ["lsb", "usb"] for pol in self.send_config.pols ]
class _CaptureSession: """Manage the lifetime of actions between ``?capture-start`` and ``?capture-stop``.""" def __init__( self, config: CaptureConfig, engine: Engine, monitor: Monitor, min_timestamp: int, sender: send.VDIFSender ) -> None: recv_chunks = 4 # TODO: may need tuning? data_ringbuffer = ChunkRingbuffer(recv_chunks, name="recv_data_ringbuffer", task_name="run", monitor=monitor) free_ringbuffer = spead2.recv.ChunkRingbuffer(recv_chunks) layout = config.recv_config.layout dtype = np.dtype(f"int{layout.sample_bits}") recv_group = recv.make_stream_group( layout, data_ringbuffer, free_ringbuffer, config.recv_config.affinity, config.recv_config.pol_labels ) for _ in range(recv_chunks): chunk = recv.Chunk( present=np.empty( (N_POLS, layout.n_batches_per_chunk, layout.n_pol_substreams), np.uint8, ), data=cupyx.empty_pinned( (N_POLS, layout.n_batches_per_chunk, layout.n_channels, layout.n_spectra_per_heap, COMPLEX), dtype, ), sink=recv_group, ) chunk.recycle() # Make available to the stream for i, stream in enumerate(recv_group): base_recv.add_reader( stream, src=config.recv_config.srcs[i], interface=config.recv_config.interface, ibv=config.recv_config.ibv, comp_vector=config.recv_config.comp_vector, buffer_size=config.recv_config.buffer_size // len(recv_group), ) self.config = config self._recv_group = recv_group self._sensors = engine.sensors self._min_timestamp = min_timestamp self._sender = sender self._capture_task = asyncio.create_task(self._capture(), name="Capture Loop") engine.add_service_task(self._capture_task, wait_on_stop=True) def _capture_complete(self) -> None: """Handle the end of all processing. This method exists only to mock from unit tests. .. todo:: Remove this method once no longer needed by unit tests. """ pass async def _capture(self) -> None: """Do all the primary work of the engine. This is an asyncio task that runs as a service task of the device server. """ # Copy some references just to make the code shorter config = self.config recv_config = config.recv_config send_config = config.send_config it: katcbf_vlbi_resample.stream.Stream[xr.DataArray] = RecvStream( recv_config.layout, recv_config.time_converter, self._recv_group, self._sensors, recv_config.pols, self._min_timestamp, ) it = katcbf_vlbi_resample.cupy_bridge.AsCupy(it) it = katcbf_vlbi_resample.resample.IFFT(it) it = katcbf_vlbi_resample.polarisation.ConvertPolarisation( it, config.pol_matrix, recv_config.pols, send_config.pols ) it = katcbf_vlbi_resample.resample.Resample(send_config.bandwidth, 0.0, config.resample_parameters, it) it = katcbf_vlbi_resample.rechunk.Rechunk.align_utc_seconds(it) it_rms: katcbf_vlbi_resample.stream.Stream[xr.Dataset] = katcbf_vlbi_resample.power.MeasurePower(it) it_rms = RecordPower(it_rms, sensors=self._sensors) it = katcbf_vlbi_resample.power.NormalisePower(it_rms, 1.0) it = katcbf_vlbi_resample.vdif_writer.VDIFEncode2Bit( it, samples_per_frame=send_config.n_samples_per_frame, threshold=config.threshold ) it = katcbf_vlbi_resample.cupy_bridge.AsNumpy(it) frameset_it = katcbf_vlbi_resample.vdif_writer.VDIFFormatter( it, config.threads, station=send_config.station, samples_per_frame=send_config.n_samples_per_frame ) async for frameset in frameset_it: await self._sender.send(frameset) self._capture_complete() async def stop(self) -> None: """Stop the capture.""" self._recv_group.stop() await self._capture_task
[docs] class VEngine(Engine): """Top-level class running the whole thing.""" VERSION = "katgpucbf-vgpu-1.0" def __init__( self, *, katcp_host: str, katcp_port: int, config: CaptureConfig, monitor: Monitor, ) -> None: super().__init__(katcp_host, katcp_port) self.config = config self.monitor = monitor recv_config = config.recv_config send_config = config.send_config recv_sensor_timeout = max( RECV_SENSOR_TIMEOUT_MIN, RECV_SENSOR_TIMEOUT_CHUNKS * recv_config.layout.chunk_timestamp_step / recv_config.adc_sample_rate, ) self._populate_sensors(self.sensors, recv_config.pol_labels, send_config.pols, recv_sensor_timeout) self._capture: _CaptureSession | None = None send_rate = send_config.bandwidth * send_config.rate_factor # Data comes out of the processing chain in chunks of size # power_int_time. We need to smooth that out, so we use a send # queue that is deeper than that (2 is the number of chunks to # buffer). queue_size = round(2 * config.power_int_time * send_config.bandwidth / send_config.n_samples_per_frame) self._sender = send.VDIFSender( send_config.dsts, send_rate, send_rate * 2.0, # Python can introduce large pauses, so catch up aggressively queue_size, ttl=send_config.ttl, buffer_size=send_config.buffer_size, interfaces=send_config.interfaces, ) # Reference counters to make the labels exist before the first scrape for pol in recv_config.pol_labels: recv.counters.labels(str(pol)) def _populate_sensors( self, sensors: aiokatcp.SensorSet, recv_pol_labels: Sequence[str], send_pols: Sequence[str], recv_sensor_timeout: float, ) -> None: """Define the sensors for the engine.""" for pol in send_pols: for channel in range(N_SIDEBANDS): sensors.add( aiokatcp.Sensor( float, f"{pol}{channel}.mean-power", "Mean power over the previous interval of length power-int-time", ) ) sensors.add( aiokatcp.Sensor( float, "delay", "Delay introduced by processing", units="s", default=0.0, initial_status=aiokatcp.Sensor.Status.NOMINAL, ) ) prefixes = [f"{pol}." for pol in recv_pol_labels] for sensor in base_recv.make_sensors(recv_sensor_timeout, prefixes).values(): sensors.add(sensor)
[docs] async def on_stop(self) -> None: # noqa: D102 if self._capture is not None: await self._stop_capture() await self._sender.stop() await super().on_stop()
[docs] async def request_vlbi_delay(self, ctx: aiokatcp.RequestContext, delay: float) -> None: """Set the delay applied to the stream, in second.""" # TODO: will need to be rounded/quantised self.sensors["delay"].value = delay
[docs] async def request_capture_start(self, ctx: aiokatcp.RequestContext, timestamp: int = 0) -> None: """Start capturing and emitting data. Parameters ---------- timestamp Minimum ADC timestamp at which to enable emitting. """ if self._capture is not None: raise aiokatcp.FailReply("a capture is already in progress") # TODO: use delay self._capture = _CaptureSession(self.config, self, self.monitor, timestamp, self._sender)
async def _stop_capture(self) -> None: assert self._capture is not None try: await self._capture.stop() finally: self._capture = None
[docs] async def request_capture_stop(self, ctx: aiokatcp.RequestContext) -> None: """Stop capturing and emitting data.""" if self._capture is None: raise aiokatcp.FailReply("no capture in progress") await self._stop_capture()