One-dimensional batched FFT

Kokkos-FFT

// SPDX-FileCopyrightText: (C) The Kokkos-FFT development team, see COPYRIGHT.md file
//
// SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception

#include <Kokkos_Core.hpp>
#include <Kokkos_Complex.hpp>
#include <Kokkos_Random.hpp>
#include <KokkosFFT.hpp>

using execution_space = Kokkos::DefaultExecutionSpace;
template <typename T>
using View3D = Kokkos::View<T***, execution_space>;

int main(int argc, char* argv[]) {
  Kokkos::initialize(argc, argv);
  {
    const int n0 = 128, n1 = 128, n2 = 16;
    const Kokkos::complex<double> z(1.0, 1.0);

    // 1D batched C2C FFT (Forward and Backward)
    View3D<Kokkos::complex<double> > xc2c("xc2c", n0, n1, n2);
    View3D<Kokkos::complex<double> > xc2c_hat("xc2c_hat", n0, n1, n2);
    View3D<Kokkos::complex<double> > xc2c_inv("xc2c_inv", n0, n1, n2);

    Kokkos::Random_XorShift64_Pool<> random_pool(12345);
    execution_space exec;
    Kokkos::fill_random(exec, xc2c, random_pool, z);

    KokkosFFT::fft(exec, xc2c, xc2c_hat, KokkosFFT::Normalization::backward,
                   /*axis=*/-1);
    KokkosFFT::ifft(exec, xc2c_hat, xc2c_inv,
                    KokkosFFT::Normalization::backward, /*axis=*/-1);

    // 1D batched R2C FFT
    View3D<double> xr2c("xr2c", n0, n1, n2);
    View3D<Kokkos::complex<double> > xr2c_hat("xr2c_hat", n0, n1, n2 / 2 + 1);
    Kokkos::fill_random(exec, xr2c, random_pool, 1);

    KokkosFFT::rfft(exec, xr2c, xr2c_hat, KokkosFFT::Normalization::backward,
                    /*axis=*/-1);

    // 1D batched C2R FFT
    View3D<Kokkos::complex<double> > xc2r("xr2c_hat", n0, n1, n2 / 2 + 1);
    View3D<double> xc2r_hat("xc2r", n0, n1, n2);
    Kokkos::fill_random(exec, xc2r, random_pool, z);

    KokkosFFT::irfft(exec, xc2r, xc2r_hat, KokkosFFT::Normalization::backward,
                     /*axis=*/-1);
    exec.fence();
  }
  Kokkos::finalize();

  return 0;
}

numpy

# SPDX-FileCopyrightText: (C) The Kokkos-FFT development team, see COPYRIGHT.md file
#
# SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception

""" Example of batched FFTs with numpy.fft
"""

import numpy as np

if __name__ == '__main__':
    n0, n1, n2 = 128, 128, 16

    # 1D batched C2C FFT (Forward and Backward)
    xc2c = np.random.rand(n0, n1, n2) + 1j * np.random.rand(n0, n1, n2)
    xc2c_hat = np.fft.fft(xc2c, axis=-1)
    xc2c_inv = np.fft.ifft(xc2c_hat, axis=-1)

    # 1D batched R2C FFT
    xr2c = np.random.rand(n0, n1, n2)
    xr2c_hat = np.fft.rfft(xr2c, axis=-1)

    # 1D batched C2R FFT
    xc2r = np.random.rand(n0, n1, n2//2+1)
    xc2r_hat = np.fft.irfft(xc2r, axis=-1)