// ---------------------------- policygradient.c ---------------------------
/*
    LibCapy - a general purpose library of C functions and data structures
    Copyright (C) 2021-2025 Pascal Baillehache info@baillehachepascal.dev
    https://baillehachepascal.dev
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.
    You should have received a copy of the GNU General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "policygradient.h"

// Create a CapyPGTransition
// Input:
//   nbFeature: number of features describing an environment
// Output:
//   Return a CapyPGTransition
CapyPGTransition CapyPGTransitionCreate(size_t const nbFeature) {
  CapyPGTransition transition = {
    .fromState = CapyVecCreate(nbFeature),
    .toState = CapyVecCreate(nbFeature),
  };
  return transition;
}

// Destruct a CapyPGTransition
void CapyPGTransitionDestruct(CapyPGTransition* const that) {
  CapyVecDestruct(&(that->fromState));
  CapyVecDestruct(&(that->toState));
}

// Step the environment
// Input:
//   action: the applied action
// Output:
//   Update the current state according to the action, and return the transition
static CapyPGTransition StepEnvironment(size_t const action) {
  (void)action;
  raiseExc(CapyExc_UndefinedExecution);
  assert(false && "PGEnvironment.step is undefined.");
  return (CapyPGTransition){0};
}

// Set the current state to an initial state
// Output:
//   The current state is set to an intial state
static void SetToInitialState(void) {
  raiseExc(CapyExc_UndefinedExecution);
  assert(false && "PGEnvironment.setToInitialState is undefined.");
}

// Check if the current state is an end state
// Output:
//   Return true if the current state is an end state, else false
static bool IsEndState(void) {
  raiseExc(CapyExc_UndefinedExecution);
  assert(false && "PGEnvironment.isEndState is undefined.");
  return true;
}

// Get an action for a given state according to their probabilities
// Input:
//   state: the state to use for evaluation
// Output:
//   Return the selected action.
static size_t GetAction(CapyVec const* const state) {
  methodOf(CapyPGEnvironment);

  // Evaluate the action probabilities according to the state
  $(that, getActionsProb)(state, &(that->actionsProb));

  // Get an action according to their probabilities
  size_t const action =
    $(&(that->rng), getIdxGivenProbVec)(&(that->actionsProb));

  // Return the action
  return action;
}

// Get the action with highest probability for a given state
// Input:
//   state: the state to use for evaluation
// Output:
//   Return the selected action.
static size_t GetBestAction(CapyVec const* const state) {
  methodOf(CapyPGEnvironment);

  // Evaluate the action probabilities according to the state
  $(that, getActionsProb)(state, &(that->actionsProb));

  // Get the action with highest probability
  size_t iBestAction = 0;
  loop(iAction, that->actionsProb.dim) {
    if(that->actionsProb.vals[iAction] > that->actionsProb.vals[iBestAction]) {
      iBestAction = iAction;
    }
  }

  // Return the action
  return iBestAction;
}

// Evaluate the action probabilities
// Input:
//   state: the state used for evaluation
//   actionsProb: the evaluated actions probability
// Output:
//   'actionsProb' is updated.
static void GetActionsProb(
  CapyVec const* const state,
        CapyVec* const actionsProb) {
  (void)state; (void)actionsProb;
  raiseExc(CapyExc_UndefinedExecution);
  assert(false && "PGEnvironment.getActionsProb is undefined.");
}

// Evaluate the value
// Input:
//   state: the state used for evaluation
// Output:
//   Return the evaluated value
static double GetValue(CapyVec const* const state) {
  (void)state;
  raiseExc(CapyExc_UndefinedExecution);
  assert(false && "PGEnvironment.getValue is undefined.");
  return 0.0;
}

// Evaluate the gradient of values
// Input:
//   state: the state used for evaluation
//   gradValue: the result gradient
// Output:
//   'gradValue' is updated.
static void GetGradientValue(
  CapyVec const* const state,
        CapyVec* const gradValue) {
  methodOf(CapyPGEnvironment);

   // Loop on the parameters
  loop(iParam, that->paramValue.dim) {

    // Memorise the original parameter value
    double const origParam = that->paramValue.vals[iParam];

    // Calculate the gradient value for the parameter
    double const epsilon = 1e-3;
    that->paramValue.vals[iParam] += epsilon;
    gradValue->vals[iParam] = $(that, getValue)(state);
    that->paramValue.vals[iParam] -= 2.0 * epsilon;
    gradValue->vals[iParam] -= $(that, getValue)(state);
    gradValue->vals[iParam] /= 2.0 * epsilon;

    // Reset the original parameter value
    that->paramValue.vals[iParam] = origParam;
  }
}

// Evaluate the gradient of actions probability
// Input:
//   state: the state used for evaluation
//   iAction: the action to be evaluated
//   gradProb: the result gradient
// Output:
//   'gradProb' is updated.
static void GetGradientActionsProb(
  CapyVec const* const state,
          size_t const iAction,
        CapyVec* const gradProb) {
  methodOf(CapyPGEnvironment);

   // Loop on the parameters
  loop(iParam, that->paramAction.dim) {

    // Memorise the original parameter value
    double const origParam = that->paramAction.vals[iParam];

    // Calculate the gradient value for the parameter
    double const epsilon = 1e-3;
    that->paramAction.vals[iParam] += epsilon;
    $(that, getActionsProb)(state, &(that->actionsProb));
    gradProb->vals[iParam] = that->actionsProb.vals[iAction];
    that->paramAction.vals[iParam] -= 2.0 * epsilon;
    $(that, getActionsProb)(state, &(that->actionsProb));
    gradProb->vals[iParam] -= that->actionsProb.vals[iAction];
    gradProb->vals[iParam] /= 2.0 * epsilon;

    // Reset the original parameter value
    that->paramAction.vals[iParam] = origParam;
  }
}

// Evaluate the gradient of actions log probability
// Input:
//   state: the state used for evaluation
//   iAction: the action to be evaluated
//   gradProb: the result gradient
// Output:
//   'gradProb' is updated.
static void GetGradientActionsLogProb(
  CapyVec const* const state,
          size_t const iAction,
        CapyVec* const gradProb) {
  methodOf(CapyPGEnvironment);

   // Loop on the parameters
  loop(iParam, that->paramAction.dim) {

    // Memorise the original parameter value
    double const origParam = that->paramAction.vals[iParam];

    // Calculate the gradient value for the parameter
    double const epsilon = 1e-3;
    that->paramAction.vals[iParam] += epsilon;
    $(that, getActionsProb)(state, &(that->actionsProb));
    if(that->actionsProb.vals[iAction] > 1e-9) {
      gradProb->vals[iParam] = log(that->actionsProb.vals[iAction]);
      that->paramAction.vals[iParam] -= 2.0 * epsilon;
      $(that, getActionsProb)(state, &(that->actionsProb));
      if(that->actionsProb.vals[iAction] > 1e-9) {
        gradProb->vals[iParam] -= log(that->actionsProb.vals[iAction]);
        gradProb->vals[iParam] /= 2.0 * epsilon;
      } else {
        gradProb->vals[iParam] = 0.0;
      }
    } else {
      gradProb->vals[iParam] = 0.0;
    }

    // Reset the original parameter value
    that->paramAction.vals[iParam] = origParam;
  }
}

// Free the memory used by a CapyPGEnvironment
static void DestructEnvironment(void) {
  methodOf(CapyPGEnvironment);
  $(&(that->rng), destruct)();
  CapyVecDestruct(&(that->paramAction));
  CapyVecDestruct(&(that->paramValue));
  CapyVecDestruct(&(that->curState));
  CapyVecDestruct(&(that->actionsProb));
}

// Create a CapyPGEnvironment
// Input:
//   nbFeature: number of features describing an environment
//   nbAction: number of possible actions
//   nbParamAction: number of parameters for actions probability evaluation
//   nbParamValue: number of parameters for value evaluation
//   seed: seed for the random number generator
// Output:
//   Return a CapyPGEnvironment
CapyPGEnvironment CapyPGEnvironmentCreate(
            size_t const nbFeature,
            size_t const nbAction,
            size_t const nbParamAction,
            size_t const nbParamValue,
  CapyRandomSeed_t const seed) {

  // Create the instance
  CapyPGEnvironment that = {
    .nbAction = nbAction,
    .paramAction = CapyVecCreate(nbParamAction),
    .paramValue = CapyVecCreate(nbParamValue),
    .curState = CapyVecCreate(nbFeature),
    .rng = CapyRandomCreate(seed),
    .actionsProb = CapyVecCreate(nbAction),
    .destruct = DestructEnvironment,
    .setToInitialState = SetToInitialState,
    .step = StepEnvironment,
    .isEndState = IsEndState,
    .getAction = GetAction,
    .getBestAction = GetBestAction,
    .getActionsProb = GetActionsProb,
    .getValue = GetValue,
    .getGradientValue = GetGradientValue,
    .getGradientActionsProb = GetGradientActionsProb,
    .getGradientActionsLogProb = GetGradientActionsLogProb,
  };

  // Return the instance
  return that;
}

// Allocate memory for a new CapyPGEnvironment and create it
// Input:
//   nbFeature: number of features describing an environment
//   nbAction: number of possible actions
//   nbParamAction: number of parameters for actions probability evaluation
//   nbParamValue: number of parameters for value evaluation
//   seed: seed for the random number generator
// Output:
//   Return a CapyPGEnvironment
// Exception:
//   May raise CapyExc_MallocFailed.
CapyPGEnvironment* CapyPGEnvironmentAlloc(
            size_t const nbFeature,
            size_t const nbAction,
            size_t const nbParamAction,
            size_t const nbParamValue,
  CapyRandomSeed_t const seed) {
  CapyPGEnvironment* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyPGEnvironmentCreate(
    nbFeature, nbAction, nbParamAction, nbParamValue, seed);
  return that;
}

// Free the memory used by a CapyPGEnvironment* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyPGEnvironment to free
void CapyPGEnvironmentFree(CapyPGEnvironment** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}

// Free the memory used by a CapyMarkovDecisionProcess
static void DestructTransitionRecorder(void) {
  methodOf(CapyPGTransitionRecorder);
  loop(i, that->nbTransition) {
    CapyPGTransitionDestruct(that->transitions + i);
  }
  free(that->transitions);
}

// Record one transition
// Input:
//   transition: the transition to be recorded
// Output:
//   A copy of the transition is added to the end of 'transitions' which
//   is realloced if necessary, 'nbTransition' and 'nbMaxTransition' are
//   updated as necessary.
static void AddTransition(CapyPGTransition const* const transition) {
  methodOf(CapyPGTransitionRecorder);
  if(that->nbTransition == that->nbMaxTransition) {
    that->nbMaxTransition *= 2;
    safeRealloc(that->transitions, that->nbMaxTransition);
  }
  that->transitions[that->nbTransition] = *transition;
  that->nbTransition += 1;
}

// Reset the recorder
// Output:
//   'nbTransition' is reset to 0.
static void ResetTransitionRecorder(void) {
  methodOf(CapyPGTransitionRecorder);
  loop(i, that->nbTransition) {
    CapyPGTransitionDestruct(that->transitions + i);
  }
  that->nbTransition = 0;
}

// Create a CapyPGTransitionRecorder
// Output:
//   Return a CapyPGTransitionRecorder
CapyPGTransitionRecorder CapyPGTransitionRecorderCreate(void) {
  CapyPGTransitionRecorder that = {
    .nbTransition = 0,
    .nbMaxTransition = 1024,
    .transitions = NULL,
    .destruct = DestructTransitionRecorder,
    .addTransition = AddTransition,
    .reset = ResetTransitionRecorder,
  };
  safeMalloc(that.transitions, that.nbMaxTransition);
  return that;
}

// Type of loss function (to commonalise code between reinforce and ppo)
typedef enum ActionLossType {
  actionLossType_reinforce,
  actionLossType_ppoClip,
  actionLossType_nb,
} ActionLossType;

// Loss function for the action probabilities
typedef struct LossAction {

  // Inherits CapyMathFun
  struct CapyMathFunDef;

  // Environment
  CapyPGEnvironment* env;

  // Transition
  CapyPGTransition const* transition;

  // Type of loss function
  ActionLossType type;
  CapyPad(ActionLossType, type);

  // Advantage value
  double advantage;

  // Param value of the old policy for PPO
  double* oldParam;

  // Clipping coefficient for PPO (in ]0,+inf[, default: 0.2, the lower the
  // more stable but the slower learning)
  double coeffClipping;

  // Destructor
  void (*destructCapyMathFun)(void);
} LossAction;

// Evaluate the loss for an action
// Input:
//   in: the action probability parameters to evaluate
//   out: the result loss value
// Output:
//   out is updated.
static void LossActionEval(
  double const* const in,
        double* const out) {
  methodOf(LossAction);

  // Backup the current action probability parameters of the environment
  double* const backupVals = that->env->paramAction.vals;

  // Set the action probability parameters to those in argument
  that->env->paramAction.vals = (double*)in;

  // Evaluate the action probabilities for the current state
  $(that->env, getActionsProb)(
    &(that->transition->fromState), &(that->env->actionsProb));

  // If the loss function is used for Reinforce algorithm
  if(that->type == actionLossType_reinforce) {

    // Apply log (equivalent to normalised gradient here)
    *out =
      log(that->env->actionsProb.vals[that->transition->action]) *
      that->advantage;

  // If the loss function is used for PPO clip algorithm
  } else if(that->type == actionLossType_ppoClip) {

    // Variable to memorise the ratio of new probability over old probability
    // for the current action
    double r = that->env->actionsProb.vals[that->transition->action];

    // Calculate the probabilities for the old policy
    that->env->paramAction.vals = that->oldParam;
    $(that->env, getActionsProb)(
      &(that->transition->fromState), &(that->env->actionsProb));

    // Update the ratio of new probability over old probability
    r /= that->env->actionsProb.vals[that->transition->action];

    // Get the clipped ratio
    double rClip = r;
    if(rClip < 1.0 - that->coeffClipping) {
      rClip = 1.0 - that->coeffClipping;
    } else if(rClip > 1.0 + that->coeffClipping) {
      rClip = 1.0 + that->coeffClipping;
    }

    // Calculate the loss 
    double a = r * that->advantage;
    double b = rClip * that->advantage;
    *out = (a < b ? a : b);

  // Else, unsupported type
  } else {
    assert(false && "Unsupported type for the loss function");
  }

  // Reset the current action probability parameters of the environment
  that->env->paramAction.vals = backupVals;
}

// Free the memory used by a LossAction
static void LossActionDestruct(void) {
  methodOf(LossAction);
  $(that, destructCapyMathFun)();
}

// Create a LossAction function for a given environment
// Input:
//   env: the environment
// Output:
//   Return a new LossAction.
static LossAction LossActionCreate(CapyPGEnvironment* env) {
  LossAction that;
  CapyInherits(that, CapyMathFun, (env->paramAction.dim, 1));
  that.env = env;
  that.type = actionLossType_reinforce;
  that.destruct = LossActionDestruct;
  that.eval = LossActionEval;
  that.oldParam = NULL;
  that.coeffClipping = 0.2;
  return that;
}

// Free the memory used by a LossAction* and reset '*that' to NULL
// Input:
//   that: a pointer to the LossAction to free
static void LossActionFree(LossAction** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}

// Allocate memory for a new LossAction and create it
// Input:
//   env: the environment
// Output:
//   Return a LossAction
// Exception:
//   May raise CapyExc_MallocFailed.
static LossAction* LossActionAlloc(CapyPGEnvironment* env) {
  LossAction* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = LossActionCreate(env);
  return that;
}

// Loss function for the state value
typedef struct LossValue {

  // Inherits CapyMathFun
  struct CapyMathFunDef;

  // Environment
  CapyPGEnvironment* env;

  // Transition
  CapyPGTransition const* transition;

  // Sum of discounted reward
  double sumReward;

  // Destructor
  void (*destructCapyMathFun)(void);
} LossValue;

// Evaluate the loss for a state value
// Input:
//   in: the state value parameters to evaluate
//   out: the result loss value
// Output:
//   out is updated.
static void LossValueEval(
  double const* const in,
        double* const out) {
  methodOf(LossValue);

  // Backup the state value parameters
  double* const backupVals = that->env->paramValue.vals;

  // Set the state value parameters to the given one
  that->env->paramValue.vals = (double*)in;

  // Evalute the loss
  *out =
    $(that->env, getValue)(&(that->transition->fromState)) - that->sumReward;
  *out *= *out;

  // Reset the state value parameters
  that->env->paramValue.vals = backupVals;
}

// Free the memory used by a LossValue
static void LossValueDestruct(void) {
  methodOf(LossValue);
  $(that, destructCapyMathFun)();
}

// Create a LossValue function for a given environment
// Input:
//   env: the environment
// Output:
//   Return a new LossValue.
static LossValue LossValueCreate(CapyPGEnvironment* env) {
  LossValue that;
  CapyInherits(that, CapyMathFun, (env->paramValue.dim, 1));
  that.env = env;
  that.destruct = LossValueDestruct;
  that.eval = LossValueEval;
  return that;
}

// Free the memory used by a LossAction* and reset '*that' to NULL
// Input:
//   that: a pointer to the LossAction to free
static void LossValueFree(LossValue** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}

// Allocate memory for a new LossAction and create it
// Input:
//   env: the environment
// Output:
//   Return a LossAction
// Exception:
//   May raise CapyExc_MallocFailed.
static LossValue* LossValueAlloc(CapyPGEnvironment* env) {
  LossValue* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = LossValueCreate(env);
  return that;
}

// Learn the weights of action probabilities and state value functions using
// the reinforce with baseline algorithm
// Inputs:
//   nbEpisode: number of training episode
// Output:
//   The environment's action probabilities parameters and state values
//   parameters are updated.
static void Reinforce(size_t const nbEpisode) {
  methodOf(CapyPolicyGradient);

  // Create a transition recorder to record the trajectories of the training
  // episodes
  CapyPGTransitionRecorder recorder = CapyPGTransitionRecorderCreate();

  // Reset the average reward and number of step
  that->avgReward = 0.0;
  that->avgFinalReward = 0.0;
  that->avgNbStep = 0.0;

  // Set the loss function type
  ((LossAction*)(that->gdAction->objFun))->type = actionLossType_reinforce;

  // Loop on the training episodes
  loop(iEpisode, nbEpisode) {

    // Select a starting state
    $(that->env, setToInitialState)();

    // Record a new trajectory
    $(&recorder, reset)();
    bool isEndState = $(that->env, isEndState)();
    size_t nbStep = 0;
    double avgReward = 0.0;
    while(isEndState == false && nbStep < that->nbMaxStep) {
      size_t const action = $(that->env, getAction)(&(that->env->curState));
      CapyPGTransition transition = $(that->env, step)(action);
      $(&recorder, addTransition)(&transition);
      avgReward += transition.reward;
      isEndState = $(that->env, isEndState)();
      nbStep += 1;
    }

    // Update the average of number of step
    that->avgNbStep += (double)(recorder.nbTransition) / (double)nbEpisode;

    // Update the average of reward
    that->avgReward += avgReward;
    that->avgFinalReward +=
      recorder.transitions[recorder.nbTransition - 1].reward;

    // Variable to memorise the sum of reward
    double sumReward = 0.0;

    // Loop on the trajectory backward
    loop(jStep, recorder.nbTransition) {
      size_t const iStep = recorder.nbTransition - 1 - jStep;
      CapyPGTransition const* const transition =
        recorder.transitions + iStep;

      // Update the sum of reward
      sumReward = transition->reward + that->discount * sumReward;

      // Get the sum of reward corrected with the value estimation as a baseline
      double const value = $(that->env, getValue)(&(transition->fromState));
      double const advantage = value - sumReward;

      // Step the gradient descent on value parameters
      ((LossValue*)(that->gdValue->objFun))->transition = transition;
      ((LossValue*)(that->gdValue->objFun))->sumReward = sumReward;
      that->gdValue->learnRate = that->learnRateState;
      $(that->gdValue, step)();
      CapyVecCopy(&(that->gdValue->in), &(that->env->paramValue));

      // Step the gradient descent on action parameters
      ((LossAction*)(that->gdAction->objFun))->transition = transition;
      ((LossAction*)(that->gdAction->objFun))->advantage = advantage;
      that->gdAction->learnRate =
        that->learnRateAction * CapyPowf(that->discount, iStep);
      $(that->gdAction, step)();
      CapyVecCopy(&(that->gdAction->in), &(that->env->paramAction));
    }
  }

  // Update the average of reward
  that->avgReward /= (double)nbEpisode;
  that->avgFinalReward /= (double)nbEpisode;

  // Free memory
  $(&recorder, destruct)();
}

// Learn the weights of action probabilities and state value functions using
// the proximal policy optimisation algorithm
// Inputs:
//   nbEpisode: number of training episode
// Output:
//   The environment's action probabilities parameters and state values
//   parameters are updated.
static void ProximalPolicyOptimisation(size_t const nbEpisode) {
  methodOf(CapyPolicyGradient);

  // Create a transition recorder to record the trajectories of the training
  // episodes
  CapyPGTransitionRecorder recorder = CapyPGTransitionRecorderCreate();

  // Reset the average reward and number of step
  that->avgFinalReward = 0.0;
  that->avgReward = 0.0;
  that->avgNbStep = 0.0;

  // Set the loss function type
  ((LossAction*)(that->gdAction->objFun))->type = actionLossType_ppoClip;

  // Set the clipping coefficient
  ((LossAction*)(that->gdAction->objFun))->coeffClipping = that->coeffClipping;

  // Variable to memorise the old policy
  CapyVec oldParamAction = CapyVecCreate(that->env->paramAction.dim);
  ((LossAction*)(that->gdAction->objFun))->oldParam = oldParamAction.vals;

  // Loop on the training episodes
  loop(iEpisode, nbEpisode) {

    // Update the old policy
    CapyVecCopy(&(that->env->paramAction), &oldParamAction);

    // Select a starting state
    $(that->env, setToInitialState)();

    // Record a new trajectory with the old policy
    $(&recorder, reset)();
    bool isEndState = $(that->env, isEndState)();
    size_t nbStep = 0;
    double* paramActionVals = that->env->paramAction.vals;
    that->env->paramAction.vals = oldParamAction.vals;
    double avgReward = 0.0;
    while(isEndState == false && nbStep < that->nbMaxStep) {
      size_t const action = $(that->env, getAction)(&(that->env->curState));
      CapyPGTransition transition = $(that->env, step)(action);
      $(&recorder, addTransition)(&transition);
      avgReward += transition.reward;
      isEndState = $(that->env, isEndState)();
      nbStep += 1;
    }
    that->env->paramAction.vals = paramActionVals;

    // Update the average of number of step
    that->avgNbStep += (double)(recorder.nbTransition) / (double)nbEpisode;

    // Update the average of final reward
    that->avgReward += avgReward;
    that->avgFinalReward +=
      recorder.transitions[recorder.nbTransition - 1].reward;

    // Variable to memorise the sum of reward
    double sumReward = 0.0;

    // Loop on the trajectory backward
    loop(jStep, recorder.nbTransition) {
      size_t const iStep = recorder.nbTransition - 1 - jStep;
      CapyPGTransition const* const transition =
        recorder.transitions + iStep;

      // Update the sum of reward
      sumReward = transition->reward + that->discount * sumReward;

      // Get the sum of reward corrected with the value estimation as a baseline
      double const value = $(that->env, getValue)(&(transition->fromState));
      double const advantage = value - sumReward;

      // Step the gradient descent on value parameters
      ((LossValue*)(that->gdValue->objFun))->transition = transition;
      ((LossValue*)(that->gdValue->objFun))->sumReward = sumReward;
      that->gdValue->learnRate = that->learnRateState;
      $(that->gdValue, step)();
      CapyVecCopy(&(that->gdValue->in), &(that->env->paramValue));

      // Step the gradient descent on action parameters
      ((LossAction*)(that->gdAction->objFun))->transition = transition;
      ((LossAction*)(that->gdAction->objFun))->advantage = advantage;
      that->gdAction->learnRate =
        that->learnRateAction * CapyPowf(that->discount, iStep);
      $(that->gdAction, step)();
      CapyVecCopy(&(that->gdAction->in), &(that->env->paramAction));
    }
  }

  // Update the average of reward
  that->avgReward /= (double)nbEpisode;

  // Free memory
  $(&recorder, destruct)();
  CapyVecDestruct(&oldParamAction);
}

// Free the memory used by a CapyPolicyGradient
static void Destruct(void) {
  methodOf(CapyPolicyGradient);
  LossActionFree((LossAction**)&(that->gdAction->objFun));
  LossValueFree((LossValue**)&(that->gdAction->objFun));
  CapyGradientDescentFree(&(that->gdAction));
  CapyGradientDescentFree(&(that->gdValue));
}

// Create a CapyPolicyGradient
// Inputs:
//   env: the environment to train
// Output:
//   Return a CapyPolicyGradient
CapyPolicyGradient CapyPolicyGradientCreate(CapyPGEnvironment* const env) {
  CapyPolicyGradient that = {
    .env = env,
    .learnRateAction = 0.01,
    .learnRateState = 0.01,
    .discount = 0.9,
    .nbMaxStep = 1000,
    .avgReward = 0.0,
    .avgNbStep = 0.0,
    .coeffClipping = 0.2,
    .destruct = Destruct,
    .reinforce = Reinforce,
    .proximalPolicyOptimisation = ProximalPolicyOptimisation,
  };
  LossAction* lossAction = LossActionAlloc(env);
  LossValue* lossValue = LossValueAlloc(env);
  that.gdAction =
    CapyGradientDescentAlloc((CapyMathFun*)lossAction, env->paramAction.vals);
  $(that.gdAction, setType)(capyGradientDescent_adam);
  that.gdValue =
    CapyGradientDescentAlloc((CapyMathFun*)lossValue, env->paramValue.vals);
  return that;
}

// Allocate memory for a new CapyPolicyGradient and create it
// Inputs:
//   env: the environment to train
// Output:
//   Return a CapyPolicyGradient
// Exception:
//   May raise CapyExc_MallocFailed.
CapyPolicyGradient* CapyPolicyGradientAlloc(CapyPGEnvironment* const env) {
  CapyPolicyGradient* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyPolicyGradientCreate(env);
  return that;
}

// Free the memory used by a CapyPolicyGradient* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyPolicyGradient to free
void CapyPolicyGradientFree(CapyPolicyGradient** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}
