#include "capy.h"
#ifndef FIXTURE
#define FIXTURE
#endif
CUTEST(test001, "normal distribution") {
  size_t dimEvent = 2;
  double mean[2] = {0.0, 1.0};
  double stdDev[2] = {1.0, 0.5};
  CapyDistNormal* dist = CapyDistNormalAlloc(dimEvent, mean, stdDev);
  double event[2] = {0.0, 1.0};
  CapyDistEvt evt = {.vec = {.dim = 2, .vals = event}};
  double prob = $(dist, getProbability)(&evt);
  CUTEST_ASSERT(
    equal(prob, 0.31830988618379069122), "unexpected probability %lf", prob);
  event[0] = 1.0; event[1] = 2.0;
  prob = $(dist, getProbability)(&evt);
  CUTEST_ASSERT(
    equal(prob, 0.02612846656936984591), "unexpected probability %lf", prob);
  CapyDistNormalFree(&dist);
  CUTEST_ASSERT(dist == NULL, "dist not reset");
}

CUTEST(test002, "get surprise") {
  size_t const nbFace = 6;
  CapyDistDiscrete* dice = CapyDistDiscreteAlloc(nbFace);
  forEach(evt, dice->occ->iter) {
    evtPtr->prob = 1.0 / (double)nbFace;
    evtPtr->evt.id = dice->occ->iter.idx;
  }
  CapyDistDiscreteOccurence const occ = $(dice->occ, get)(0);
  double const h = $(dice, getSurprise)(&(occ.evt));
  CUTEST_ASSERT(fabs(h - 1.791759) < 1e-6, "h=%lf", h);
  CapyDistDiscreteFree(&dice);
}

CUTEST(test003, "get entropy discrete distribution") {
  size_t const nbFace = 6;
  CapyDistDiscrete* dice = CapyDistDiscreteAlloc(nbFace);
  forEach(evt, dice->occ->iter) {
    evtPtr->prob = 1.0 / (double)nbFace;
    evtPtr->evt.id = dice->occ->iter.idx;
  }
  double const h = $(dice, getEntropy)();
  CUTEST_ASSERT(fabs(h - 1.791759) < 1e-6, "h=%lf", h);
  CapyDistDiscreteFree(&dice);
}

CUTEST(test004, "get cross entropy discrete distribution") {
  size_t const nbFace = 2;
  CapyDistDiscrete* fairDice = CapyDistDiscreteAlloc(nbFace);
  CapyDistDiscrete* riggedDice = CapyDistDiscreteAlloc(nbFace);
  forEach(evt, fairDice->occ->iter) {
    evtPtr->prob = 1.0 / (double)nbFace;
    evtPtr->evt.id = fairDice->occ->iter.idx;
  }
  forEach(evt, riggedDice->occ->iter) {
    double probs[2] = {0.99, 0.01};
    evtPtr->prob = probs[riggedDice->occ->iter.idx];
    evtPtr->evt.id = riggedDice->occ->iter.idx;
  }
  double const h[3] = {
    CapyDistDiscreteGetCrossEntropy(fairDice, fairDice),
    CapyDistDiscreteGetCrossEntropy(fairDice, riggedDice),
    CapyDistDiscreteGetCrossEntropy(riggedDice, fairDice),
  };
  CUTEST_ASSERT(
    fabs(h[0] - 0.693147) < 1e-6 &&
    fabs(h[1] - 2.307610) < 1e-6 &&
    fabs(h[2] - 0.693147) < 1e-6,
    "h={%lf, %lf, %lf}", h[0], h[1], h[2]);
  CapyDistDiscreteFree(&fairDice);
  CapyDistDiscreteFree(&riggedDice);
}

CUTEST(test005, "get KL divergence discrete distribution") {
  size_t const nbFace = 2;
  CapyDistDiscrete* fairDice = CapyDistDiscreteAlloc(nbFace);
  CapyDistDiscrete* riggedDice = CapyDistDiscreteAlloc(nbFace);
  forEach(evt, fairDice->occ->iter) {
    evtPtr->prob = 1.0 / (double)nbFace;
    evtPtr->evt.id = fairDice->occ->iter.idx;
  }
  forEach(evt, riggedDice->occ->iter) {
    double probs[2] = {0.99, 0.01};
    evtPtr->prob = probs[riggedDice->occ->iter.idx];
    evtPtr->evt.id = riggedDice->occ->iter.idx;
  }
  double const h[2] = {
    CapyDistDiscreteGetKLDivergence(fairDice, fairDice),
    CapyDistDiscreteGetKLDivergence(fairDice, riggedDice),
  };
  CUTEST_ASSERT(
    fabs(h[0] - 0.0) < 1e-6 &&
    fabs(h[1] - 1.614463) < 1e-6,
    "h={%lf, %lf}", h[0], h[1]);
  CapyDistDiscreteFree(&fairDice);
  CapyDistDiscreteFree(&riggedDice);
}
