// ---------------------------- graphplotter.c ---------------------------
/*
    LibCapy - a general purpose library of C functions and data structures
    Copyright (C) 2021-2025 Pascal Baillehache baillehache.pascal@gmail.com
    https://baillehachepascal.dev
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.
    You should have received a copy of the GNU General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "graphplotter.h"

// Get the bounding box of the area containing data (without margin, axii, etc,
// ...) for a given image.
// Input:
//   img: the image
// Output:
//   Return the coordinates (in pixel) as a CapyRectangle.
static CapyRectangle* GetBoundingBoxDataArea(CapyImg const* const img) {
  methodOf(CapyGraphPlotter);
  CapyRectangle* bbox = CapyRectangleAlloc();
  bbox->corners[0].x = that->margin + that->marginLegend;
  bbox->corners[0].y = that->margin;
  bbox->corners[1].x = ((CapyImgPos_t)(img->dims.width)) - that->margin;
  bbox->corners[1].y =
    ((CapyImgPos_t)(img->dims.height)) - (that->margin + that->marginLegend);
  if(bbox->corners[1].x < 0) bbox->corners[1].x = 0;
  if(bbox->corners[1].y < 0) bbox->corners[1].y = 0;
  if(bbox->corners[0].x > bbox->corners[1].x) {
    bbox->corners[0].x = bbox->corners[1].x;
  }
  if(bbox->corners[0].y > bbox->corners[1].y) {
    bbox->corners[0].y = bbox->corners[1].y;
  }
  return bbox;
}

// Draw the legend of a graph
// Input:
//   img: the image on which to draw
//   legend: data about the legend
// Output:
//   The image is updated with the legend.
static void DrawLegend(
                           CapyImg* const img,
  CapyGraphPlotterLegendData const* const legend) {
  methodOf(CapyGraphPlotter);

  // Plot the bounding box
  CapyRectangle* bbox = $(that, getBoundingBoxDataArea)(img);
  $(&(that->penLegend), drawRectangle)(bbox, img);

  // Create a font for the legend
  CapyFont font = CapyFontCreate();
  loop(i, 2) font.scale.vals[i] = that->marginLegend * 0.5;
  loop(i, 2) font.spacing.vals[i] = 1.1;

  // Write the abciss legend, if any
  if(legend->titleAbciss) {
    font.forward.vals[0] = 1;
    font.forward.vals[1] = 0;
    double posLegend[2];
    posLegend[0] = bbox->corners[0].x;
    posLegend[1] = bbox->corners[1].y + that->marginLegend * 0.5;
    $(&(that->penLegend), drawText)(posLegend, legend->titleAbciss, &font, img);
  }

  // Write the ordinate legend, if any
  if(legend->titleOrdinate) {
    font.forward.vals[0] = 0;
    font.forward.vals[1] = -1;
    double posLegend[2];
    posLegend[0] = bbox->corners[0].x - that->marginLegend;
    posLegend[1] = bbox->corners[1].y;
    $(&(that->penLegend), drawText)(
      posLegend, legend->titleOrdinate, &font, img);
  }

  // Free memory
  $(&font, destruct)();
  CapyRectangleFree(&bbox);
}

// Plot an histogram.
// Input:
//   vals: values to plot
//   dims: dimension of the result image
//   legend: data for the legend
// Output:
//   Return an image of the plot of the histogram
static CapyImg* PlotHistogram(
               CapyArrDouble const* const vals,
                 CapyImgDims const* const dims,
  CapyGraphPlotterLegendData const* const legend) {
  methodOf(CapyGraphPlotter);

  // Create the result image
  CapyImg* img = CapyImgAlloc(capyImgMode_rgb, *dims);

  // Set the background to white
  $(img, fillWithColor)(&capyColorRGBAWhite);

  // Get the bounding box of the data area
  CapyRectangle* bbox = $(that, getBoundingBoxDataArea)(img);
  CapyRangeDouble rangeXFrom = {.min = 0.0, .max = (double)(vals->size)};
  CapyRangeDouble rangeXTo = {
    .min = bbox->corners[0].x, .max = bbox->corners[1].x
  };
  CapyComparator cmp = CapyComparatorCreate();
  cmp.eval = (CapyCmpFun)CapyCmpDoubleInc;
  double* minVal = $(vals, getMin)(&cmp);
  double* maxVal = $(vals, getMax)(&cmp);
  if(that->logScaleY) {
    if(*minVal > 0.0) *minVal = log10(*minVal);
    if(*maxVal > 0.0) *maxVal = log10(*maxVal);
  }
  CapyRangeDouble rangeYFrom = {.min = *minVal, .max = *maxVal};
  CapyRangeDouble rangeYTo = {
    .min = bbox->corners[1].y, .max = bbox->corners[0]. y
  };

  // Plot the bins
  CapyRectangle rectBin = CapyRectangleCreate();
  loop(iVal, vals->size) {
    double val = $(vals, get)(iVal);
    if(that->logScaleY) val = log10(val);
    loop(i, (size_t)2) {
      rectBin.corners[i].x =
        CapyLerp((double)(iVal + i), &rangeXFrom, &rangeXTo);
      rectBin.corners[i].y =
        CapyLerp((i == 0 ? val : *minVal), &rangeYFrom, &rangeYTo);
    }
    $(&(that->penData), drawFilledRectangle)(&rectBin, img);
    $(&(that->penLegend), drawRectangle)(&rectBin, img);
  }
  $(&rectBin, destruct)();
  $(that, drawLegend)(img, legend);

  // Free memory
  $(&cmp, destruct)();
  CapyRectangleFree(&bbox);

  // Return the result image
  return img;
}

// Plot the density distributions for one field from a CapyDataset.
// Input:
//   dataset: the CapyDataset
//   iField: the index of the plotted field
//   dims: the dimension of the graph
//   nbBin: number of bins for the density distribution
// Output:
//   Return a CapyImg. The ordinate corresponds to the bins from the min to the
//   max value of the plotted field and the abciss corresponds to the number
//   of records in the dataset for that bin.
static CapyImg* PlotDensityDistributions(
  CapyDataset const* const dataset,
              size_t const iField,
  CapyImgDims const* const dims,
              size_t const nbBin) {
  methodOf(CapyGraphPlotter);

  // Get the bins
  CapyArrSize* bins = $(dataset, getDistAsBins)(iField, nbBin);

  // Convert to an array of double
  CapyArrDouble* vals = CapyArrDoubleAlloc(bins->size);
  loop(i, bins->size) {
    double val = (double)($(bins, get)(i));
    $(vals, set)(i, &val);
  }

  // Create the legend
  CapyGraphPlotterLegendData legend = {
    .titleAbciss = dataset->fields[iField].label,
  };

  // Plot the histogram
  CapyImg* img = $(that, plotHistogram)(vals, dims, &legend);

  // Free memory
  CapyArrDoubleFree(&vals);

  // Return the result image
  return img;
}

// Plot the density distributions for one categorical and one other
// field from a CapyDataset.
// Input:
//   dataset: the CapyDataset
//   iCatField: the index of the categorical field for filtering
//   valCatField: the value of the categorical field for filtering
//   iField: the index of the plotted field
//   dims: the dimension of the graph
//   nbBin: number of bins for the density distribution
// Output:
//   Return a CapyImg. The ordinate corresponds to the bins from the min to the
//   max value of the plotted field and the abciss corresponds to the number
//   of records in the dataset for that bin. Records are filtered on the value
//   of the categorical field.
static CapyImg* PlotDensityDistributionsGivenCatValue(
  CapyDataset const* const dataset,
              size_t const iCatField,
         char const* const valCatField,
              size_t const iField,
  CapyImgDims const* const dims,
              size_t const nbBin) {
  methodOf(CapyGraphPlotter);

  // Get the bins
  CapyArrSize* bins =
    $(dataset, getDistAsBinsGivenCatValue)(
      iField, nbBin, iCatField, valCatField);

  // Convert to an array of double
  CapyArrDouble* vals = CapyArrDoubleAlloc(bins->size);
  loop(i, bins->size) {
    double val = (double)($(bins, get)(i));
    $(vals, set)(i, &val);
  }

  // Create the legend
  CapyGraphPlotterLegendData legend = {
    .titleAbciss = dataset->fields[iField].label,
    .titleOrdinate = valCatField,
  };

  // Plot the histogram
  CapyImg* img = $(that, plotHistogram)(vals, dims, &legend);

  // Free memory
  CapyArrDoubleFree(&vals);

  // Return the result image
  return img;
}

// Plot the density distributions of all fields in a CapyDataset filtered per
// value of a given categorical field.
// Input:
//   dataset: the CapyDataset
//   iCatField: the index of the categorical field
//   dims: the dimension of one graph
//   nbBin: number of bins for the density distribution
// Output:
//   Return a CapyImg containing the result graphs. Each graph represent
//   the combination of one value of the categorical field and one other
//   field. Categorical field values are aligned top-down and other fields
//   values are aligned left-right. In one plot, the ordinate corresponds to
//   the bins from the min to the max value of the numerical value and the
//   abciss corresponds to the number of record in the dataset for that bin.
static CapyImg* PlotAllDensityDistributionsGivenCatField(
  CapyDataset const* const dataset,
              size_t const iCatField,
  CapyImgDims const* const dims,
              size_t const nbBin) {
  methodOf(CapyGraphPlotter);

  // Create the result image
  CapyDatasetFieldDesc const* catField = dataset->fields + iCatField;
  CapyImgDims resDims = {
    .width = dims->width * (CapyImgDims_t)(dataset->nbField - 1),
    .height = dims->height * (CapyImgDims_t)(catField->nbCategoryVal)
  };
  CapyImg* img = CapyImgAlloc(capyImgMode_rgb, resDims);

  // Variable to memorise the index of the field corrected for the skipped
  // categorical field
  size_t iSkippedField = 0;

  // Loop on the fields in the dataset except the given categorical field
  loop(iField, dataset->nbField) if(iField != iCatField) {

    // Loop on the values of the categorical field
    loop(iCatVal, catField->nbCategoryVal) {

      // Get the plot for the other field and that categorical value
      CapyImg* imgPlot = $(that, plotDensityDistributionsGivenCatValue)(
        dataset, iCatField, catField->categoryVals[iCatVal], iField,
        dims, nbBin);

      // Paste the plot in the result image
      CapyImgPos posPlot = {
        .x = (CapyImgPos_t)(dims->width * iSkippedField),
        .y = (CapyImgPos_t)(dims->height * iCatVal)
      };
      $(imgPlot, pasteInto)(img, &posPlot);

      // Free memory
      CapyImgFree(&imgPlot);
    }

    // Increment the corrected field index;
    ++iSkippedField;
  }
  return img;
}

// Plot the records of a dataset on a 2D graph using the values of two
// given fields.
// Input:
//   dataset: the CapyDataset
//   iField: the index of the field in abciss
//   jField: the index of the field in ordinate
//   dims: the dimension the graph
// Output:
//   Return a CapyImg containing the result graph.
static CapyImg* PlotValuesForGivenPairOfFields(
  CapyDataset const* const dataset,
              size_t const iField,
              size_t const jField,
  CapyImgDims const* const dims) {
  methodOf(CapyGraphPlotter);

  // Create the result image
  CapyImg* img = CapyImgAlloc(capyImgMode_rgb, *dims);

  // Set the background to white
  $(img, fillWithColor)(&capyColorRGBAWhite);

  // Draw the legend
  CapyGraphPlotterLegendData legendData = {
    .titleAbciss = dataset->fields[iField].label,
    .titleOrdinate = dataset->fields[jField].label,
  };
  $(that, drawLegend)(img, &legendData);

  // Flag to memorise if we center values (for categorical fields)
  bool flagCenter[2] = {false, false};
  flagCenter[0] = (dataset->fields[iField].type == capyDatasetFieldType_cat);
  flagCenter[1] = (dataset->fields[jField].type == capyDatasetFieldType_cat);

  // Get the bounding box of the data area
  CapyRectangle* bbox = $(that, getBoundingBoxDataArea)(img);
  CapyRangeDouble rangeXFrom = dataset->fields[iField].range;
  if(flagCenter[0]) rangeXFrom.max += 1.0;
  CapyRangeDouble rangeXTo = {
    .min = bbox->corners[0].x, .max = bbox->corners[1].x
  };
  CapyRangeDouble rangeYFrom = dataset->fields[jField].range;
  if(flagCenter[1]) rangeYFrom.max += 1.0;
  CapyRangeDouble rangeYTo = {
    .min = bbox->corners[1].y, .max = bbox->corners[0]. y
  };

  // Loop on the records
  loop(iRow, dataset->nbRow) {

    // Get the pair of values in the row
    bool areValueValid = true;
    double pos[2] = {0, 0};
    try {
      pos[0] = $(dataset, getValAsNum)(iRow, iField);
      pos[1] = $(dataset, getValAsNum)(iRow, jField);
      loop(i, 2) if(that->logScales[i] && pos[i] > 0.0) pos[i] = log10(pos[i]);
    } catch(CapyExc_InvalidParameters) {
      areValueValid = false;
    } endCatch;

    // If the requested fields have non null values in the row
    if(areValueValid) {

      // Plot the row in the graph
      loop(i, 2) if(flagCenter[i]) pos[i] += 0.5;
      pos[0] = CapyLerp(pos[0], &rangeXFrom, &rangeXTo);
      pos[1] = CapyLerp(pos[1], &rangeYFrom, &rangeYTo);
      $(&(that->penData), drawPoint)(pos, img);
    }
  }

  // Free memory
  CapyRectangleFree(&bbox);

  // Return the result image
  return img;
}

// Plot the correlation graph between all pair of variables.
// Input:
//   dataset: the CapyDataset
//   dims: the dimension of one graph
//   typeCorrelation: type of correlation
// Output:
//   Return a CapyImg containing the result graphs. Each graph represent
//   a pair of fields. Bottom left graphs contains a plot of records in
//   the dataset.
static CapyImg* PlotAllCorrelationGraph(
               CapyDataset const* const dataset,
               CapyImgDims const* const dims,
  CapyGraphPlotterTypeCorrelation const typeCorrelation) {
  methodOf(CapyGraphPlotter);

  // Create the result image
  CapyImgDims resDims = {
    .width = dims->width * (CapyImgDims_t)(dataset->nbField),
    .height = dims->height * (CapyImgDims_t)(dataset->nbField)
  };
  CapyImg* img = CapyImgAlloc(capyImgMode_rgb, resDims);

  // Loop on the pairs of field
  loop(iField, dataset->nbField) loop(jField, dataset->nbField) {

    // If we are in the bottom left of the plot
    if(iField < jField) {

      // Plot the dataset records for the pair of fields
      CapyImg* imgPlot =
        $(that, plotValuesForGivenPairOfFields)(dataset, iField, jField, dims);

      // Paste the plot in the result image
      CapyImgPos posPlot = {
        .x = (CapyImgPos_t)(dims->width * iField),
        .y = (CapyImgPos_t)(dims->height * jField)
      };
      $(imgPlot, pasteInto)(img, &posPlot);

      // Free memory
      CapyImgFree(&imgPlot);

    // Else, if we are in the top right of the plot
    } else if(iField > jField) {

      // Plot the correlation
      CapyVec u = {0};
      CapyVec v = {0};
      $(dataset, getValuesFromTwoFieldsAsVectors)(iField, jField, &u, &v);
      double correlation = 0.0;
      if(typeCorrelation == capyGraphPlotterTypeCorrelation_pearson) {
        correlation = CapyVecGetPearsonCorrelation(&u, &v);
      } else if(typeCorrelation == capyGraphPlotterTypeCorrelation_distance) {
        correlation = CapyVecGetDistanceCorrelation(&u, &v);
      } else raiseExc(CapyExc_UndefinedExecution);
      CapyRectangle rectBin = CapyRectangleCreate();
      loop(i, (size_t)2) {
        rectBin.corners[i].x = (double)(dims->width * (iField + i));
        rectBin.corners[i].y = (double)(dims->height * (jField + i));
      }
      CapyColorData backupColor = that->penData.color;
      that->penData.color = capyColorRGBABlack;
      if(typeCorrelation == capyGraphPlotterTypeCorrelation_pearson) {
        that->penData.color.vals[0] = 0.5 + 0.5 * correlation;
      } else if(typeCorrelation == capyGraphPlotterTypeCorrelation_distance) {
        that->penData.color.vals[0] = correlation;
      } else raiseExc(CapyExc_UndefinedExecution);
      that->penData.color.vals[2] = 1.0 - that->penData.color.vals[0];
      $(&(that->penData), drawFilledRectangle)(&rectBin, img);
      that->penData.color = backupColor;
      $(&(that->penLegend), drawRectangle)(&rectBin, img);
      backupColor = that->penLegend.color;
      that->penLegend.color = capyColorRGBAWhite;
      char* text = strCreate(
        "%s\n%s\n%.2f",
        dataset->fields[iField].label,
        dataset->fields[jField].label,
        correlation);
      double pos[2] = {
        (double)(dims->width) * (0.1 + (double)iField),
        (double)(dims->height) * (0.1 + (double)jField),
      };
      CapyFont font = CapyFontCreate();
      loop(i, 2) font.scale.vals[i] = that->marginLegend * 0.5;
      font.spacing.vals[0] = 1.2;
      font.spacing.vals[1] = 1.5;
      $(&(that->penLegend), drawText)(pos, text, &font, img);
      that->penLegend.color = backupColor;
      $(&font, destruct)();
      free(text);
      $(&rectBin, destruct)();
      CapyVecDestruct(&u);
      CapyVecDestruct(&v);

    // Else, we are on the diagonal of the plot, nothing here for now
    } else {
    }
  }

  // Return the result image
  return img;
}

// Free the memory used by a CapyGraphPlotter
static void Destruct(void) {
  methodOf(CapyGraphPlotter);
  $(&(that->penData), destruct)();
  $(&(that->penLegend), destruct)();
}

// Create a CapyGraphPlotter
// Output:
//   Return a CapyGraphPlotter
CapyGraphPlotter CapyGraphPlotterCreate(void) {
  CapyGraphPlotter that = {
    .margin = 10,
    .marginLegend = 20,
    .penData = CapyPenCreate(),
    .penLegend = CapyPenCreate(),
    .destruct = Destruct,
    .getBoundingBoxDataArea = GetBoundingBoxDataArea,
    .drawLegend = DrawLegend,
    .plotHistogram = PlotHistogram,
    .plotDensityDistributions = PlotDensityDistributions,
    .plotAllDensityDistributionsGivenCatField =
      PlotAllDensityDistributionsGivenCatField,
    .plotDensityDistributionsGivenCatValue =
      PlotDensityDistributionsGivenCatValue,
    .plotValuesForGivenPairOfFields = PlotValuesForGivenPairOfFields,
    .plotAllCorrelationGraph = PlotAllCorrelationGraph,
  };
  that.penData.color = capyColorRGBBlue;
  that.penLegend.color = capyColorRGBBlack;
  return that;
}

// Allocate memory for a new CapyGraphPlotter and create it
// Output:
//   Return a CapyGraphPlotter
// Exception:
//   May raise CapyExc_MallocFailed.
CapyGraphPlotter* CapyGraphPlotterAlloc(void) {
  CapyGraphPlotter* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyGraphPlotterCreate();
  return that;
}

// Free the memory used by a CapyGraphPlotter* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyGraphPlotter to free
void CapyGraphPlotterFree(CapyGraphPlotter** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}
