// ------------------- colorCorrectionBezier.c --------------------
/*
    LibCapy - a general purpose library of C functions and data structures
    Copyright (C) 2021-2025 Pascal Baillehache baillehache.pascal@gmail.com
    https://baillehachepascal.dev
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.
    You should have received a copy of the GNU General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "colorCorrectionBezier.h"
#include "pointCloud.h"

// Apply the color correction to a pixel's values
// Input:
//   that: the CapyColorCorrBezier
//   vals: the pixel's values
//   corrVals: the corrected pixel's values
// Output:
//   'corrVals' is updated with the corrected values
static void CorrectValues(
  CapyColorCorrBezier* const that,
         double const* const vals,
               double* const corrVals) {

  // Calculate the corrected color, it is the result of evaluting the Bezier
  // surface with the input values
  double in[3];
  double out[3];
  loop(i, 3) {
    double val = 0.0;
    if(that->colorSpace == capyColorSpace_sRGB) {
      val = vals[i];
    } else if(that->colorSpace == capyColorSpace_LAB) {
      if(i == 0) {
        val = vals[i] * 0.01;
      } else {
        val = vals[i] * 0.005 + 0.5;
      }
    }
    in[i] = val;
  }
  $(that->bezier, eval)(in, out);

  // Set the corrected color
  loop(i, 3) {
    if(that->colorSpace == capyColorSpace_sRGB) {
      corrVals[i] = out[i];
    } else if(that->colorSpace == capyColorSpace_LAB) {
      if(i == 0) {
        corrVals[i] = out[i] * 100.0;
      } else {
        corrVals[i] = (out[i] - 0.5) * 200.0;
      }
    }
  }
}

// Apply the color correction to an image.
// Input:
//   img: the image to be corrected
static void Apply(CapyImg* const img) {
  methodOf(CapyColorCorrBezier);

  // Memorise the original color space of the image
  CapyColorSpace origColorSpace = img->colorSpace;

  // Convert the image to the operationg color space of the color
  // correction
  $(img, convertToColorSpace)(that->colorSpace);

  // Apply the correction to each pixel
  size_t nbPixels = $(img, getNbPixels)();
  loop(iPixel, nbPixels) {
    CorrectValues(that, img->pixels[iPixel].vals, img->pixels[iPixel].vals);
  }

  // Convert the image back to its original color space
  $(img, convertToColorSpace)(origColorSpace);
}

// Find the correction to convert from a given color chart to
// a reference color chart
// Input:
//      chart: the original color chart
//   refChart: the reference color chart
// Output:
//   Update the initialFitness and finalFitness properties.
static void Match(
  CapyColorChart const* const chart,
  CapyColorChart const* const refChart) {
  methodOf(CapyColorCorrBezier);

  // Reset the result of matching
  that->initialFitness = 0.0;
  that->finalFitness = 0.0;

  // Ensure the Bezier surface doesn't exist already
  CapyBezierFree(&(that->bezier));

  // Create copy of the charts in the operating color space of the
  // color correction
  CapyColorChart* safeChart = CapyColorChartClone(chart);
  CapyColorChart* safeRefChart = CapyColorChartClone(refChart);
  $(safeChart, convertToColorSpace)(that->colorSpace);
  $(safeRefChart, convertToColorSpace)(that->colorSpace);

  // Calculate the initial fitness
  loop(iRow, chart->nbRow) loop(iCol, chart->nbCol) {
    size_t iSwatch = iRow * safeChart->nbCol + iCol;
    if(safeChart->isInsideSRGBGamut[iSwatch]) {
      double v = 0.0;
      loop(i, 3) {
        double diff =
          safeRefChart->colors[iSwatch].vals[i] -
          safeChart->colors[iSwatch].vals[i];
        v += pow(diff, 2.0);
      }
      that->initialFitness -= sqrt(v);
    }
  }
  that->finalFitness = that->initialFitness;

  // Create the point cloud representing the pairs (image swatch, ref swatch).
  CapyPointCloud* pointCloud = CapyPointCloudAlloc(6);

  // There are as many points as there are swatches, plus anchors artificially
  // added out of the gamut to prevent the correction from diverging. Swatches
  // not in the sRGB gamut aren't used as they aren't reliable.
  size_t nbSwatches = chart->nbCol * chart->nbRow;
  size_t nbValidSwatches = 0;
  loop(iSwatch, nbSwatches) if(chart->isInsideSRGBGamut[iSwatch]) {
    ++nbValidSwatches;
  }
  size_t nbAnchors = 0;
  if(that->nbAnchorPerAxis >= 2) {
    nbAnchors =
      (6 * that->nbAnchorPerAxis * that->nbAnchorPerAxis + 8) -
      12 * that->nbAnchorPerAxis;
  }
  pointCloud->size = nbValidSwatches + nbAnchors;
  safeMalloc(pointCloud->points, pointCloud->size);
  if(!(pointCloud->points)) return;

  // Calculate the order just as there are as many control point as possible
  // without having more control points than valid color swatches, or use
  // teh requested order.
  CapyBezierOrder_t order = 0;
  if(that->order != 0) order = that->order;
  else {
    order =
      (CapyBezierOrder_t)floor(pow((double)nbValidSwatches, 1.0 / 3.0)) - 1;
  }
  if(order < 1) order = 1;

  // Set the points from the swatches
  size_t iPoint = 0;
  loop(iRow, chart->nbRow) loop(iCol, chart->nbCol) {
    size_t iSwatch = iRow * chart->nbCol + iCol;
    if(chart->isInsideSRGBGamut[iSwatch]) {
      pointCloud->points[iPoint] = CapyVecCreate(6);
      loop(i, 3) {
        double vals[2] = {0.0, 0.0};
        if(that->colorSpace == capyColorSpace_sRGB) {
          vals[0] = safeChart->colors[iSwatch].RGB[i];
          vals[1] = safeRefChart->colors[iSwatch].RGB[i];
        } else if(that->colorSpace == capyColorSpace_LAB) {
          if(i == 0) {
            vals[0] = safeChart->colors[iSwatch].LAB[i] * 0.01;
            vals[1] = safeRefChart->colors[iSwatch].LAB[i] * 0.01;
          } else {
            vals[0] = safeChart->colors[iSwatch].LAB[i] * 0.005 + 0.5;
            vals[1] = safeRefChart->colors[iSwatch].LAB[i] * 0.005 + 0.5;
          }
        }
        pointCloud->points[iPoint].vals[i] = vals[0];
        pointCloud->points[iPoint].vals[i + 3] = vals[1];
      }
      ++iPoint;
    }
  }

  // Set the anchor points
  size_t n = that->nbAnchorPerAxis;
  loop(i, n) loop(j, n) loop(k, n) {
    if(
      i == 0 || i == (n - 1) ||
      j == 0 || j == (n - 1) ||
      k == 0 || k == (n - 1)
    ) {
      pointCloud->points[iPoint] = CapyVecCreate(6);
      pointCloud->points[iPoint].vals[0] =
        (double)(3 * i) / (double)(n - 1) - 1.0;
      pointCloud->points[iPoint].vals[1] =
        (double)(3 * j) / (double)(n - 1) - 1.0;
      pointCloud->points[iPoint].vals[2] =
        (double)(3 * k) / (double)(n - 1) - 1.0;
      pointCloud->points[iPoint].vals[3] = pointCloud->points[iPoint].vals[0];
      pointCloud->points[iPoint].vals[4] = pointCloud->points[iPoint].vals[1];
      pointCloud->points[iPoint].vals[5] = pointCloud->points[iPoint].vals[2];
      ++iPoint;
    }
  }

  // Create the approximating bezier, which will be used as the corrector
  that->bezier = $(pointCloud, getApproxBezier)(3, order);

  // Calculate the final fitness
  that->finalFitness = 0.0;
  if(that->bezier != NULL) loop(iRow, chart->nbRow) loop(iCol, chart->nbCol) {
    size_t iSwatch = iRow * safeChart->nbCol + iCol;
    if(chart->isInsideSRGBGamut[iSwatch]) {
      double v = 0.0;
      double out[3] = {0.0, 0.0, 0.0};
      CorrectValues(that, safeChart->colors[iSwatch].vals, out);
      loop(i, 3) {
        double diff = safeRefChart->colors[iSwatch].vals[i] - out[i];
        v += pow(diff, 2.0);
      }
      that->finalFitness -= sqrt(v);
    }
  }

  // Free memory
  CapyPointCloudFree(&pointCloud);
  CapyColorChartFree(&safeChart);
  CapyColorChartFree(&safeRefChart);
}

// Save the correction to a path.
// Format: mat[0], mat[1], ...
// Input:
//   path: the path
// Exceptions:
//   May raise CapyExc_StreamOpenError, CapyExc_StreamWriteError
static void SaveToPath(char const* const path) {
  methodOf(CapyColorCorrBezier);

  // Open the stream
  CapyStreamIo stream = CapyStreamIoCreate();
  $(&stream, open)(path, "wb");

  // Save to the stream
  $(that, saveToStream)(&stream);

  // Close the stream
  $(&stream, close)();
}

// Save the correction to a stream (in binary mode).
// Format: mat[0], mat[1], ...
// Input:
//   stream: the stream
// Exceptions:
//   May raise CapyExc_StreamWriteError
static void SaveToStream(CapyStreamIo* const stream) {
  methodOf(CapyColorCorrBezier);

  // Write the data
  $(stream, writeBytes)(&(that->colorSpace), sizeof(CapyColorSpace));
  $(that->bezier, save)(stream->stream);
}

// Load the correction from a path.
// Format: mat[0], mat[1], ...
// Input:
//   path: the path
// Exceptions:
//   May raise CapyExc_StreamOpenError, CapyExc_StreamReadError,
//   CapyExc_MallocFailed.
static void LoadFromPath(char const* const path) {
  methodOf(CapyColorCorrBezier);

  // Open the stream
  CapyStreamIo stream = CapyStreamIoCreate();
  $(&stream, open)(path, "rb");

  // Load from the stream
  $(that, loadFromStream)(&stream);

  // Close the stream
  $(&stream, close)();
}

// Load the correction from a stream (in binary mode).
// Input:
//   stream: the stream
// Exceptions:
//   May raise CapyExc_StreamReadError, CapyExc_MallocFailed.
static void LoadFromStream(CapyStreamIo* const stream) {
  methodOf(CapyColorCorrBezier);

  // Write the data
  $(stream, readBytes)(
    &(that->colorSpace),
    sizeof(CapyColorSpace));
  CapyBezierFree(&(that->bezier));
  that->bezier = CapyBezierLoad(stream->stream);
}

// Free the memory used by a CapyColorCorrBezier
static void Destruct(void) {
  methodOf(CapyColorCorrBezier);
  $(that, destructCapyColorCorr)();
  CapyBezierFree(&(that->bezier));
}

// Create a CapyColorCorrBezier
// Output:
//   Return a CapyColorCorrBezier
CapyColorCorrBezier CapyColorCorrBezierCreate(void) {
  CapyColorCorrBezier that = {
    .bezier = NULL,
    .order = 0,
    .nbAnchorPerAxis = 3,
  };
  CapyInherits(that, CapyColorCorr, ());
  that.destruct = Destruct;
  that.match = Match;
  that.apply = Apply;
  that.saveToPath = SaveToPath;
  that.saveToStream = SaveToStream;
  that.loadFromPath = LoadFromPath;
  that.loadFromStream = LoadFromStream;
  return that;
}

// Allocate memory for a new CapyColorCorrBezier and create it
// Output:
//   Return a CapyColorCorrBezier
// Exception:
//   May raise CapyExc_MallocFailed.
CapyColorCorrBezier* CapyColorCorrBezierAlloc(void) {
  CapyColorCorrBezier* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyColorCorrBezierCreate();
  return that;
}

// Free the memory used by a CapyColorCorrBezier* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyColorCorrBezier to free
void CapyColorCorrBezierFree(CapyColorCorrBezier** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}
