################################################################################
# Copyright (c) 2021-2024, National Research Foundation (SARAO)
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use
# this file except in compliance with the License. You may obtain a copy
# of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""Synthesis of simulated signals."""
import asyncio
import dataclasses
import logging
import math
import multiprocessing.connection
import operator
import os
import signal
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from typing import ClassVar, Self
import dask.array as da
import dask.base
import numba
import numpy as np
import pyparsing as pp
import scipy.fft
import xarray as xr
from .. import BYTE_BITS
logger = logging.getLogger(__name__)
#: Dask chunk size for sampling signals (must be a multiple of 8)
CHUNK_SIZE = 2**20
def _sample_helper(n: int, chunk_data: da.Array, sample_chunk: Callable, **kwargs) -> da.Array:
"""Help implement :meth:`sample` in subclasses of :class:`Signal`.
This can be used when each chunk can be generated from a single value
per chunk (plus some chunk-independent state).
Currently the call to `sample_chunk` for the final chunk will use the
full chunk size, unless `n` is less than CHUNK_SIZE. In future it may
take care of truncating the final chunk.
Parameters
----------
n
Number of samples to generate
chunk_data
Array holding per-chunk data. It must have the same number of
chunks as the output.
sample_chunk
Callback to generate a single chunk. It is passed a chunk from
`chunk_data` as a positional argument, and `chunk_size` as a
keyword argument.
kwargs
Additional keyword arguments forwarded to `sample_chunk`.
"""
chunk_size = min(n, CHUNK_SIZE)
n_chunks = (n + chunk_size - 1) // chunk_size
return da.map_blocks(
sample_chunk,
chunk_data,
chunks=((chunk_size,) * n_chunks,),
chunk_size=chunk_size,
**kwargs,
)[:n]
def _sample_helper_random(n: int, seed_seq: np.random.SeedSequence, sample_chunk: Callable, **kwargs) -> da.Array:
"""Generate a 1D dask array of random data.
The random generators are seeded from `seed_seq`. Other parameters are
passed to :func:`_sample_helper`.
"""
n_chunks = (n + CHUNK_SIZE - 1) // CHUNK_SIZE
seed_seqs = seed_seq.spawn(n_chunks)
# Chunk size of 1 so that we can map each SeedSequence to an output chunk
seed_seqs_dask = da.from_array(np.array(seed_seqs, dtype=object), 1)
return _sample_helper(n, seed_seqs_dask, sample_chunk, **kwargs)
def _cw(offset: np.int64, *, amplitude: np.float32, n: int, frequency: float, dtype: np.dtype) -> np.ndarray:
r"""Compute :math:`\cos(2\pi f(offset+i))` for i in :math:`[0, n)`.
The `dtype` is the complex dtype in which to perform calculations.
The return value is in the equivalent real dtype.
"""
# Compute the complex exponential. Because it is being regularly
# sampled, it is possible to do this efficiently by repeated
# doubling. This also makes it possible to keep most of the
# computation in single precision without losing much precision
# (experimentally the results seem to be off by less than 1e-6).
scale = frequency * (2 * np.pi)
out = np.empty(n, dtype)
out[0] = np.exp(offset * scale * 1j) * amplitude
valid = 1
while valid < n:
# Rotate the segment [0, valid) by valid steps, giving the segment
# [valid, 2 * value). It's slightly complicated to handle the case
# where we have to truncate to n.
add = min(valid, n - valid)
rot = np.exp(valid * scale * 1j).astype(dtype)
np.multiply(out[0:add], rot, out[valid : valid + add])
valid += add
return out.real
[docs]
class TerminalError(TypeError):
"""Indicate that a terminal signal has been used in an expression."""
def __init__(self, signal: "Signal") -> None:
self.signal = signal
assert signal.terminal
super().__init__(f"Signal '{signal}' cannot be used in a larger expression")
[docs]
class Signal(ABC):
"""Abstract base class for signals.
An instance is simply a real-valued function of time, for a single
polarisation.
"""
[docs]
@abstractmethod
def sample(self, n: int, sample_rate: float) -> da.Array:
"""Sample the signal at regular intervals.
The returned values should be scaled to the range (-1, 1).
.. note::
Calling this method with two different values of `n` may
yield results that are not consistent with each other.
Parameters
----------
n
Number of samples to generate
sample_rate
Frequency of samples (Hz)
Returns
-------
samples
Dask array of samples, float32. The chunk size must be CHUNK_SIZE.
"""
@property
def terminal(self) -> bool:
"""Indicate whether the signal is terminal.
Terminal signals cannot be combined into larger expressions, because
they contain information about how to handle their postprocessing.
"""
return False
def __add__(self, other) -> "Signal":
if not isinstance(other, Signal):
return NotImplemented
return CombinedSignal(self, other, operator.add, "+")
def __sub__(self, other) -> "Signal":
if not isinstance(other, Signal):
return NotImplemented
return CombinedSignal(self, other, operator.sub, "-")
def __mul__(self, other) -> "Signal":
if not isinstance(other, Signal):
return NotImplemented
return CombinedSignal(self, other, operator.mul, "*")
[docs]
@dataclass
class CombinedSignal(Signal):
"""Signal built by combining two other signals.
Parameters
----------
a, b
Input signals
combine
Operator to combine two arrays
op_name
Symbol for the operator
"""
a: Signal
b: Signal
combine: Callable[[da.Array, da.Array], da.Array]
op_name: str
def __post_init__(self):
if self.a.terminal:
raise TerminalError(self.a)
if self.b.terminal:
raise TerminalError(self.b)
[docs]
def sample(self, n: int, sample_rate: float) -> da.Array: # noqa: D102
return self.combine(self.a.sample(n, sample_rate), self.b.sample(n, sample_rate))
def __str__(self) -> str:
return f"({self.a} {self.op_name} {self.b})"
[docs]
@dataclass
class Constant(Signal):
"""Fixed value."""
value: float
[docs]
def sample(self, n: int, sample_rate: float) -> da.Array: # noqa: D102
return da.full((n,), self.value, dtype=np.float32, chunks=CHUNK_SIZE)
def __str__(self) -> str:
return f"{self.value}"
[docs]
@dataclass
class Periodic(Signal):
"""Base class for periodic signals.
The frequency is adjusted during sampling so that the sampled result can
be looped.
"""
amplitude: float
frequency: float
_class_name: ClassVar[str] = ""
@staticmethod
@abstractmethod
def _sample_chunk(offset: np.ndarray, *, amplitude: np.float32, chunk_size: int, frequency: float) -> np.ndarray:
"""Compute a single chunk.
Parameters
----------
offset
Index of the first element of the chunk (1-element array)
amplitude
Amplitude of the signal
chunk_size
Number of samples to produce
frequency
Rounded frequency, expressed as cycles per sample
"""
[docs]
def sample(self, n: int, sample_rate: float) -> da.Array: # noqa: D102
# Round target frequency to fit an integer number of waves into signal_heaps
waves = max(1, round(n * self.frequency / sample_rate))
frequency = waves * sample_rate / n
logger.info(f"Rounded tone frequency to {frequency} Hz")
# Index of the firstelement of each chunk
offsets = da.arange(0, n, CHUNK_SIZE, chunks=1, dtype=np.int64)
return _sample_helper(
n,
offsets,
self._sample_chunk,
amplitude=np.float32(self.amplitude),
frequency=waves / n,
meta=np.array((), np.float32),
)
def __str__(self) -> str:
return f"{self._class_name}({self.amplitude}, {self.frequency})"
[docs]
@dataclass
class CW(Periodic):
"""Continuous wave.
To make the resulting signal periodic, the frequency is adjusted during
sampling so that the sampled result can be looped.
"""
_class_name: ClassVar[str] = "cw"
@staticmethod
def _sample_chunk(offset: np.ndarray, *, amplitude: np.float32, chunk_size: int, frequency: float) -> np.ndarray:
return _cw(offset[0], amplitude=amplitude, n=chunk_size, frequency=frequency, dtype=np.dtype(np.complex64))
[docs]
@dataclass
class MultiCW(Signal):
"""Sum of multiple continuous waves.
The amplitudes and frequencies are both arithmetic progressions. The
frequencies are adjusted during sampling so that the sampled result
can be looped.
"""
m: int
amplitude0: float
amplitude_step: float
frequency0: float
frequency_step: float
_class_name: ClassVar[str] = "multicw"
[docs]
def sample(self, n: int, sample_rate: float) -> da.Array: # noqa: D102
# Build the frequency domain
spectrum = np.zeros(n // 2 + 1, np.float32)
for i in range(self.m):
f = i * self.frequency_step + self.frequency0
# Round target frequency to fit an integer number of waves into signal_heaps
waves = max(1, round(n * f / sample_rate))
pos = waves % n # Positive frequency component
neg = (-waves) % n # Negative frequency component
amp = i * self.amplitude_step + self.amplitude0
if 0 <= pos < len(spectrum):
spectrum[pos] += 0.5 * amp
if 0 <= neg < len(spectrum):
spectrum[neg] += 0.5 * amp
# Inverse FFT to get time domain. When n is even, we can use the DCT,
# which avoids some copies.
if n % 2 == 0:
v = scipy.fft.dct(spectrum, type=1, overwrite_x=True)
# DCT only gives us the first half of the output. Mirror to
# get the second half.
v = np.r_[v, v[-2:0:-1]]
else:
# Fallback path (scipy doesn't implement the variant of the DCT
# that would be needed for this). But we generally expect n to
# be a power of 2.
v = scipy.fft.irfft(spectrum, n=n, norm="forward", overwrite_x=True)
# Generate a unique name. Dask's default is to hash the array itself,
# which is somewhat slow. Instead, hash the parameters used to
# create it.
token = dask.base.tokenize(dataclasses.astuple(self), n, sample_rate)
return da.from_array(v, chunks=CHUNK_SIZE, name="multicw-" + token)
def __str__(self) -> str:
return (
f"{self._class_name}({self.m}, {self.amplitude0}, {self.amplitude_step}, "
"{self.frequency0}, {self.frequency_step})"
)
[docs]
@dataclass
class Comb(Periodic):
"""Signal with periodic impulses.
To make the resulting signal periodic, the frequency is adjusted during
sampling so that the sampled result can be looped.
"""
_class_name: ClassVar[str] = "comb"
@staticmethod
def _sample_chunk(offset: np.ndarray, *, amplitude: np.float32, chunk_size: int, frequency: float) -> np.ndarray:
start = offset[0]
stop = start + chunk_size
start_cycle = math.ceil(start * frequency)
stop_cycle = math.ceil(stop * frequency)
# Rounding errors can make start_cycle/stop_cycle be off by 1. So we
# include an extra cycle on each end to ensure we don't miss anything
# at the edges, then trim again if necessary.
cycles = np.arange(start_cycle - 1, stop_cycle + 1)
indices = np.rint(cycles / np.float32(frequency)).astype(int)
indices = indices[(start <= indices) & (indices < stop)]
out = np.zeros(chunk_size, np.float32)
out[indices - start] = amplitude
return out
[docs]
@dataclass
class Random(Signal):
"""Base class for randomly-generated signals.
This base class is only suitable when the samples at different times
are independent. The derived class must implement :meth:`_sample_chunk`.
"""
entropy: int #: entropy used to populate a :class:`np.random.SeedSequence`
def __init__(self, entropy: int | None = None) -> None:
self.entropy = entropy if entropy is not None else self._generate_entropy()
def _generate_entropy(self) -> int:
"""Generate a random seed.
This is split into a separate method so that it can be easily mocked.
"""
# In general SeedSequence.entropy is not always an int, but it is for
# this usage.
return np.random.SeedSequence().entropy # type: ignore
@abstractmethod
def _sample_chunk(self, seed_seq: Sequence[np.random.SeedSequence], *, chunk_size: int) -> np.ndarray:
"""Sample random values from a single chunk.
Parameters
----------
seed_seq
A single-element list with the entropy to use to initialise a random generator
chunk_size
The number of elements to generate
"""
[docs]
def sample(self, n: int, sample_rate: float) -> da.Array: # noqa: D102
seed_seq = np.random.SeedSequence(self.entropy)
return _sample_helper_random(n, seed_seq, self._sample_chunk, meta=np.array((), np.float32))
[docs]
@dataclass
class WGN(Random):
"""White Gaussian Noise signal.
Each sample in time is an independent Gaussian random variable with zero
mean and a given standard deviation.
In practice, the signal has a period equal to the value of `n` given to
:meth:`sample`, which could lead to undesirable correlations.
Parameters
----------
std
Standard deviation of the samples
entropy
If provided, used to seed the random number generator
"""
std: float = 1.0 #: standard deviation
def __init__(self, std: float, entropy: int | None = None) -> None:
# __init__ is overridden to change the argument order
super().__init__(entropy)
self.std = std
def _sample_chunk(self, seed_seq: Sequence[np.random.SeedSequence], *, chunk_size: int) -> np.ndarray:
# The RNG is initialised every time this is called so that it will
# produce the same results.
rng = np.random.default_rng(seed_seq[0])
return rng.standard_normal(size=chunk_size, dtype=np.float32) * self.std
def __str__(self) -> str:
return f"wgn({self.std}, {self.entropy})"
[docs]
@dataclass
class Delay(Signal):
"""Delay another signal by an integer number of samples.
Parameters
----------
signal
Underlying signal to delay
delay
Number of samples to delay the signal (may be negative)
"""
signal: Signal
delay: int
def __post_init__(self) -> None:
if self.signal.terminal:
raise TerminalError(self.signal)
[docs]
def sample(self, n: int, sample_rate: float) -> da.Array: # noqa: D102
return da.roll(self.signal.sample(n, sample_rate), self.delay)
def __str__(self) -> str:
return f"delay({self.signal}, {self.delay})"
[docs]
@dataclass
class Nodither(Signal):
"""Mark a signal expression as not needing dither.
Parameters
----------
signal
Underlying signal
"""
signal: Signal
def __post_init__(self) -> None:
if self.signal.terminal:
raise TerminalError(self.signal)
[docs]
def sample(self, n: int, sample_rate: float) -> da.Array: # noqa: D102
return self.signal.sample(n, sample_rate)
def __str__(self) -> str:
return f"nodither({self.signal})"
@property
def terminal(self) -> bool:
"""Prevent this signal from being used in expressions."""
return True
def _apply_operator(s: str, loc: int, tokens: pp.ParseResults) -> Signal:
assert len(tokens) == 1
tokens = tokens[0] # infix_operator passes the expression with an extra nesting level
op_map = {"*": operator.mul, "+": operator.add, "-": operator.sub}
result = tokens[0]
for i in range(1, len(tokens), 2):
result = op_map[tokens[i]](result, tokens[i + 1])
return result
[docs]
def parse_signals(prog: str) -> list[Signal]:
"""Generate a set of signals from a domain-specific language.
See :ref:`dsim-dsl` for a description of the language.
"""
var_table = {}
output = []
def assign(s: str, loc: int, tokens: pp.ParseResults) -> None:
var_table[tokens[0]] = tokens[1]
def get_variable(s: str, loc: int, tokens: pp.ParseResults) -> None:
try:
return var_table[tokens[0]]
except KeyError:
raise pp.ParseFatalException("", loc, f"Unknown variable {tokens[0]!r}") from None
lpar = pp.Suppress("(")
rpar = pp.Suppress(")")
comma = pp.Suppress(",")
eq = pp.Suppress("=")
semicolon = pp.Suppress(";")
variable = pp.pyparsing_common.identifier("variable")
real = pp.pyparsing_common.number
integer = pp.pyparsing_common.integer
signed_integer = pp.pyparsing_common.signed_integer
expr = pp.Forward()
# See https://pyparsing-docs.readthedocs.io/en/latest/HowToUsePyparsing.html#expression-subclasses
# for an explanation of + versus - in these rules (it helps give more
# useful errors). I've been conservative with the use of - i.e., there
# may be some +'s that can still be changed to -'s.
cw = pp.Keyword("cw") + lpar - real + comma - real + rpar
cw.set_parse_action(lambda s, loc, tokens: CW(tokens[1], tokens[2]))
multicw = pp.Keyword("multicw") + lpar - integer + comma - real + comma - real + comma - real + comma - real + rpar
multicw.set_parse_action(lambda s, loc, tokens: MultiCW(tokens[1], tokens[2], tokens[3], tokens[4], tokens[5]))
comb = pp.Keyword("comb") + lpar - real + comma - real + rpar
comb.set_parse_action(lambda s, loc, tokens: Comb(tokens[1], tokens[2]))
wgn = pp.Keyword("wgn") + lpar - real + pp.Opt(comma - integer("entropy")) + rpar
wgn.set_parse_action(lambda s, loc, tokens: WGN(tokens[1], tokens.get("entropy")))
delay = pp.Keyword("delay") + lpar - expr + comma - signed_integer + rpar
delay.set_parse_action(lambda s, loc, tokens: Delay(tokens[1], tokens[2]))
nodither = pp.Keyword("nodither") + lpar - expr + rpar
nodither.set_parse_action(lambda s, loc, tokens: Nodither(tokens[1]))
variable_expr = variable.copy()
variable_expr.set_parse_action(get_variable)
real_expr = real.copy()
real_expr.set_parse_action(lambda s, loc, tokens: Constant(float(tokens[0])))
atom = real_expr | cw | multicw | comb | wgn | delay | nodither | variable_expr
expr <<= pp.infix_notation(
atom, [("*", 2, pp.OpAssoc.LEFT, _apply_operator), (pp.one_of("+ -"), 2, pp.OpAssoc.LEFT, _apply_operator)]
)
assignment = variable + eq - expr
assignment.set_parse_action(assign)
expr_statement = expr.copy()
expr_statement.add_parse_action(lambda s, loc, tokens: output.append(tokens[0]))
statement = (assignment | expr_statement) - semicolon
program = statement[...]
program.parse_string(prog, parse_all=True)
return output
def _dither_sample_chunk(seed_seq: Sequence[np.random.SeedSequence], *, chunk_size: int) -> np.ndarray:
"""Produce one chunk for :func:`dither`."""
rng = np.random.default_rng(seed_seq[0])
return rng.random(size=chunk_size, dtype=np.float32) - np.float32(0.5)
[docs]
def make_dither(n_pols: int, n: int, entropy: int | None = None) -> xr.DataArray:
"""Create a set of dither signals to use with :func:`quantise`.
The returned array has ``pol`` and ``data`` axes, and is backed by a Dask
array.
The implementation currently uses a uniform distribution, but that is
subject to change.
"""
seed_seqs = np.random.SeedSequence(entropy).spawn(n_pols)
d = da.stack(
[
_sample_helper_random(n, seed_seq, _dither_sample_chunk, meta=np.array((), np.float32))
for seed_seq in seed_seqs
]
)
return xr.DataArray(d, dims=["pol", "data"])
@numba.njit
def _clip(a, a_min, a_max):
"""Like np.clip, but for scalars.
It's not working in numba: https://github.com/numba/numba/issues/3469.
"""
if a < a_min:
return a_min
elif a > a_max:
return a_max
else:
return a
@numba.njit(nogil=True)
def _quantise_chunk(chunk: np.ndarray, dither: np.ndarray, scale: np.float32) -> np.ndarray:
out = np.empty_like(chunk, dtype=np.int32)
for i in range(chunk.shape[0]):
scaled = chunk[i] * scale + dither[i]
out[i] = np.rint(_clip(scaled, -scale, scale))
return out
[docs]
def quantise(
data: da.Array,
bits: int,
dither: da.Array,
) -> da.Array:
"""Convert floating-point data to fixed-point.
Parameters
----------
data
Array of values, nominally in the range -1 to 1 (values outside the
range are clamped).
bits
Total number of bits per output sample (including the sign bit). The
input values are scaled by :math:`2^{bits-1} - 1`.
dither
Values to add to the data after scaling.
"""
scale = np.float32(2 ** (bits - 1) - 1)
return da.blockwise(_quantise_chunk, "i", data, "i", dither, "i", scale=scale, meta=np.array((), np.int32))
@numba.njit(nogil=True)
def _saturation_counts(data: np.ndarray, saturation_value: np.integer) -> np.ndarray:
out = np.empty((data.shape[0], 1), np.uint64)
for i in range(data.shape[0]):
total = np.uint64(0)
for j in range(data.shape[1]):
total += np.uint64(data[i, j] >= saturation_value or data[i, j] <= -saturation_value)
out[i, 0] = total
return out
[docs]
def saturation_counts(data: da.Array, saturation_value) -> da.Array:
"""Return an array indicating counts of saturated elements of ``data``.
The count is taken along each row of ``data``.
Elements are considered saturated if they exceed `saturation_value` in
absolute value.
"""
assert data.ndim == 2
# Ensure the saturation value is already a numpy scalar
saturation_value = data.dtype.type(saturation_value)
block_sums = da.map_blocks(
_saturation_counts,
data,
meta=np.array((), np.uint64),
chunks=(data.chunks[0], (1,) * len(data.chunks[1])),
saturation_value=saturation_value,
)
return da.sum(block_sums, axis=1)
@numba.njit(nogil=True)
def _packbits(data: np.ndarray, bits: int) -> np.ndarray:
# Note: needs lots of explicit casting to np.uint64, as otherwise
# numba seems to want to infer double precision.
out = np.zeros(data.size * bits // BYTE_BITS, np.uint8)
buf = np.uint64(0)
buf_size = 0
mask = (np.uint64(1) << bits) - np.uint64(1)
out_pos = 0
for v in data:
buf = (buf << bits) | (np.uint64(v) & mask)
buf_size += bits
while buf_size >= BYTE_BITS:
out[out_pos] = buf >> (buf_size - BYTE_BITS)
out_pos += 1
buf_size -= BYTE_BITS
return out
[docs]
def packbits(data: da.Array, bits: int) -> da.Array:
"""Pack integers into bytes.
The least-significant `bits` bits of each integer in `data` is collected
together in big-endian order, and returned as a sequence of bytes. The
total number of bits must form a whole number of bytes.
If the chunks in `data` are not be aligned on byte boundaries then a
slower path is used.
"""
assert data.ndim == 1
if data.shape[0] * bits % BYTE_BITS:
raise ValueError("Total number of bits is not a multiple of 8")
if not all(c * bits % BYTE_BITS == 0 for c in data.chunks[0]):
assert CHUNK_SIZE % BYTE_BITS == 0
data = data.rechunk(CHUNK_SIZE)
out_chunks = (tuple(c * bits // BYTE_BITS for c in data.chunks[0]),)
return da.map_blocks(_packbits, data, dtype=np.uint8, chunks=out_chunks, bits=bits)
[docs]
def sample(
signals: Sequence[Signal],
timestamp: int,
period: int | None,
sample_rate: float,
sample_bits: int,
out: xr.DataArray,
out_saturated: xr.DataArray | None = None,
saturation_group: int = 1,
*,
dither: bool | xr.DataArray = True,
dither_seed: int | None = None,
) -> None:
"""Sample, quantise and pack a set of signals.
The number of samples to generate is determined from the output array.
Parameters
----------
signals
Signals to sample, one per polarisation
timestamp
Timestamp for the first element to return. The signal is rotated by
this amount.
period
Number of samples after which to repeat. This must divide into the
total number of samples to generate. If not specified, uses the
total number of samples.
sample_rate
Passed to :meth:`Signal.sample`
sample_bits
Passed to :func:`quantise` and :func:`packbits`
out
Output array, with a dimension called ``pol`` (which must match the
number of signals). The other dimensions are flattened.
out_saturated
Output array, with the same shape as ``out``, into which saturation
counts are written.
saturation_group
Samples are taken in contiguous groups of this size and each element of
out_saturated is a saturation count for one group. This must divide
into the total number of samples.
dither
If true (default), add uniform random values in the range [-0.5, 0.5)
after scaling to reduce artefacts. It may also be a :class:`xr.DataArray`
with axes called ``pol`` (which must match the number of signals) and
``data`` (which must have length at least equal to ``period``).
"""
n_pols = out.sizes["pol"]
if len(signals) != n_pols:
raise ValueError(f"Expected {n_pols} signals, received {len(signals)}")
n = out.isel(pol=0).data.size * BYTE_BITS // sample_bits
if period is None:
period = n
if n % period:
raise ValueError(f"period {period} does not divide into total samples {n}")
if n % saturation_group:
raise ValueError(f"saturation_group (saturation_group) does not divide into total samples {n}")
if dither is True:
dither = make_dither(len(signals), period, entropy=dither_seed)
elif dither is False:
dither = xr.DataArray(da.zeros((n_pols, period), np.float32, chunks=CHUNK_SIZE), dims=["pol", "data"])
else:
if dither.sizes["pol"] != n_pols:
raise ValueError(f"Expected {n_pols} dither signals, received {dither.sizes['pol']}")
if dither.sizes["data"] < period:
raise ValueError(f"Expected at least {period} dither samples, only found {dither.sizes['data']}")
dither = dither.isel(data=np.s_[:period])
in_arrays = []
out_arrays = []
for i, sig in enumerate(signals):
data = sig.sample(period, sample_rate)
if sig.terminal:
sig_dither = da.zeros(period, np.float32, chunks=CHUNK_SIZE)
else:
sig_dither = dither.isel(pol=i).data
data = quantise(data, sample_bits, sig_dither)
data = da.roll(data, -timestamp)
if period < n:
data = da.tile(data, n // period)
if out_saturated is not None:
saturated = saturation_counts(data.reshape(-1, saturation_group), 2 ** (sample_bits - 1) - 1)
data = packbits(data, sample_bits)
in_arrays.append(data)
out_arrays.append(out.isel(pol=i).data.ravel())
if out_saturated is not None:
in_arrays.append(saturated)
out_arrays.append(out_saturated.isel(pol=i).data.ravel())
# Compute all the pols together, so that common signals are only computed
# once.
da.store(in_arrays, out_arrays, lock=False)
[docs]
class SignalService:
"""Compute signals in a separate process.
The provided arrays must be backed by :class:`.SharedArray`, and each must
have an xarray attribute called ``"shared_array"`` which holds the backing
:class:`.SharedArray`.
Parameters
----------
arrays
All the arrays that might be passed to :meth:`sample`.
sample_bits
Number of bits per sample for all queries.
dither_seed
Seed used to generate a fixed dither.
"""
@dataclass
class _Request:
"""Serialises a request from the main process to the service process."""
signals: Sequence[Signal]
timestamp: int
period: int | None
sample_rate: float
out_idx: int #: Index of the out array in the list of valid arrays
out_saturated_idx: int | None #: Index of the out saturated array in the list of valid arrays
saturation_group: int
@staticmethod
def _run(
array_schemas: list[dict],
sample_bits: int,
dither_seed: int | None,
pipe: multiprocessing.connection.Connection,
) -> None:
"""Run the main service loop for the separate process.
It receives a _Request on the pipe, and replies with either ``None``
or an exception. When the main process wants to shut down, it will
close the pipe.
"""
# Avoid catching Ctrl-C meant for the parent
signal.signal(signal.SIGINT, signal.SIG_IGN)
os.sched_setscheduler(0, os.SCHED_IDLE, os.sched_param(0))
# Load the arrays by fetching the shared array out of the special attribute
arrays = [
xr.DataArray.from_dict({**schema, "data": schema["attrs"]["shared_array"].buffer})
for schema in array_schemas
]
# Generate and pre-compute the dither
n_pols = arrays[0].sizes["pol"]
n = arrays[0].isel(pol=0).data.size * BYTE_BITS // sample_bits
dither = make_dither(n_pols, n, dither_seed)
dither = dither.persist() # Compute now and store results
while True:
try:
req: SignalService._Request = pipe.recv()
except EOFError:
break # Caller has shut down the pipe
try:
sample(
req.signals,
req.timestamp,
req.period,
req.sample_rate,
sample_bits,
arrays[req.out_idx],
out_saturated=arrays[req.out_saturated_idx] if req.out_saturated_idx is not None else None,
saturation_group=req.saturation_group,
dither=dither,
)
except Exception as exc:
pipe.send(exc)
else:
pipe.send(None)
# Not strictly necessary since the process is about to die anyway,
# but might help prevent warnings about leaking file descriptors in
# future.
for array in arrays:
array.attrs["shared_array"].close()
def __init__(self, arrays: Sequence[xr.DataArray], sample_bits: int, dither_seed: int | None = None) -> None:
self.arrays = arrays
# These contain the `shared_array` attribute, which carries the
# reference to the shared memory into the child process.
array_schemas = [array.to_dict(data=False) for array in arrays]
# The default "fork" method seems to cause problems with the unit
# tests (NGC-637). Spawning is slower but ensures we share nothing
# other than what we wish to share explicitly.
ctx = multiprocessing.get_context("spawn")
self._pipe, remote_pipe = ctx.Pipe()
self._process = ctx.Process(
target=self._run,
args=(
array_schemas,
sample_bits,
dither_seed,
remote_pipe,
),
)
self._process.start()
remote_pipe.close() # Ensures the child holds the only reference
[docs]
async def stop(self) -> None:
"""Shut down the process."""
self._pipe.close()
await asyncio.get_running_loop().run_in_executor(None, self._process.join)
def _make_request(self, request: "SignalService._Request") -> None:
self._pipe.send(request)
reply = self._pipe.recv()
if reply is not None:
raise reply
def _array_index(self, out: xr.DataArray) -> int:
for i, array in enumerate(self.arrays):
# Object identity doesn't work well, I think because fetching one
# xr.DataArray from a xr.DataSet creates a new object on the fly.
# So we check if they're referencing the same memory in the same
# way.
if array.data.__array_interface__ == out.data.__array_interface__:
return i
raise ValueError("output was not registered with the constructor")
[docs]
async def sample(
self,
signals: Sequence[Signal],
timestamp: int,
period: int | None,
sample_rate: float,
out: xr.DataArray,
out_saturated: xr.DataArray | None = None,
saturation_group: int = 1,
) -> None:
"""Perform signal sampling in the remote process.
`out` and `out_saturated` must each be one of the arrays passed to the
constructor. Only the first `n` samples will be populated (and this
will be taken as the period).
"""
out_idx = self._array_index(out)
out_saturated_idx = self._array_index(out_saturated) if out_saturated is not None else None
loop = asyncio.get_running_loop()
req = SignalService._Request(
signals, timestamp, period, sample_rate, out_idx, out_saturated_idx, saturation_group
)
await loop.run_in_executor(None, self._make_request, req)
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, *exc_info) -> None:
await self.stop()