#include "capy.h"
#ifndef FIXTURE
#define FIXTURE
static void FunDivergence(
  double const* const in,
        double* const out) {
  out[0] = 4.0 * in[1] / (in[0] * in[0]);
  out[1] = sin(in[1]);
  out[2] = 3.0;
}

static void FunIntegral(
  double const* const in,
        double* const out) {
  out[0] = in[0] * in[1];
  out[1] = 3.0;
}

static void FunCombA(
  double const* const in,
        double* const out) {
  out[0] = in[0];
  out[1] = in[0] * 2.0;
}

static void FunCombB(
  double const* const in,
        double* const out) {
  out[0] = in[0] * 3.0;
  out[1] = in[0] * 4.0;
}

static void FunCombC(
  double const* const in,
        double* const out) {
  out[0] = in[0] * in[0] + in[1] * 2.0;
}

static void FunCombD(
  double const* const in,
        double* const out) {
  out[0] = -1.0 * in[0] + in[1] * in[0];
}

#endif
CUTEST(test001, "Padding") {
  CUTEST_ASSERT(
    offsetof(CapyMathFun, dimIn) == offsetof(CapyMathFun, dims[0]),
    "offset of dimIn and dims[0] doesn't match %lu!=%lu",
    offsetof(CapyMathFun, dimIn), offsetof(CapyMathFun, dims[0]));
  CUTEST_ASSERT(
    offsetof(CapyMathFun, dimOut) == offsetof(CapyMathFun, dims[1]),
    "offset of dimOut and dims[1] doesn't match %lu!=%lu",
    offsetof(CapyMathFun, dimOut), offsetof(CapyMathFun, dims[1]));
}

CUTEST(test002, "Divergence") {
  CapyMathFun fun = CapyMathFunCreate(3, 3);
  fun.eval = FunDivergence;
  double in[3] = {1, 0, 0};
  double div = $(&fun, evalDivergence)(in);
  double checkDiv = -8.0 * in[1] * pow(in[0], -3.0) + cos(in[1]);
  CUTEST_ASSERT(
    fabs(div - checkDiv) < 1e-6, "div=%lf != %lf", div, checkDiv);
  $(&fun, destruct)();
}

CUTEST(test003, "EvalIntegral") {
  CapyMathFun fun = CapyMathFunCreate(2, 2);
  fun.eval = FunIntegral;
  double out[2] = {0, 0};
  CapyRangeDouble domains[2] = {
    CapyRangeDoubleCreate(0.0, 1.0),
    CapyRangeDoubleCreate(1.0, 2.0),
  };
  $(&fun, evalIntegral)(domains, out);
  double check[2] = {0.75, 3.0};
  loop(i, 2) {
    CUTEST_ASSERT(
      fabs(out[i] - check[i]) < 1e-6, "integral=%lf != %lf", out[i], check[i]);
  }
  $(&fun, destruct)();
}

CUTEST(test004, "Alloc/free a CapyLinCombFun") {
  CapyLinCombFun* combFun = CapyLinCombFunAlloc(2, 1, 3);
  CUTEST_ASSERT(
    combFun->nbComb == 2 &&
    combFun->dimIn == 1 &&
    combFun->dimOut == 3 &&
    combFun->coeff.dim == 2 &&
    combFun->out.dim == 3 &&
    combFun->combFuns != NULL,
    "CapyLinCombFunAlloc failed");
  CapyLinCombFunFree(&combFun);
}

CUTEST(test005, "Eval a CapyLinCombFun") {
  size_t const dimIn = 1;
  size_t const dimOut = 2;
  CapyLinCombFun* combFun = CapyLinCombFunAlloc(2, dimIn, dimOut);
  CapyMathFun funs[2];
  funs[0] = CapyMathFunCreate(dimIn, dimOut);
  funs[0].eval = FunCombA;
  funs[1] = CapyMathFunCreate(dimIn, dimOut);
  funs[1].eval = FunCombB;
  combFun->combFuns[0] = funs + 0;
  combFun->combFuns[1] = funs + 1;
  double in[dimIn];
  double out[dimOut];
  combFun->bias.vals[0] = 10.0;
  combFun->bias.vals[1] = 20.0;
  combFun->coeff.vals[0] = 0.5;
  combFun->coeff.vals[1] = 0.25;
  in[0] = 1.0;
  $(combFun, eval)(in, out);
  CUTEST_ASSERT(
    fabs(out[0] - 11.25) < 1e-6 &&
    fabs(out[1] - 22) < 1e-6,
    "Eval failed");
  $(funs + 0, destruct)();
  $(funs + 1, destruct)();
  CapyLinCombFunFree(&combFun);
}

CUTEST(test006, "Linear regression on a CapyLinCombFun") {
  size_t const dimIn = 2;
  size_t const dimOut = 1;
  CapyLinCombFun* combFun = CapyLinCombFunAlloc(2, dimIn, dimOut);
  CapyMathFun funs[2];
  funs[0] = CapyMathFunCreate(dimIn, dimOut);
  funs[0].eval = FunCombC;
  funs[1] = CapyMathFunCreate(dimIn, dimOut);
  funs[1].eval = FunCombD;
  combFun->combFuns[0] = funs + 0;
  combFun->combFuns[1] = funs + 1;
  CapyDataset* dataset = CapyDatasetAlloc();
  $(dataset, loadFromPath)("UnitTests/TestMathFun/linreg.csv");
  $(combFun, linearRegression)(dataset, 0);
  CUTEST_ASSERT(
    fabs(combFun->coeff.vals[0] - 2.0) < 1e-6 &&
    fabs(combFun->coeff.vals[1] - 3.0) < 1e-6 &&
    fabs(combFun->bias.vals[0] - 1.0) < 1e-6,
    "Linear regression failed failed");
  $(funs + 0, destruct)();
  $(funs + 1, destruct)();
  CapyLinCombFunFree(&combFun);
  CapyDatasetFree(&dataset);
}
