// -------------------------- gradientDescent.c --------------------------
/*
    LibCapy - a general purpose library of C functions and data structures
    Copyright (C) 2021-2025 Pascal Baillehache baillehache.pascal@gmail.com
    https://baillehachepascal.dev
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.
    You should have received a copy of the GNU General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "gradientDescent.h"

// Free the memory used by a CapyGradientDescent
static void Destruct(void) {
  methodOf(CapyGradientDescent);
  CapyVecDestruct(&(that->in));
  CapyVecDestruct(&(that->gradient));
  CapyVecDestruct(&(that->minibatchGradient));
  CapyVecDestruct(&(that->m));
  CapyVecDestruct(&(that->v));
}

// Perform one step of the gradient descent method (standard)
// Output:
//   that.in and that.gradient are updated
static void Step(void) {
  methodOf(CapyGradientDescent);

  // Increment the step counter
  that->iStep += 1;

  // Calculate the gradient of the objective function and update the minibatch
  // gradient
  loop(i, that->in.dim) {
    $(that->objFun, evalDerivative)(that->in.vals, i, that->gradient.vals + i);
    that->minibatchGradient.vals[i] += that->gradient.vals[i];
  }

  // Step the minibatch
  that->iMinibatch += 1;
  if(that->iMinibatch >= that->minibatchSize) {

    // Update the input with the minibatch gradient
    loop(i, that->in.dim) {
      that->minibatchGradient.vals[i] /= (double)(that->iMinibatch);
      that->in.vals[i] -= that->learnRate * that->minibatchGradient.vals[i];
      that->minibatchGradient.vals[i] = 0.0;
    }
    that->iMinibatch = 0;
  }
}

// Perform one step of the gradient descent method (momentum)
// Output:
//   that.in and that.gradient are updated
static void StepMomentum(void) {
  methodOf(CapyGradientDescent);

  // Increment the step counter
  that->iStep += 1;

  // Calculate the gradient of the objective function
  loop(i, that->in.dim) {
    $(that->objFun, evalDerivative)(that->in.vals, i, that->gradient.vals + i);
    that->minibatchGradient.vals[i] += that->gradient.vals[i];
  }

  // Step the minibatch
  that->iMinibatch += 1;
  if(that->iMinibatch >= that->minibatchSize) {

    // Update the input with the minibatch gradient and momentum
    loop(i, that->in.dim) {
      that->minibatchGradient.vals[i] /= (double)(that->iMinibatch);
      that->m.vals[i] =
        that->momentum * that->m.vals[i] -
        that->learnRate * that->gradient.vals[i];
      that->in.vals[i] += that->m.vals[i];
      that->minibatchGradient.vals[i] = 0.0;
    }
    that->iMinibatch = 0;
  }
}

// Perform one step of the gradient descent method (adam)
// Output:
//   that.in and that.gradient are updated
static void StepAdam(void) {
  methodOf(CapyGradientDescent);

  // Increment the step counter
  that->iStep += 1;

  // Loop on the input dimension
  loop(iDim, that->in.dim) {

    // Calculate the gradient of the objective function and update the
    // minibatch gradient
    double d = 0.0;
    $(that->objFun, evalDerivative)(that->in.vals, iDim, &d);
    that->minibatchGradient.vals[iDim] += d;
  }

  // Step the minibatch
  that->iMinibatch += 1;
  if(that->iMinibatch >= that->minibatchSize) {

    // Loop on the input dimension
    loop(iDim, that->in.dim) {
      that->minibatchGradient.vals[iDim] /= (double)(that->iMinibatch);

      // Update the moments
      that->m.vals[iDim] =
        that->decayRates[0] * that->m.vals[iDim] +
        (1.0 - that->decayRates[0]) * that->minibatchGradient.vals[iDim];
      that->v.vals[iDim] =
        that->decayRates[1] * that->v.vals[iDim] +
        (1.0 - that->decayRates[1]) * that->minibatchGradient.vals[iDim] *
        that->minibatchGradient.vals[iDim];

      // Calculate the bias corrected moments
      double const mBiasCorr =
        that->m.vals[iDim] /
        (1.0 - pow(that->decayRates[0], (double)(that->iStep)));
      double const vBiasCorr =
        that->v.vals[iDim] /
        (1.0 - pow(that->decayRates[1], (double)(that->iStep)));

      // Update the adam gradient
      that->gradient.vals[iDim] = mBiasCorr / (sqrt(vBiasCorr) + that->epsilon);

      // Update the input with the adam gradient
      that->in.vals[iDim] -= that->learnRate * that->gradient.vals[iDim];
      that->minibatchGradient.vals[iDim] = 0.0;
    }
    that->iMinibatch = 0;
  }
}

// Set the type of the gradient descent
// Input:
//   type: the type of gradient descent
// Output:
//   Set the type of gradient descent
static void SetType(CapyGradientDescentType const type) {
  methodOf(CapyGradientDescent);
  if(type == capyGradientDescent_standard) {
    that->step = Step;
  } else if(type == capyGradientDescent_momentum) {
    that->step = StepMomentum;
  } else if(type == capyGradientDescent_adam) {
    that->step = StepAdam;
  } else {
    raiseExc(CapyExc_UndefinedExecution);
  }
}

// Create a CapyGradientDescent
// Input:
//   objFun: objective function (must have dimOut = 1)
//   initIn: initial position in input space (must be of dimension objFun.dimIn)
// Output:
//   Return a CapyGradientDescent (of standard type)
CapyGradientDescent CapyGradientDescentCreate(
   CapyMathFun* const objFun,
  double const* const initIn) {
  CapyGradientDescent that = {
    .in = CapyVecCreate(objFun->dimIn),
    .iStep = 0,
    .epsilon = 1e-8,
    .decayRates = {0.9, 0.999},
    .learnRate = 0.1,
    .gradient = CapyVecCreate(objFun->dimIn),
    .minibatchGradient = CapyVecCreate(objFun->dimIn),
    .m = CapyVecCreate(objFun->dimIn),
    .v = CapyVecCreate(objFun->dimIn),
    .minibatchSize = 1,
    .iMinibatch = 0,
    .momentum = 0.0,
    .objFun = objFun,
    .destruct = Destruct,
    .step = Step,
    .setType = SetType,
  };
  loop(i, that.in.dim) that.in.vals[i] = initIn[i];
  return that;
}

// Allocate memory for a new CapyGradientDescent and create it
// Input:
//   objFun: objective function (must have dimOut = 1)
//   initIn: initial position in input space (must be of dimension objFun.dimIn)
// Output:
//   Return a CapyGradientDescent (of standard type)
// Exception:
//   May raise CapyExc_MallocFailed.
CapyGradientDescent* CapyGradientDescentAlloc(
   CapyMathFun* const objFun,
  double const* const initIn) {
  CapyGradientDescent* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyGradientDescentCreate(objFun, initIn);
  return that;
}

// Free the memory used by a CapyGradientDescent* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyGradientDescent to free
void CapyGradientDescentFree(CapyGradientDescent** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}
