Source code for katgpucbf.xbgpu.correlation

################################################################################
# Copyright (c) 2020-2024, National Research Foundation (SARAO)
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use
# this file except in compliance with the License. You may obtain a copy
# of the License at
#
#   https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################


"""Module wrapping the ASTRON Tensor-Core Correlation Kernels in the MeerKAT katsdpsigproc framework.

.. todo::

    Eventually modify the classes to support 4 and 16 bit input samples. The
    kernel supports this, but it is not exposed to the reader. There is no use
    case for this at the moment, so this is a low priority.

"""

import importlib.resources

import numpy as np
from katsdpsigproc import accel, cuda
from katsdpsigproc.abc import AbstractContext, AbstractDevice

from .. import COMPLEX, N_POLS

#: Minimum CUDA compute capability needed for the kernel (with 8-bit samples)
MIN_COMPUTE_CAPABILITY = (7, 2)
#: Magic value indicating missing data
MISSING = np.array([-(2**31), 1], dtype=np.int32)


[docs] def device_filter(device: AbstractDevice) -> bool: """Determine whether a device is suitable for running the kernel.""" return isinstance(device, cuda.Device) and device.compute_capability >= MIN_COMPUTE_CAPABILITY
[docs] class CorrelationTemplate: r"""Template class for the Tensor-Core correlation kernel. The template creates a :class:`Correlation` that will run the compiled kernel. The parameters are used to compile the kernel and by the :class:`Correlation` to specify the shape of the memory buffers connected to this kernel. The number of baselines calculated here is not the canonical way that it is done in radio astronomy: .. math:: n_{baselines} = \frac{n_{ants} * (n_{ants} + 1)}{2} Because we have a dual-polarised telescope, we calculate four 'baselines' for each canonical baseline as calculated above, namely :math:`h_1 h_2`, :math:`h_1 v_2`, :math:`v_1 h_2`, and :math:`v_1 v_2`. So the list of baselines appears four times as long as you might expect. Parameters ---------- n_ants The number of antennas that will be correlated. Each antennas is expected to produce two polarisations. n_channels_per_substream The number of frequency channels to be processed. n_spectra_per_heap The number of time samples to be processed per frequency channel. input_sample_bits The number of bits per input sample. Only 8 bits is supported at the moment. """ def __init__( self, context: AbstractContext, n_ants: int, n_channels_per_substream: int, n_spectra_per_heap: int, input_sample_bits: int, ) -> None: self.n_ants = n_ants self.n_channels_per_substream = n_channels_per_substream self.n_spectra_per_heap = n_spectra_per_heap self.n_baselines = self.n_ants * (self.n_ants + 1) // 2 self.input_sample_bits = input_sample_bits # hardcoded to 8 upstream self._n_ants_per_block = 32 # Hardcoded to 32 for now, but can be set to 32/48/64. # This 128 is hardcoded in the original Tensor-Core kernel. It loads # each block as two int4's, which is 256 bits (the extra factor of 2 # is because input_sample_bits only counts the real part of a complex # number). self.n_times_per_block = 128 // self.input_sample_bits valid_bitwidths = [4, 8, 16] if self.input_sample_bits not in valid_bitwidths: raise ValueError( f"input_sample_bits must equal either 4, 8 or 16, currently equal to {self.input_sample_bits}." ) elif self.input_sample_bits == 4 or self.input_sample_bits == 16: raise ValueError( f"Sample bitwidth of {self.input_sample_bits} " "will eventually be supported but has not yet been implemented." ) if self.n_spectra_per_heap % self.n_times_per_block != 0: raise ValueError(f"spectra_per_heap must be divisible by {self.n_times_per_block}.") n_blocks_1d = accel.divup(self.n_ants, self._n_ants_per_block) if self._n_ants_per_block in {32, 48}: self.n_blocks = n_blocks_1d * (n_blocks_1d + 1) // 2 elif self._n_ants_per_block == 64: self.n_blocks = n_blocks_1d * n_blocks_1d else: raise ValueError( f"ants_per_block must equal either 32, 48 or 64, currently equal to {self._n_ants_per_block}." ) source = (importlib.resources.files(__package__) / "kernels" / "tensor_core_correlation_kernel.cu").read_text() program = context.compile( source, [ f"-DNR_RECEIVERS={self.n_ants}", f"-DNR_RECEIVERS_PER_BLOCK={self._n_ants_per_block}", f"-DNR_BITS={self.input_sample_bits}", f"-DNR_CHANNELS={self.n_channels_per_substream}", f"-DNR_SAMPLES_PER_CHANNEL={self.n_spectra_per_heap}", f"-DNR_POLARIZATIONS={N_POLS}", "-DCUSTOM_STORE_VISIBILITY=1", # Suppress "pointless comparison of unsigned integer with zero" "-Xcudafe", "--diag_suppress=186", ], ) self.correlate_kernel = program.get_kernel("correlate") self.reduce_kernel = program.get_kernel("reduce")
[docs] def instantiate(self, command_queue: accel.AbstractCommandQueue, n_batches: int) -> "Correlation": """Create a :class:`Correlation` using this template to build the kernel.""" return Correlation(self, command_queue, n_batches)
[docs] class Correlation(accel.Operation): """Tensor-Core correlation kernel. Specifies the shape of the input sample and output visibility buffers required by the kernel. The parameters specified in the :class:`CorrelationTemplate` object are used to determine the shape of the buffers. There is an outer-most dimension called "batches", over which the operation is parallelised. Not all batches need to be processed every time: set the ``first_batch`` and ``last_batch`` attributes to control which batches will be computed. The input sample buffer must have the shape: ``[n_batches][n_ants][channels][spectra_per_heap][polarisations]`` There is an alignment requirement for ``spectra_per_heap`` due to the implementation details of the kernel. For 8-bit input mode, ``spectra_per_heap`` must be a multiple of 16. Each input element is a complex 8-bit integer sample. :mod:`.numpy` does not support 8-bit complex numbers, so the dimensionality is extended by 1, with the last dimension sized ``2`` to represent the complexity. With 8-bit input samples, the value -128i is not supported by the kernel as there is no 8-bit complex conjugate representation of this number. Passing ``-128i`` into the kernel will produce incorrect values at the output. The output visibility buffer must have the shape ``[channels][baselines][COMPLEX]``. In 8-bit mode, each element in this visibility matrix is a 32-bit integer value. Calling this object does not directly update the output. Instead, it updates an intermediate buffer (called ``mid_visibilities``). To produce the output, call :meth:`reduce`. This function can also flag data that was missing during the accumulation, by writing a special value. This is controlled by the ``present_baselines`` slot, which has one boolean entry per baseline (antenna pair). Currently only 8-bit input mode is supported. """ def __init__( self, template: CorrelationTemplate, command_queue: accel.AbstractCommandQueue, n_batches: int ) -> None: super().__init__(command_queue) self.template = template # Determine how many accumulators to use. Fewer is better for both # memory usage and I/O throughput, but too few means there will not # be enough parallelism to saturate the GPU. Aim for 1024-2048 # work-groups, while sticking to powers of 2 since that's likely to # give an even division of work across them. n_mid = 1 while n_mid * self.template.n_channels_per_substream * self.template.n_blocks < 1024: n_mid *= 2 input_data_dimensions = ( accel.Dimension(n_batches), accel.Dimension(self.template.n_ants, exact=True), accel.Dimension(self.template.n_channels_per_substream, exact=True), accel.Dimension(self.template.n_spectra_per_heap, exact=True), accel.Dimension(N_POLS, exact=True), accel.Dimension(COMPLEX, exact=True), ) mid_data_dimensions = ( accel.Dimension(n_mid), accel.Dimension(self.template.n_channels_per_substream, exact=True), accel.Dimension(self.template.n_baselines * N_POLS * N_POLS, exact=True), accel.Dimension(COMPLEX, exact=True), ) # TODO: NGC-1104 update this once 4-bit correlation is supported assert self.template.input_sample_bits == 8, ( f"{self.template.input_sample_bits}-bit mode not supported yet, only 8-bit." ) self.slots["in_samples"] = accel.IOSlot(dimensions=input_data_dimensions, dtype=np.int8) self.slots["mid_visibilities"] = accel.IOSlot(dimensions=mid_data_dimensions, dtype=np.int64) self.slots["out_visibilities"] = accel.IOSlot(dimensions=mid_data_dimensions[1:], dtype=np.int32) self.slots["out_saturated"] = accel.IOSlot(dimensions=(), dtype=np.uint32) self.slots["present_baselines"] = accel.IOSlot(dimensions=(self.template.n_baselines,), dtype=np.uint8) if n_batches * self.template.n_channels_per_substream * self.template.n_baselines * N_POLS * N_POLS >= 2**31: # Can probably go higher, but rather keep it low to reduce the risk # of indexing bugs. raise ValueError("2^31 or more visibilities are not currently supported") self.first_batch = 0 self.last_batch = n_batches self.n_batches = n_batches def _run(self) -> None: """Run the correlation kernel and add the generated values to internal buffer.""" if not 0 <= self.first_batch < self.last_batch <= self.n_batches: raise ValueError("Invalid batch range") in_samples_buffer = self.buffer("in_samples") mid_visibilities_buffer = self.buffer("mid_visibilities") n_z = mid_visibilities_buffer.shape[0] n_batches = self.last_batch - self.first_batch # Number of batches for this launch n_time_blocks_per_batch = self.template.n_spectra_per_heap // self.template.n_times_per_block n_time_blocks = n_batches * n_time_blocks_per_batch n_time_blocks_per_z = accel.divup(n_time_blocks, n_z) # The rounding up of n_time_blocks_per_z may leave some z values with # no work. So recompute n_z to avoid launching them at all. n_z = accel.divup(n_time_blocks, n_time_blocks_per_z) first_time_block = self.first_batch * n_time_blocks_per_batch self.command_queue.enqueue_kernel( self.template.correlate_kernel, [ mid_visibilities_buffer.buffer, in_samples_buffer.buffer, np.uint32(first_time_block), np.uint32(n_time_blocks), np.uint32(n_time_blocks_per_z), ], # NOTE: Even though we are using CUDA, we follow OpenCL's grid/block # conventions. As such we need to multiply the number of # blocks(global_size) by the block size(local_size) in order to # specify global threads not global blocks. global_size=(32 * self.template.n_blocks, 2 * self.template.n_channels_per_substream, 2 * n_z), local_size=(32, 2, 2), )
[docs] def reduce(self) -> None: """Finalise computation of the output visibilities from the internal buffer.""" self.ensure_all_bound() mid_visibilities_buffer = self.buffer("mid_visibilities") out_visibilities_buffer = self.buffer("out_visibilities") out_saturated_buffer = self.buffer("out_saturated") present_baselines_buffer = self.buffer("present_baselines") wgs = 128 # TODO: could be tuned. But this kernel costs a tiny amount out_saturated_buffer.zero(self.command_queue) self.command_queue.enqueue_kernel( self.template.reduce_kernel, [ out_visibilities_buffer.buffer, out_saturated_buffer.buffer, mid_visibilities_buffer.buffer, present_baselines_buffer.buffer, np.uint32(mid_visibilities_buffer.shape[0]), ], global_size=(accel.roundup(int(np.prod(out_visibilities_buffer.shape)), wgs), 1, 1), local_size=(wgs, 1, 1), )
[docs] def zero_visibilities(self) -> None: """Zero all the values in the internal buffer.""" self.ensure_bound("mid_visibilities") self.buffer("mid_visibilities").zero(self.command_queue)
[docs] @staticmethod def get_baseline_index(ant1: int, ant2: int) -> int: r"""Get index in the visibilities matrix for baseline (ant1, ant2). The visibilities matrix indexing is as follows: .. code:: ant2 = 0 1 2 3 4 +--------------- ant1 = 0 | 00 01 03 06 10 1 | 02 04 07 11 2 | 05 08 12 3 | 09 13 4 | 14 This function requires that :math:`ant2 \ge ant1` """ if ant1 > ant2: raise ValueError("It is required that ant2 >= ant1 in all cases") return ant2 * (ant2 + 1) // 2 + ant1