// ------------------------------ fft.c ------------------------------
/*
    LibCapy - a general purpose library of C functions and data structures
    Copyright (C) 2021-2025 Pascal Baillehache baillehache.pascal@gmail.com
    https://baillehachepascal.dev
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.
    You should have received a copy of the GNU General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "fft.h"

// Evaluate the Fourier series for given DFT coefficients 
static void DFTCoeffsEval(
  double const* const in,
        double* const out) {
  methodOf(CapyDFTCoeffs);
  double x = (in[0] - that->range.min) / (that->range.max - that->range.min);
  out[0] = creal(that->vals[0]) / (double)(that->nb);
  for(size_t iCoeff = 1; iCoeff < that->nb / 2; ++iCoeff) {
    double amp = 2.0 * cabs(that->vals[iCoeff]) / (double)(that->nb);
    double theta = carg(that->vals[iCoeff]);
    double y = (double)iCoeff * x * 2.0 * M_PI - theta;
    out[0] += amp * cos(y);
  }
}

// Get the vector of amplitude (magnitude of the complex number) per
// frequency bin (aka spectrum plot).
// Inputs:
//   fold: if true, fold the symetric frequency and return only half of
//         the bins.
// Output:
//   Return the amplitude per frequency bins as a vector. It is symetric
//   relative to that->nb/2. One can get the single sided Fourier coeficients
//   by taking only the first half values multiplied by two.
static CapyVec GetAmpFreqBins(bool const fold) {
  methodOf(CapyDFTCoeffs);

  // Create the result vector
  size_t nbBins = that->nb;
  if(fold) nbBins /= 2;
  CapyVec amp = CapyVecCreate(nbBins);

  // Loop on the frequency bins and calculate the amplitude for each frequency
  loop(iBin, nbBins) {
    amp.vals[iBin] = cabs(that->vals[iBin]);
    if(fold && iBin > 0) amp.vals[iBin] *= 2.0;
    amp.vals[iBin] /= (double)(that->nb);
  }

  // Return the result
  return amp;
}

// Destruct a CapyDFTCoeffs
// Input:
//   that: a pointer to the CapyDFTCoeffs to free
static void DFTCoeffsDestruct(void) {
  methodOf(CapyDFTCoeffs);
  $(that, destructCapyMathFun)();
  $(&(that->range), destruct)();
  free(that->vals);
  that->vals = NULL;
  that->nb = 0;
}

// Create a CapyDFTCoeffs
// Input:
//   nb: the number of coefficients
// Output:
//   Return a CapyDFTCoeffs
CapyDFTCoeffs CapyDFTCoeffsCreate(size_t const nb) {
  CapyDFTCoeffs that = {0};
  CapyInherits(that, CapyMathFun, (1, 1));
  that.destruct = DFTCoeffsDestruct;
  that.eval = DFTCoeffsEval;
  that.range = CapyRangeDoubleCreate(0, 1);
  that.getAmpFreqBins = GetAmpFreqBins;
  safeMalloc(that.vals, nb);
  loop(i, nb) that.vals[i] = 0.0;
  that.nb = nb;
  return that;
}

// Allocate memory for a new CapyDFTCoeffs and create it
// Input:
//   nb: the number of coefficients
// Output:
//   Return a CapyDFTCoeffs
CapyDFTCoeffs* CapyDFTCoeffsAlloc(size_t const nb) {
  CapyDFTCoeffs* that = NULL;
  safeMalloc(that, 1);
  assert(that != NULL);
  *that = CapyDFTCoeffsCreate(nb);
  return that;
}

// Free the memory used by a CapyDFTCoeffs* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyDFTCoeffs to free
void CapyDFTCoeffsFree(CapyDFTCoeffs** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}

//  Recursive FFT
static void FastFourierTransformRec(
                 size_t const n,
  double complex const* const x,
        double complex* const y) {
  if(n == 1) *y = *x;
  else {
    size_t nHalf = n / 2;
    double complex* xe = NULL;
    safeMalloc(xe, nHalf);
    assert(xe != NULL);
    double complex* xo = NULL;
    safeMalloc(xo, nHalf);
    assert(xo != NULL);
    loop(i, nHalf) {
      xe[i] = x[2 * i];
      xo[i] = x[2 * i + 1];
    }
    double complex* ye = NULL;
    safeMalloc(ye, nHalf);
    assert(ye != NULL);
    double complex* yo = NULL;
    safeMalloc(yo, nHalf);
    assert(yo != NULL);
    FastFourierTransformRec(nHalf, xe, ye);
    FastFourierTransformRec(nHalf, xo, yo);
    double complex omega = cexp(2.0 * M_PI * I / (double complex)n);
    double complex w = 1.0;
    loop(i, nHalf) {
      double complex t = w * yo[i];
      y[i] = ye[i] + t;
      y[i + nHalf] = ye[i] - t;
      w *= omega;
    }
    free(xe);
    free(xo);
    free(ye);
    free(yo);
  }
}

//  Recursive inverse FFT
static void FastFourierInverseTransformRec(
                 size_t const n,
  double complex const* const x,
        double complex* const y) {
  if(n == 1) *y = *x;
  else {
    size_t nHalf = n / 2;
    double complex omega = cexp(-2.0 * M_PI * I / (double)n);
    double complex* xe = NULL;
    safeMalloc(xe, nHalf);
    assert(xe != NULL);
    double complex* xo = NULL;
    safeMalloc(xo, nHalf);
    assert(xo != NULL);
    loop(i, nHalf) {
      xe[i] = x[2 * i];
      xo[i] = x[2 * i + 1];
    }
    double complex* ye = NULL;
    safeMalloc(ye, nHalf);
    assert(ye != NULL);
    double complex* yo = NULL;
    safeMalloc(yo, nHalf);
    assert(yo != NULL);
    FastFourierInverseTransformRec(nHalf, xe, ye);
    FastFourierInverseTransformRec(nHalf, xo, yo);
    double complex w = 1.0;
    loop(i, nHalf) {
      y[i] = (ye[i] + w * yo[i]) / 2.0;
      y[i + nHalf] = (ye[i] - w * yo[i]) / 2.0;
      w *= omega;
    }
    free(xe);
    free(xo);
    free(ye);
    free(yo);
  }
}

// Transform a 1D polynomial from coefficients representation to values
// representation. The polynomial must have a power-of-2 number of
// coefficients.
// Input:
//     poly: the polynomial
// Output:
//   Return the value representation of the polynomial as a CapyDFTCoeffs
static CapyDFTCoeffs FftPolyFromCoeffToVal(CapyPolynomial1D const* const poly) {
  double complex* x = NULL;
  safeMalloc(x, poly->coeffs.dim);
  loop(i, poly->coeffs.dim) x[i] = poly->coeffs.vals[i] + 0.0 * I;
  CapyDFTCoeffs values = CapyDFTCoeffsCreate(poly->coeffs.dim);
  FastFourierTransformRec(poly->coeffs.dim, x, values.vals);
  free(x);
  return values;
}

// Transform a 1D polynomial from values representation to coefficients
// representation. The number of values must be a power-of-2.
// Input:
//   values: the values representation of the polynomial
// Output:
//   Return the polynomial corresponding to the value representation
static CapyPolynomial1D* FftPolyFromValToCoeff(
  CapyDFTCoeffs const* const values) {
  double vals[values->nb];
  CapyVec coeffs = {.dim = values->nb, .vals = vals};
  CapyPolynomial1D* poly = CapyPolynomial1DAlloc(&coeffs);
  double complex x[values->nb];
  FastFourierInverseTransformRec(values->nb, values->vals, x);
  loop(i, poly->coeffs.dim) poly->coeffs.vals[i] = creal(x[i]);
  return poly;
}

// Calculate the Discrete Fourier Transform for an input range of a given
// function
// Input:
//   fun: the function (must have one input and one output)
//   range: range of the input of the function for which the DFT is calculated
//   nbSample: the number of samples taken from the function, equally spaced
//             in 'range' (bounds included)
// Output:
//   Return the DFT coefficients as a CapyDFTCoeffs
static CapyDFTCoeffs FftFun(
            CapyMathFun* const fun,
  CapyRangeDouble const* const range,
                  size_t const nbSample) {
  methodOf(CapyDFT);

  // Vector to memorise the sample
  CapyVec samples = CapyVecCreate(nbSample);

  // Calculate the sample values
  double x[1];
  CapyRangeDouble fromRange = { .min = 0, .max = (double)(nbSample - 1)};
  CapyRangeDouble toRange;
  toRange.min = range->min;
  toRange.max =
    range->min +
    (range->max - range->min) * (double)(nbSample - 1) / (double)nbSample;
  loop(iSample, nbSample) {
    x[0] = CapyLerp((double)iSample, &fromRange, &toRange);
    $(fun, eval)(x, samples.vals + iSample);
  }

  // Calculate the DFT coefficients
  CapyDFTCoeffs coeffs = $(that, fftSamples)(&samples);
  loop(i, 2) coeffs.range.vals[i] = range->vals[i];

  // Free memory
  CapyVecDestruct(&samples);

  // Return the coefficients
  return coeffs;
}

// Calculate the Discrete Fourier Transform for a given set of samples
// Input:
//   samples: the samples value
// Output;
//   Return the DFT coefficients as a CapyDFTCoeffs
static CapyDFTCoeffs FftSamples(CapyVec const* const samples) {

  // Convert the samples to complex numbers
  double complex* x = NULL;
  safeMalloc(x, samples->dim);
  assert(x != NULL);
  loop(i, samples->dim) x[i] = samples->vals[i];

  // Apply the fast fourier transform
  CapyDFTCoeffs coeffs = CapyDFTCoeffsCreate(samples->dim);
  FastFourierTransformRec(samples->dim, x, coeffs.vals);
  coeffs.range.min = 0.0;
  coeffs.range.max = (double)(samples->dim - 1);

  // Free memory
  free(x);

  // Return the coefficients
  return coeffs;
}

// Free the memory used by a CapyDFT
static void Destruct(void) {
  return;
}

// Create a CapyDFT
// Output:
//   Return a CapyDFT
CapyDFT CapyDFTCreate(void) {
  return (CapyDFT){
    .destruct = Destruct,
    .fftPolyFromCoeffToVal = FftPolyFromCoeffToVal,
    .fftPolyFromValToCoeff = FftPolyFromValToCoeff,
    .fftFun = FftFun,
    .fftSamples = FftSamples,
  };
}

// Allocate memory for a new CapyDFT and create it
// Output:
//   Return a CapyDFT
// Exception:
//   May raise CapyExc_MallocFailed.
CapyDFT* CapyDFTAlloc(void) {
  CapyDFT* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyDFTCreate();
  return that;
}

// Free the memory used by a CapyDFT* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyDFT to free
void CapyDFTFree(CapyDFT** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}

// Evaluate the Fourier series for given DFT2D coefficients 
static void DFT2DCoeffsEval(
  double const* const in,
        double* const out) {

  // methodOf(CapyDFT2DCoeffs);
  // TODO
  (void)in;
  raiseExc(CapyExc_UndefinedExecution);
  out[0] = 0.0;
}

// Convert the coefficients into a visualisation of the amplitude
// Inputs:
//   center: if true, center the null frequency
//   logScale: if true, apply logarithmic scaling
// Output:
//   Return a normalised greyscale image.
static CapyImg* ToAmplitudeImg(
  bool const center,
  bool const logScale) {
  methodOf(CapyDFT2DCoeffs);

  // Create the result image
  CapyImgDims dims = {
    .width = (CapyImgDims_t)(that->nb[0]),
    .height = (CapyImgDims_t)(that->nb[1])
  };
  CapyImg* img = CapyImgAlloc(capyImgMode_rgb, dims);

  // Loop on the pixels
  forEach(pixel, img->iter) {

    // Get the corresponding coefficients, applying centering if required
    size_t idx[2] = {0};
    loop(i, 2) {
      idx[i] = (size_t)(pixel.pos.coords[i]);
      if(center) {
        size_t halfNb = that->nb[i] / 2;
        if(idx[i] < halfNb) idx[i] += that->nb[i] - halfNb;
        else idx[i] -= halfNb;
      }
    }

    // Copy the real part of the coefficients in the pixel
    double val = creal(that->vals[idx[1] * that->nb[0] + idx[0]]);
    loop(c, 3) pixel.color->rgb[c] = val;
  }

  // Normalise the image
  $(img, normalise)();

  // Apply logarithm scale if required
  if(logScale) forEach(pixel, img->iter) {
    loop(c, 3) pixel.color->rgb[c] = pow(pixel.color->rgb[c], 0.1);
  }

  // Return the image
  return img;
}

// Convert the coefficients into a visualisation of the phase
// Inputs:
//   center: if true, center the null frequency
// Output:
//   Return a normalised greyscale image.
static CapyImg* ToPhaseImg(bool const center) {
  methodOf(CapyDFT2DCoeffs);

  // Create the result image
  CapyImgDims dims = {
    .width = (CapyImgDims_t)(that->nb[0]),
    .height = (CapyImgDims_t)(that->nb[1])
  };
  CapyImg* img = CapyImgAlloc(capyImgMode_rgb, dims);

  // Loop on the pixels
  forEach(pixel, img->iter) {

    // Get the corresponding coefficients, applying centering if required
    size_t idx[2] = {0};
    loop(i, 2) {
      idx[i] = (size_t)(pixel.pos.coords[i]);
      if(center) {
        size_t halfNb = that->nb[i] / 2;
        if(idx[i] < halfNb) idx[i] += that->nb[i] - halfNb;
        else idx[i] -= halfNb;
      }
    }

    // Copy the imaginary part of the coefficients in the pixel
    double val = cimag(that->vals[idx[1] * that->nb[0] + idx[0]]);
    loop(c, 3) pixel.color->rgb[c] = val;
  }

  // Normalise the image
  $(img, normalise)();

  // Return the image
  return img;
}

// Convert the coefficients into a periodogram
// Inputs:
//   center: if true, center the null frequency
//   logScale: if true, apply logarithmic scaling
// Output:
//   Return a normalised greyscale image.
static CapyImg* ToPeriodogramImg(
  bool const center,
  bool const logScale) {
  methodOf(CapyDFT2DCoeffs);

  // Create the result image
  CapyImgDims dims = {
    .width = (CapyImgDims_t)(that->nb[0]),
    .height = (CapyImgDims_t)(that->nb[1])
  };
  CapyImg* img = CapyImgAlloc(capyImgMode_rgb, dims);

  // Loop on the pixels
  forEach(pixel, img->iter) {

    // Get the corresponding coefficients, applying centering if required
    size_t idx[2] = {0};
    loop(i, 2) {
      idx[i] = (size_t)(pixel.pos.coords[i]);
      if(center) {
        size_t halfNb = that->nb[i] / 2;
        if(idx[i] < halfNb) idx[i] += that->nb[i] - halfNb;
        else idx[i] -= halfNb;
      }
    }

    // Set the magnitude to the pixel
    double val = cabs(that->vals[idx[1] * that->nb[0] + idx[0]]);
    loop(c, 3) pixel.color->rgb[c] = val * val;
  }

  // Normalise the image
  $(img, normalise)();

  // Apply logarithm scale if required
  if(logScale) forEach(pixel, img->iter) {
    loop(c, 3) pixel.color->rgb[c] = pow(pixel.color->rgb[c], 0.1);
  }

  // Return the image
  return img;
}

// Destruct a CapyDFT2DCoeffs
// Input:
//   that: a pointer to the CapyDFT2DCoeffs to free
static void DFT2DCoeffsDestruct(void) {
  methodOf(CapyDFT2DCoeffs);
  $(that, destructCapyMathFun)();
  loop(i, 2) $(that->range + i, destruct)();
  free(that->vals);
  that->vals = NULL;
  loop(i, 2) that->nb[i] = 0;
}

// Create a CapyDFTCoeffs
// Input:
//   nb: the number of coefficients
// Output:
//   Return a CapyDFTCoeffs
CapyDFT2DCoeffs CapyDFT2DCoeffsCreate(size_t const nb[2]) {
  CapyDFT2DCoeffs that = {0};
  CapyInherits(that, CapyMathFun, (2, 1));
  that.destruct = DFT2DCoeffsDestruct;
  that.eval = DFT2DCoeffsEval;
  that.toAmplitudeImg = ToAmplitudeImg;
  that.toPhaseImg = ToPhaseImg;
  that.toPeriodogramImg = ToPeriodogramImg;
  loop(i, 2) that.range[i] = CapyRangeDoubleCreate(0, 1);
  safeMalloc(that.vals, nb[0] * nb[1]);
  loop(i, nb[0] * nb[1]) that.vals[i] = 0.0;
  loop(i, 2) that.nb[i] = nb[i];
  return that;
}

// Allocate memory for a new CapyDFT2DCoeffs and create it
// Input:
//   nb: the number of coefficients
// Output:
//   Return a CapyDFT2DCoeffs
CapyDFT2DCoeffs* CapyDFT2DCoeffsAlloc(size_t const nb[2]) {
  CapyDFT2DCoeffs* that = NULL;
  safeMalloc(that, 1);
  assert(that != NULL);
  *that = CapyDFT2DCoeffsCreate(nb);
  return that;
}

// Free the memory used by a CapyDFT2DCoeffs* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyDFT2DCoeffs to free
void CapyDFT2DCoeffsFree(CapyDFT2DCoeffs** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}

// Calculate the 2D Discrete Fourier Transform for a given image
// Input:
//   img: the img to transform
// Output:
//   Return the DFT coefficients as a CapyDFT2DCoeffs
static CapyDFT2DCoeffs FftImage(CapyImg const* const img) {

  // Create the result DFT2DCoeffs
  size_t nbCoeffs[2] = {img->dims.width, img->dims.height};
  CapyDFT2DCoeffs coeffs = CapyDFT2DCoeffsCreate(nbCoeffs);
  loop(i, 2) coeffs.range[i].max = (double)(nbCoeffs[i] - 1);

  // Variable to memorise the DFT coefficients per column
  double complex* colCoeffs[img->dims.width];
  loop(i, img->dims.width) colCoeffs[i] = NULL;

  // Calculate the DFT per column
  double complex* x = NULL;
  safeMalloc(x, img->dims.height);
  loop(iCol, img->dims.width) {
    safeMalloc(colCoeffs[iCol], img->dims.height);
    loop(i, img->dims.height) colCoeffs[iCol][i] = 0.0;
    loop(iRow, img->dims.height) {
      CapyImgPos pos = {.x = (CapyImgPos_t)iCol, .y = (CapyImgPos_t)iRow};
      CapyColorData* color = $(img, getColor)(&pos);
      x[iRow] = 2.0 * color->rgb[0] - 1.0;
    }
    FastFourierTransformRec(img->dims.height, x, colCoeffs[iCol]);
  }
  free(x);

  // Calculate the DFT per rows
  safeMalloc(x, img->dims.width);
  loop(iRow, img->dims.height) {
    loop(iCol, img->dims.width) x[iCol] = colCoeffs[iCol][iRow];
    FastFourierTransformRec(
      img->dims.width, x, coeffs.vals + iRow * img->dims.width);
  }
  free(x);

  // Free memory
  loop(iCol, img->dims.width) free(colCoeffs[iCol]);

  // Return the coefficients
  return coeffs;
}

// Free the memory used by a CapyDFT2D
static void Destruct2D(void) {
  return;
}

// Create a CapyDFT2D
// Output:
//   Return a CapyDFT2D
CapyDFT2D CapyDFT2DCreate(void) {
  return (CapyDFT2D){
    .destruct = Destruct2D,
    .fftImage = FftImage,
  };
}

// Allocate memory for a new CapyDFT2D and create it
// Output:
//   Return a CapyDFT2D
// Exception:
//   May raise CapyExc_MallocFailed.
CapyDFT2D* CapyDFT2DAlloc(void) {
  CapyDFT2D* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyDFT2DCreate();
  return that;
}

// Free the memory used by a CapyDFT2D* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyDFT2D to free
void CapyDFT2DFree(CapyDFT2D** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}
