#include "capy.h"
#ifndef FIXTURE
#define FIXTURE
#endif
CUTEST(test001, "iris dataset, SVM") {
  CapyDataset* dataset = CapyDatasetAlloc();
  char const* pathDataset = "Resources/iris.csv";
  $(dataset, loadFromPath)(pathDataset);
  CapySupportVectorMachine* svm = CapySupportVectorMachineAlloc();
  CapySVMKernelGaussian kernel = CapySVMKernelGaussianCreate();
  kernel.gamma = 10.0;
  $(svm, setKernel)((CapySVMKernel*)&kernel);
  svm->coeffRelax = 100.0;
  svm->verbose = false;
  svm->seed = 0;
  CapyKfoldCrossValid* valid = CapyKfoldCrossValidAlloc(10);
  valid->verbose = false;
  valid->accType = capyPredictorAccuracyMeasure_accuracy;
  CapyStreamIo streamSplit = CapyStreamIoCreate();
  $(&streamSplit, open)("Resources/iris_split.txt", "r");
  valid->streamSplit = streamSplit.stream;
  CapyKfoldCrossValidResPredictor eval =
    $(valid, evalPredictor)((CapyPredictor*)svm, dataset);
  double check[2][3] =
    {{0.666667, 0.666667, 0.666667}, {0.533333, 0.613333, 0.666667}};
  loop(i, 2) loop(j, 3) {
    CUTEST_ASSERT(
      fabs(eval.accuracy[i][j] - check[i][j]) < 0.001,
      "evalPredictor failed, [%d][%d]=%lf != %lf",
      i, j, eval.accuracy[i][j], check[i][j]);
  }
  $(&streamSplit, destruct)();
  $(&eval, destruct)();
  CapyKfoldCrossValidFree(&valid);
  CapySupportVectorMachineFree(&svm);
  $(&kernel, destruct)();
  CapyDatasetFree(&dataset);
}

CUTEST(test002, "Padding (1)") {
  CUTEST_ASSERT(
    offsetof(CapyKfoldCrossValidResPredictor, evalTraining) ==
      offsetof(CapyKfoldCrossValidResPredictor, eval[0]),
    "offset of evalTraining and eval[0] doesn't match %lu!=%lu",
    offsetof(CapyKfoldCrossValidResPredictor, evalTraining),
    offsetof(CapyKfoldCrossValidResPredictor, eval[0]));
  CUTEST_ASSERT(
    offsetof(CapyKfoldCrossValidResPredictor, evalValidation) ==
      offsetof(CapyKfoldCrossValidResPredictor, eval[1]),
    "offset of evalValidation and eval[1] doesn't match %lu!=%lu",
    offsetof(CapyKfoldCrossValidResPredictor, evalValidation),
    offsetof(CapyKfoldCrossValidResPredictor, eval[1]));
}

CUTEST(test003, "Padding (2)") {
  CUTEST_ASSERT(
    offsetof(CapyKfoldCrossValidResPredictor, accTraining) ==
      offsetof(CapyKfoldCrossValidResPredictor, accuracy[0]),
    "offset of accTraining and accuracy[0] doesn't match %lu!=%lu",
    offsetof(CapyKfoldCrossValidResPredictor, accTraining),
    offsetof(CapyKfoldCrossValidResPredictor, accuracy[0]));
  CUTEST_ASSERT(
    offsetof(CapyKfoldCrossValidResPredictor, accValidation) ==
      offsetof(CapyKfoldCrossValidResPredictor, accuracy[1]),
    "offset of accValidation and accuracy[1] doesn't match %lu!=%lu",
    offsetof(CapyKfoldCrossValidResPredictor, accValidation),
    offsetof(CapyKfoldCrossValidResPredictor, accuracy[1]));
}

CUTEST(test005, "iris dataset, NN") {
  CapyDataset* dataset = CapyDatasetAlloc();
  char const* pathDataset = "Resources/iris.csv";
  $(dataset, loadFromPath)(pathDataset);
  CapyNNActivationSiLU activation = CapyNNActivationSiLUCreate();
  CapyNNModel layerDef = {
    .nbLayer = 1,
    .layers = (CapyNNLayerDef[]){
      {.nbNode = 8, .activation = (CapyNNActivationFun*)&activation},
    },
  };
  CapyNNPredictor* predictor =
    CapyNNPredictorAlloc(&layerDef, capyPredictorType_categorical);
  predictor->verbose = false;
  predictor->timeTraining = 0.0;
  predictor->nbIterTrainMax = 200;
  predictor->momentum = 0.01;
  predictor->seed = 2;
  predictor->batchSize = 150;
  predictor->featureScaling = capyPredictorFeatureScaling_minMaxNormalization;
  CapyKfoldCrossValid* valid = CapyKfoldCrossValidAlloc(10);
  valid->verbose = false;
  valid->accType = capyPredictorAccuracyMeasure_accuracy;
  CapyStreamIo streamSplit = CapyStreamIoCreate();
  $(&streamSplit, open)("Resources/iris_split.txt", "r");
  valid->streamSplit = streamSplit.stream;
  CapyKfoldCrossValidResPredictor eval =
    $(valid, evalPredictor)((CapyPredictor*)predictor, dataset);
  double check[2][3] =
    {{0.977778, 0.983704, 0.985185}, {0.933333, 0.973333, 1.0}};
  loop(i, 2) loop(j, 3) {
    CUTEST_ASSERT(
      fabs(eval.accuracy[i][j] - check[i][j]) < 0.001,
      "evalPredictor failed, [%d][%d]=%lf != %lf",
      i, j, eval.accuracy[i][j], check[i][j]);
  }
  $(&streamSplit, destruct)();
  $(&eval, destruct)();
  CapyKfoldCrossValidFree(&valid);
  CapyNNPredictorFree(&predictor);
  $(&activation, destruct)();
  CapyDatasetFree(&dataset);
}

