// ---------------------------- kfoldCrossValid.c ---------------------------
/*
    LibCapy - a general purpose library of C functions and data structures
    Copyright (C) 2021-2025 Pascal Baillehache baillehache.pascal@gmail.com
    https://baillehachepascal.dev
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.
    You should have received a copy of the GNU General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "kfoldCrossValid.h"
#include "random.h"

// Free the memory used by a CapyKfoldCrossValidResPredictor
static void KfoldCrossValidResPredictorDestruct(void) {
  methodOf(CapyKfoldCrossValidResPredictor);
  loop(i, 2) {
    loop(k, that->k) CapyPredictorEvaluationFree(that->eval[i] + k);
    free(that->eval[i]);
  }
}

// Create a CapyKfoldCrossValidResPredictor
// Input:
//   k: the number of fold
// Output:
//   Return a CapyKfoldCrossValidResPredictor
CapyKfoldCrossValidResPredictor CapyKfoldCrossValidResPredictorCreate(
  size_t const k) {
  CapyKfoldCrossValidResPredictor that;
  that.k = k;
  that.destruct = KfoldCrossValidResPredictorDestruct;
  loop(i, 2) safeMalloc(that.eval[i], k);
  loop(i, 2) loop(j, 3) that.accuracy[i][j] = 0.0;
  return that;
}

// Split a dataset into k pair of training dataset and evaluation dataset
// Input:
//      that: the CapyKfoldCrossValid
//   dataset: the original dataset
// Output:
//   Return an array of dataset which are clones of the original dataset having
//   the nbRow and rows properties modified to reflect the splitting. The array
//   is arranged as follow: arr[2i] is the training dataset for split 'i' and
//   arr[2i+1] is the evaluation dataset for the split 'i'.
static CapyDataset* Split(
  CapyKfoldCrossValid const* const that,
          CapyDataset const* const dataset) {

  // Allocate memory for the split datasets
  CapyDataset* splitDataset = NULL;
  safeMalloc(splitDataset, 2 * that->k);
  if(!splitDataset) return NULL;

  // Clone the split datasets
  loop(i, 2 * that->k) splitDataset[i] = *dataset;

  // Create the array of split indices for each row in the dataset.
  CapyArrSize idxSplit = CapyArrSizeCreate(dataset->nbRow);

  // If there is no splits definition
  if(that->streamSplit == NULL) {

    // First, attribute a split to each row uniformly, then shuffle the indices.
    loop(iRow, dataset->nbRow) {
      size_t idx = iRow % that->k;
      $(&idxSplit, set)(iRow, &idx);
    }
    CapyRandom rng = CapyRandomCreate(that->seed);
    $(&idxSplit, shuffle)(&rng);
    $(&rng, destruct)();

  // Else, there is a splits definition
  } else {

    // Read the header
    size_t header[3];
    safeFScanf(
      that->streamSplit, "%lu %lu %lu\n", header, header + 1, header + 2);
    if(header[0] != that->k || (header[1] + header[2]) != dataset->nbRow) {
      raiseExc(CapyExc_InvalidStream);
      return NULL;
    }

    // Skip the 9 following lines
    char buffer = ' ';
    loop(i, 9) {
      while(buffer != '\n' && !feof(that->streamSplit)) {
        safeFScanf(that->streamSplit, "%c", &buffer);
      }
      buffer = ' ';
    }

    // Load the indices from the last line
    loop(iRow, dataset->nbRow) {
      size_t idxRow;
      safeFScanf(that->streamSplit, "%lu ", &idxRow);
      size_t idx = iRow % that->k;
      $(&idxSplit, set)(idxRow, &idx);
    }
  }

  // Loop on the split
  loop(iSplit, that->k) {

    // Count the number of training/validating rows for that split
    size_t nb[2] = {0, 0};
    loop(iRow, dataset->nbRow) {
      if(idxSplit.data[iRow] != iSplit) ++(nb[0]); else ++(nb[1]);
    }

    // Allocate memory for the rows of that split's training/validating
    // dataset
    loop(i, (size_t)2) {
      splitDataset[2 * iSplit + i].nbRow = nb[i];
      safeMalloc(splitDataset[2 * iSplit + i].rows, nb[i]);
      if(!(splitDataset[2 * iSplit + i].rows)) return NULL;
    }

    // Clone the rows
    size_t jRow[2] = {0, 0};
    loop(iRow, dataset->nbRow) {
      if(idxSplit.data[iRow] != iSplit) {
        splitDataset[2 * iSplit].rows[jRow[0]] = dataset->rows[iRow];
        ++(jRow[0]);
      } else {
        splitDataset[2 * iSplit + 1].rows[jRow[1]] = dataset->rows[iRow];
        ++(jRow[1]);
      }
    }
  }

  // Free memory
  $(&idxSplit, destruct)();

  // Return the split datasets
  return splitDataset;
}

// Run the k-fold cross validation for a predictor and a dataset
// Input:
//   predictor: the predictor
//      dataset: the dataset
// Output:
//   The dataset is split into k folds (the original dataset is not modified),
//   the predictor is trained on all combination of (k-1) fold and evaluated
//   on the remaining fold.
typedef struct EvalPredictorArg {
  CapyKfoldCrossValid* that;
  size_t iSplit;
  CapyPredictor* predictor;
  CapyDataset* splitDataset;
  CapyKfoldCrossValidResPredictor* result;
  sem_t* semaphore;
} EvalPredictorArg;

static void* EvalPredictorThread(void* arg) {

  // Cast the argument
  CapyKfoldCrossValid* that = ((EvalPredictorArg*)arg)->that;
  size_t iSplit = ((EvalPredictorArg*)arg)->iSplit;
  CapyPredictor* predictor = ((EvalPredictorArg*)arg)->predictor;
  CapyDataset* splitDataset = ((EvalPredictorArg*)arg)->splitDataset;
  CapyKfoldCrossValidResPredictor* result = ((EvalPredictorArg*)arg)->result;
  sem_t* semaphore = ((EvalPredictorArg*)arg)->semaphore;

  // Train and evaluate the predictor on the training splits
  $(predictor, train)(splitDataset + 2 * iSplit);
  result->evalTraining[iSplit] =
    $(predictor, evaluate)(splitDataset + 2 * iSplit);

  // Evaluate the trained predictor on the evaluation split
  result->evalValidation[iSplit] =
    $(predictor, evaluate)(splitDataset + 2 * iSplit + 1);

  // If we are in verbose mode, print some info
  if(that->verbose && that->stream) {
    sem_wait(semaphore);
    fprintf(
      that->stream,
      "%lu-fold cross valid, split #%lu, "
      "acc. train %.3lf, valid %.3lf\n",
      that->k, iSplit,
      result->evalTraining[iSplit]->accuracies[that->accType],
      result->evalValidation[iSplit]->accuracies[that->accType]);
    fflush(that->stream);
    sem_post(semaphore);
  }

  // Free memory
  free(splitDataset[2 * iSplit].rows);
  free(splitDataset[2 * iSplit + 1].rows);

  // Nothing to return but necessary
  return NULL;
}

static CapyKfoldCrossValidResPredictor EvalPredictor(
      CapyPredictor* const predictor,
  CapyDataset const* const dataset) {
  methodOf(CapyKfoldCrossValid);

  // Allocate memory for the result
  CapyKfoldCrossValidResPredictor result =
    CapyKfoldCrossValidResPredictorCreate(that->k);

  // Split the dataset
  CapyDataset* splitDataset = Split(that, dataset);

  // Clone the predictor to be able to do multithreading without the training
  // leaking from one thread to the other
  size_t nbThread = that->k;
  CapyPredictor** clone = NULL;
  safeMalloc(clone, nbThread);
  if(!clone) return result;
  loop(i, nbThread) clone[i] = $(predictor, clone)();

  // Process the splits in parallel
  pthread_t thread[nbThread];
  sem_t semaphore;
  sem_init(&semaphore, 1, 1);
  EvalPredictorArg threadArgs[nbThread];
  loop(iSplit, that->k) {
    threadArgs[iSplit] = (EvalPredictorArg){
      .that = that,
      .iSplit = iSplit,
      .predictor = clone[iSplit],
      .splitDataset = splitDataset,
      .result = &result,
      .semaphore = &semaphore,
    };
    int ret =
      pthread_create(
        thread + iSplit, NULL, EvalPredictorThread, threadArgs + iSplit);
    if(ret != 0) raiseExc(CapyExc_ForkFailed);
  }

  // Wait for the thread to terminate
  loop(iSplit, nbThread) {
    int ret = pthread_join(thread[iSplit], NULL);
    if(ret != 0) raiseExc(CapyExc_ForkFailed);
  }

  // Update the accuracy
  loop(iSplit, that->k) {
    loop(i, 2) {
      double acc = $(result.eval[i][iSplit], getAccuracy)(that->accType);
      if(iSplit == 0 || result.accuracy[i][0] > acc) {
        result.accuracy[i][0] = acc;
      }
      result.accuracy[i][1] += acc;
      if(iSplit == 0 || result.accuracy[i][2] < acc) {
        result.accuracy[i][2] = acc;
      }
    }
  }
  loop(i, 2) result.accuracy[i][1] /= (double)(that->k);

  // Free memory
  free(splitDataset);
  sem_destroy(&semaphore);
  loop(i, that->k) CapyPredictorFree(clone + i);
  free(clone);

  // Return the result
  return result;
}

// Free the memory used by a CapyKfoldCrossValid
static void Destruct(void) {
  return;
}

// Create a CapyKfoldCrossValid
// Input:
//   k: the number of fold
// Output:
//   Return a CapyKfoldCrossValid
CapyKfoldCrossValid CapyKfoldCrossValidCreate(size_t const k) {
  return (CapyKfoldCrossValid){
    .k = k,
    .verbose = false,
    .stream = stdout,
    .streamSplit = NULL,
    .seed = 0,
    .accType = capyPredictorAccuracyMeasure_mae,
    .destruct = Destruct,
    .evalPredictor = EvalPredictor,
  };
}

// Allocate memory for a new CapyKfoldCrossValid and create it
// Input:
//   k: the number of fold
// Exception:
//   May raise CapyExc_MallocFailed.
CapyKfoldCrossValid* CapyKfoldCrossValidAlloc(size_t const k) {
  CapyKfoldCrossValid* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyKfoldCrossValidCreate(k);
  return that;
}

// Free the memory used by a CapyKfoldCrossValid* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyKfoldCrossValid to free
void CapyKfoldCrossValidFree(CapyKfoldCrossValid** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}
