// ------------------------------ distribution.c ------------------------------
/*
    LibCapy - a general purpose library of C functions and data structures
    Copyright (C) 2021-2025 Pascal Baillehache baillehache.pascal@gmail.com
    https://baillehachepascal.dev
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.
    You should have received a copy of the GNU General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "distribution.h"
#include "random.h"

// Create a CapyDistEvt
// Output:
//   Return a CapyDistEvt.
CapyDistEvt CapyDistEvtCreate(void) {
  return (CapyDistEvt){.vec = {.dim = 0, .vals = NULL}, .id = 0, .ptr = NULL};
}

// Get the probability of a given event.
// Input:
//    x: the random variable describing the event
// Output:
//   Return the probability of the event
static double GetProbability(CapyDistEvt const* const evt) {
  (void)evt;
  raiseExc(CapyExc_UndefinedExecution);
  assert (false && "CapyDist.getProbability is undefined.");
}

// Get the surprise of a given event.
// Input:
//    evt: the event
// Output:
//   Return the surprise of the event (h(e) = log(1/p(e)). The higher the
//   surprise the less probable is the event.
static double GetSurprise(CapyDistEvt const* const evt) {
  methodOf(CapyDist);
  double const p = $(that, getProbability)(evt);
  return log(1.0 / p);
}

// Get the entropy of the distribution.
// Output:
//   Return the entropy (average of the surprise). The higher the entropy the
//   more uncertain the outcome of drawing a random sample.
static double GetEntropy(void) {
  raiseExc(CapyExc_UndefinedExecution);
  assert (false && "CapyDist.getEntropy is undefined.");
}

// Check if an event is within the most probable up to a given threshold
// Input:
//         evt: the event to check
//   threshold: the threshold
// Output:
//   Return true if the event is in the most probable events up to the
//   threshold.
static bool IsEvtInMostProbable(
  CapyDistEvt const* const evt,
              double const threshold) {
  (void)evt;
  (void)threshold;
  raiseExc(CapyExc_UndefinedExecution);
  assert (false && "CapyDist.isEvtInMostProbable is undefined.");
  return false;
}

// Free the memory used by a CapyDist
static void Destruct(void) {
  return;
}

// Create a CapyDist
// Input:
//   type: type of ditribution
// Output:
//   Return a CapyDist
CapyDist CapyDistCreate(CapyDistType const type) {
  CapyDist that = {
    .type = type,
    .destruct = Destruct,
    .getProbability = GetProbability,
    .getSurprise = GetSurprise,
    .getEntropy = GetEntropy,
    .isEvtInMostProbable = IsEvtInMostProbable,
  };
  return that;
}

// Free the memory used by a CapyDistContinuous
static void DistContinuousDestruct(void) {
  methodOf(CapyDistContinuous);
  $(that, destructCapyDist)();
  loop(iDim, that->dimEvt) $(that->range + iDim, destruct)();
  free(that->range);
}

// Create a CapyDistContinuous
// Input:
//   dimEvt: the dimension of the random variable describing an event
// Output:
//   Return a CapyDistContinuous
CapyDistContinuous CapyDistContinuousCreate(size_t const dimEvt) {
  CapyDistContinuous that = {
    .dimEvt = dimEvt,
    .range = NULL,
  };
  CapyInherits(that, CapyDist, (capyDistributionType_continuous));
  that.destruct = DistContinuousDestruct;
  safeMalloc(that.range, dimEvt);
  if(that.range) loop(iDim, dimEvt) {
    that.range[iDim] = CapyRangeDoubleCreate(0.0, 0.0);
  }
  return that;
}

// Constant used during calculation (equal to sqrt(2pi))
double const normalDistributionDensity_a = 2.5066282746310002;

// Get the probability of the given event for a normal distribution.
// Input:
//    evt: the random variable describing the event
// Output:
//   Return the probability of the event
static double DistNormalDensity(CapyDistEvt const* const evt) {
  methodOf(CapyDistNormal);

  // Variable to memorise the probability
  double prob = 1.0;

  // Loop on the dimensions and calculate the probability of the event
  // Only valid if the dimensions of the random variable are not
  // correlated
  loop(iDim, that->dimEvt) {
    double x = (evt->vec.vals[iDim] - that->mean[iDim]) / that->stdDev[iDim];
    prob *=
      exp(
        -0.5 * pow(x, 2.0)) / (that->stdDev[iDim] * normalDistributionDensity_a
      );
  }

  // Return the probability
  return prob;
}

// Get the derivative of the probability of a given event along a given
// axis for a normal distribution.
// Input:
//      evt: the random variable describing the event
//    iAxis: the derivative axis
// Output:
//   Return the derivative of the probability of the event
static double DistNormalDensityDerivative(
  CapyDistEvt const* const evt,
              size_t const iAxis) {
  methodOf(CapyDistNormal);

  // Variable to memorise the probability
  double prob = $(that, getProbability)(evt);
  prob *=
    -(evt->vec.vals[iAxis] - that->mean[iAxis]) / pow(that->stdDev[iAxis], 2.0);

  // Return the probability
  return prob;
}

// Free the memory used by a CapyDistNormal
static void DistNormalDestruct(void) {
  methodOf(CapyDistNormal);
  $(that, destructCapyDistContinuous)();
  free(that->mean);
  free(that->stdDev);
}

// Create a CapyDistNormal
// Input:
//   dimEvent: the dimension of the random variable describing an event
//       mean: the means in each dimension of the random variable
//     stdDev: the standard deviations in each dimension of the random
//             variable
// Output:
//   Return a CapyDistNormal
CapyDistNormal CapyDistNormalCreate(
         size_t const dimEvent,
  double const* const mean,
  double const* const stdDev) {

  // Create the instance
  CapyDistNormal that;
  CapyInherits(that, CapyDistContinuous, (dimEvent));
  that.type = capyDistributionType_normal;
  that.getProbability = DistNormalDensity;
  that.getProbabilityDerivative = DistNormalDensityDerivative;
  that.destruct = DistNormalDestruct;

  // Copy the means and standard deviations
  safeMalloc(that.mean, dimEvent);
  safeMalloc(that.stdDev, dimEvent);
  if(that.mean) memcpy(that.mean, mean, dimEvent * sizeof(double));
  if(that.stdDev) memcpy(that.stdDev, stdDev, dimEvent * sizeof(double));

  // Set the range of value per dimension
  loop(iDim, dimEvent) {
    that.range[iDim].min = that.mean[iDim] - 4.0 * that.stdDev[iDim];
    that.range[iDim].max = that.mean[iDim] + 4.0 * that.stdDev[iDim];
  }
  return that;
}

// Allocate memory for a new CapyDistNormal and create it
// Input:
//   dimEvent: the dimension of the random variable describing an event
//       mean: the means in each dimension of the random variable
//     stdDev: the standard deviations in each dimension of the random
//             variable
// Output:
//   Return a CapyDistNormal
CapyDistNormal* CapyDistNormalAlloc(
         size_t const dimEvent,
  double const* const mean,
  double const* const stdDev) {

  // Allocate meory and cretae the instance
  CapyDistNormal* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyDistNormalCreate(dimEvent, mean, stdDev);
  return that;
}

// Free the memory used by a CapyDistNormal
// Input:
//   that: the CapyDistNormal to free
void CapyDistNormalFree(CapyDistNormal** const that) {

  // Destruct the instance and free the memory
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}

// Array of CapyDistDiscreteOccurence
CapyDefArray(CapyDistDiscreteOccs, CapyDistDiscreteOccurence)

// Get the probability of a given event.
// Input:
//    evt: the event
// Output:
//   Return the probability of the event
static double DistDiscreteGetProbability(CapyDistEvt const* const evt) {
  methodOf(CapyDistDiscrete);
  forEach(occ, that->occ->iter) {
    if(occ.evt.ptr == evt->ptr && occ.evt.id == evt->id) return occ.prob;
  }
  return 0.0;
}

// Check if an event is within the most probable up to a given threshold
// Input:
//         evt: the event to check
//   threshold: the threshold
// Output:
//   Return true if the event is in the most probable events up to the
//   threshold.
static bool DistDiscreteIsEvtInMostProbable(
  CapyDistEvt const* const evt,
              double const threshold) {
  methodOf(CapyDistDiscrete);
  double probEvt = 0.0;
  forEach(occ, that->occ->iter) {
    if(occ.evt.ptr == evt->ptr && occ.evt.id == evt->id) {
      probEvt = occ.prob;
      $(&(that->occ->iter), toLast)();
    }
  }
  double sumProb = 0.0;
  forEach(occ, that->occ->iter) {
    if(occ.prob > probEvt) sumProb += occ.prob;
    if(sumProb >= threshold) return false;
  }
  return (sumProb < threshold);
}

// Get the entropy of the distribution (for discrete distribution).
// Output:
//   Return the entropy (average of the surprise). The higher the entropy the
//   more uncertain the outcome of drawing a random sample.
static double DistDiscreteGetEntropy(void) {
  methodOf(CapyDistDiscrete);
  double entropy = 0.0;
  forEach(occ, that->occ->iter) {
    entropy += occ.prob * $(that, getSurprise)(&(occ.evt));
  }
  return entropy;
}

// Free the memory used by a CapyDistDiscrete
static void DistDiscreteDestruct(void) {
  methodOf(CapyDistDiscrete);
  $(that, destructCapyDist)();
  CapyDistDiscreteOccsFree(&(that->occ));
}

// Create a CapyDistDiscrete
// Input:
//   nbEvt: the number of events in the distribution
// Output:
//   Return a CapyDistDiscrete
CapyDistDiscrete CapyDistDiscreteCreate(size_t const nbEvt) {
  CapyDistDiscrete that;
  CapyInherits(that, CapyDist, (capyDistributionType_discrete));
  that.destruct = DistDiscreteDestruct;
  that.occ = CapyDistDiscreteOccsAlloc(nbEvt);
  that.getProbability = DistDiscreteGetProbability;
  that.getEntropy = DistDiscreteGetEntropy;
  that.isEvtInMostProbable = DistDiscreteIsEvtInMostProbable;
  forEach(occ, that.occ->iter) {
    occPtr->prob = 0.0;
    occPtr->evt = CapyDistEvtCreate();
  }
  return that;
}

// Allocate memory for a CapyDistDiscrete and create a CapyDistDiscrete
// Input:
//   nbEvt: the number of events in the distribution
// Output:
//   Return a CapyDistDiscrete
CapyDistDiscrete* CapyDistDiscreteAlloc(size_t const nbEvt) {
  CapyDistDiscrete* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyDistDiscreteCreate(nbEvt);
  return that;
}

// Free the memory used by a CapyDistDiscrete
// Input:
//   that: the CapyDistDiscrete to free
void CapyDistDiscreteFree(CapyDistDiscrete** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}


// Get the cross entropy of two discrete distributions
// Input:
//   distA: the first distribution
//   distB: the second distribution
// Output:
//   Return the cross entropy of distA relative to distB. It is higher or
//   equal to the entropy of distA, and increase with the discrepancy between
//   the probabilities of the two distribtions.
double CapyDistDiscreteGetCrossEntropy(
  CapyDistDiscrete const* const distA,
  CapyDistDiscrete const* const distB) {
  double entropy = 0.0;
  forEach(occ, distA->occ->iter) {
    entropy += occ.prob * $(distB, getSurprise)(&(occ.evt));
  }
  return entropy;
}

// Get the KL divergence of two discrete distributions
// Input:
//   distA: the first distribution
//   distB: the second distribution
// Output:
//   Return the KL divergence of distA relative to distB (equals to cross
//   entropy of (distA, distB) minus entropy of distA. If the distribution are
//   the same it returns 0.0. The more they diverge the higher the returned
//   value.
double CapyDistDiscreteGetKLDivergence(
  CapyDistDiscrete const* const distA,
  CapyDistDiscrete const* const distB) {
  double entropy = 0.0;
  forEach(occ, distA->occ->iter) {
    entropy += occ.prob * log(occ.prob / $(distB, getProbability)(&(occ.evt)));
  }
  return entropy;
}
