// ------------------------------ gradientDescent.h -----------------------------
/*
    LibCapy - a general purpose library of C functions and data structures
    Copyright (C) 2021-2025 Pascal Baillehache baillehache.pascal@gmail.com
    https://baillehachepascal.dev
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.
    You should have received a copy of the GNU General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef CAPY_GRADIENT_DESCENT_H
#define CAPY_GRADIENT_DESCENT_H
#include "externalHeaders.h"
#include "cext.h"
#include "capymath.h"
#include "mathfun.h"

// Description:
// Class implementing the gradient descent method.

// Type of gradient descent
typedef enum CapyGradientDescentType {
  capyGradientDescent_standard,
  capyGradientDescent_momentum,
  capyGradientDescent_adam,
} CapyGradientDescentType;

// GradientDescent object
typedef struct CapyGradientDescent {

  // Current step
  size_t iStep;

  // Current position in the input space
  CapyVec in;

  // Learn rate (default: 0.1)
  double learnRate;

  // Decay rates for adam (default: [0.9, 0.999]
  // The higher decayRates[0], the more the recent moment is taken into account
  // The higher decayRates[1], the more stable the estimate of the variance
  // decayRates[1] should be as near to 1 as possible for stable convergence
  double decayRates[2];

  // Epsilon for adam (default: 1e-8)
  double epsilon;

  // Momentum (default: 0)
  double momentum;

  // Current first moments
  CapyVec m;

  // Current second moments
  CapyVec v;

  // Current gradient
  CapyVec gradient;

  // Current minibatch gradient
  CapyVec minibatchGradient;

  // Reference to the objective function
  CapyMathFun* objFun;

  // Minibatch size (default: 1 equivalent to pure stochastic gradient
  // descent, suggested in [2, 32], the higher the slower the convergence but
  // the more stable and overall good convergence)
  size_t minibatchSize;

  // Index to manage the minibatch
  size_t iMinibatch;

  // Destructor
  void (*destruct)(void);

  // Perform one step of the gradient descent method
  // Output:
  //   that.in and that.gradient are updated
  void (*step)(void);

  // Set the type of the gradient descent
  // Input:
  //   type: the type of gradient descent
  // Output:
  //   Set the type of gradient descent
  void (*setType)(CapyGradientDescentType const type);
} CapyGradientDescent;

// Create a CapyGradientDescent
// Input:
//   objFun: objective function (must have dimOut = 1)
//   initIn: initial position in input space (must be of dimension objFun.dimIn)
// Output:
//   Return a CapyGradientDescent (of standard type)
CapyGradientDescent CapyGradientDescentCreate(
   CapyMathFun* const objFun,
  double const* const initIn);

// Allocate memory for a new CapyGradientDescent and create it
// Input:
//   objFun: objective function (must have dimOut = 1)
//   initIn: initial position in input space (must be of dimension objFun.dimIn)
// Output:
//   Return a CapyGradientDescent (of standard type)
// Exception:
//   May raise CapyExc_MallocFailed.
CapyGradientDescent* CapyGradientDescentAlloc(
   CapyMathFun* const objFun,
  double const* const initIn);

// Free the memory used by a CapyGradientDescent* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyGradientDescent to free
void CapyGradientDescentFree(CapyGradientDescent** const that);
#endif
