// -------------------------- markovdecisionprocess.c -------------------------
/*
    LibCapy - a general purpose library of C functions and data structures
    Copyright (C) 2021-2025 Pascal Baillehache info@baillehachepascal.dev
    https://baillehachepascal.dev
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.
    You should have received a copy of the GNU General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "markovdecisionprocess.h"
#include "range.h"

// Get the action for a given state
// Input:
//   state: the state
// Output:
//   Return the action
static size_t GetAction(size_t const state) {
  methodOf(CapyMDPPolicy);
  return that->actions[state];
}

// Get the probability that a given action is selected given a state
// Input:
//   state: the state
//   action: the action
// Output:
//   Return the probability in [0,1]
static double GetProbAction(
  size_t const state,
  size_t const action) {
  methodOf(CapyMDPPolicy);
  if(state >= that->nbState) {
    raiseExc(CapyExc_InvalidStateIdx);
    return 0.0;
  }
  return (that->actions[state] == action ? 1.0 : 0.0);
}

// Free the memory used by a CapyMDPPolicy
static void DestructPolicy(void) {
  methodOf(CapyMDPPolicy);
  free(that->values);
  free(that->actions);
}

// Create a CapyMDPPolicy
// Input:
//   nbState: the number of state
// Output:
//   Return a CapyMDPPolicy
CapyMDPPolicy CapyMDPPolicyCreate(size_t const nbState) {
  CapyMDPPolicy that = {
    .nbState = nbState,
    .destruct = DestructPolicy,
    .getAction = GetAction,
    .getProbAction = GetProbAction,
  };
  safeMalloc(that.values, that.nbState);
  safeMalloc(that.actions, that.nbState);
  loop(iState, nbState) {
    that.values[iState] = 0.0;
    that.actions[iState] = 0;
  }
  return that;
}

// Allocate memory for a new CapyMDPPolicy and create it
// Input:
//   nbState: the number of state
// Output:
//   Return a CapyMDPPolicy
// Exception:
//   May raise CapyExc_MallocFailed.
CapyMDPPolicy* CapyMDPPolicyAlloc(size_t const nbState) {
  CapyMDPPolicy* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyMDPPolicyCreate(nbState);
  return that;
}

// Free the memory used by a CapyMDPPolicy* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyMDPPolicy to free
void CapyMDPPolicyFree(CapyMDPPolicy** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}

// Get the action for a given state
// Input:
//   state: the state
// Output:
//   Return the action
static size_t GetActionMDPPolicyEpsilonSoft(size_t const state) {
  methodOf(CapyMDPPolicyEpsilonSoft);
  size_t action = that->actions[state];
  double const p = $(&(that->rng), getDouble)();
  if(p < that->epsilon) {
    CapyRangeSize range = {.min = 0, .max = that->nbAction - 1};
    do {
      action = $(&(that->rng), getSizeRange)(&range);
    } while(action == that->actions[state]);
  }
  return action;
}

// Get the probability that a given action is selected given a state
// Input:
//   state: the state
//   action: the action
// Output:
//   Return the probability in [0,1]
static double GetProbActionMDPPolicyEpsilonSoft(
  size_t const state,
  size_t const action) {
  methodOf(CapyMDPPolicyEpsilonSoft);
  if(state >= that->nbState) {
    raiseExc(CapyExc_InvalidStateIdx);
    return 0.0;
  }
  double prob = 0.0;
  if(that->nbAction == 1) {
    prob = (action == 0 ? 1.0 : 0.0);
  } else {
    if(that->actions[state] == action) {
      prob = 1.0 - that->epsilon;
    } else {
      prob = that->epsilon / (double)(that->nbAction - 1);
    }
  }
  return prob;
}

// Free the memory used by a CapyMDPPolicy
static void DestructPolicyEpsilonSoft(void) {
  methodOf(CapyMDPPolicyEpsilonSoft);
  $(that, destructCapyMDPPolicy)();
  $(&(that->rng), destruct)();
}

// Create a new CapyMDPPolicyEpsilonSoft
// Input:
//   nbState: the number of state
//   nbAction: the number of action
//   epsilon: the epsilon constant for the action selection
// Output:
//   Return a CapyMDPPolicyEpsilonSoft
CapyMDPPolicyEpsilonSoft CapyMDPPolicyEpsilonSoftCreate(
  size_t const nbState,
  size_t const nbAction,
  double const epsilon) {
  CapyMDPPolicyEpsilonSoft that = {0};
  CapyInherits(that, CapyMDPPolicy, (nbState));
  that.nbAction = nbAction;
  that.rng = CapyRandomCreate(0);
  that.getAction = GetActionMDPPolicyEpsilonSoft;
  that.getProbAction = GetProbActionMDPPolicyEpsilonSoft;
  that.destruct = DestructPolicyEpsilonSoft;
  that.epsilon = epsilon;
  return that;
}

// Allocate memory for a new CapyMDPPolicyEpsilonSoft and create it
// Input:
//   nbState: the number of state
//   nbAction: the number of action
//   epsilon: the epsilon constant for the action selection
// Output:
//   Return a CapyMDPPolicyEpsilonSoft
// Exception:
//   May raise CapyExc_MallocFailed.
CapyMDPPolicyEpsilonSoft* CapyMDPPolicyEpsilonSoftAlloc(
  size_t const nbState,
  size_t const nbAction,
  double const epsilon) {
  CapyMDPPolicyEpsilonSoft* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyMDPPolicyEpsilonSoftCreate(nbState, nbAction, epsilon);
  return that;
}

// Free the memory used by a CapyMDPPolicyEpsilonSoft* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyMDPPolicyEpsilonSoft to free
void CapyMDPPolicyEpsilonSoftFree(CapyMDPPolicyEpsilonSoft** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}

// Get the result action for a given state
// Input:
//   fromState: the 'from' state
//   action: the applied action
// Output:
//   Return the result state
static size_t StepEnvironment(
  size_t const fromState,
  size_t const action) {
  (void)fromState; (void)action;
  raiseExc(CapyExc_UndefinedExecution);
  assert(false && "MDPEnvironment.step is undefined.");
  return 0;
}

// Free the memory used by a CapyMDPEnvironment
static void DestructEnvironment(void) {

  // Nothing to do
}

// Create a CapyMDPEnvironment
// Output:
//   Return a CapyMDPEnvironment
CapyMDPEnvironment CapyMDPEnvironmentCreate(void) {
  CapyMDPEnvironment that = {
    .destruct = DestructEnvironment,
    .step = StepEnvironment,
  };
  return that;
}

// Allocate memory for a new CapyMDPEnvironment and create it
// Output:
//   Return a CapyMDPEnvironment
// Exception:
//   May raise CapyExc_MallocFailed.
CapyMDPEnvironment* CapyMDPEnvironmentAlloc(void) {
  CapyMDPEnvironment* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyMDPEnvironmentCreate();
  return that;
}

// Free the memory used by a CapyMDPEnvironment* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyMDPEnvironment to free
void CapyMDPEnvironmentFree(CapyMDPEnvironment** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}

// Free the memory used by a CapyMarkovDecisionProcess
static void Destruct(void) {
  methodOf(CapyMarkovDecisionProcess);
  free(that->transitions);
  free(that->flagStartStates);
  free(that->flagEndStates);
  $(&(that->rng), destruct)();
  $(&(that->optimalPolicy), destruct)();
}

// Get a transition
// Input:
//   fromState: index of the origin state
//   action: index of the action
//   toState: index of the termination state
// Output:
//   Return a reference to the transition
static CapyMDPTransition* GetTransition(
  size_t const fromState,
  size_t const action,
  size_t const toState) {
  methodOf(CapyMarkovDecisionProcess);
  if(fromState >= that->nbState || toState >= that->nbState) {
    raiseExc(CapyExc_InvalidStateIdx);
    return NULL;
  }
  if(action >= that->nbAction) {
    raiseExc(CapyExc_InvalidActionIdx);
    return NULL;
  }
  CapyMDPTransition* const transition =
    that->transitions +
    (fromState * that->nbAction + action) * that->nbState + toState;
  return transition;
}

// Free the memory used by a transition recorder
static void DestructTransitionRecorder(void) {
  methodOf(CapyMDPTransitionRecorder);
  free(that->transitions);
}

// Record one transition
// Input:
//   transition: the transition to be recorded
// Output:
//   A copy of the transition is added to the end of 'transitions' which
//   is realloced if necessary, 'nbTransition' and 'nbMaxTransition' are
//   updated as necessary.
static void AddTransition(CapyMDPTransition const* const transition) {
  methodOf(CapyMDPTransitionRecorder);
  if(that->nbTransition == that->nbMaxTransition) {
    that->nbMaxTransition *= 2;
    safeRealloc(that->transitions, that->nbMaxTransition);
  }
  that->transitions[that->nbTransition] = *transition;
  that->nbTransition += 1;
}

// Reset the recorder
// Output:
//   'nbTransition' is reset to 0.
static void ResetTransitionRecorder(void) {
  methodOf(CapyMDPTransitionRecorder);
  that->nbTransition = 0;
}

// Create a CapyMDPTransitionRecorder
// Input:
//   nbState: the number of state
// Output:
//   Return a CapyMDPTransitionRecorder
CapyMDPTransitionRecorder CapyMDPTransitionRecorderCreate(void) {
  CapyMDPTransitionRecorder that = {
    .nbTransition = 0,
    .nbMaxTransition = 1024,
    .transitions = NULL,
    .destruct = DestructTransitionRecorder,
    .addTransition = AddTransition,
    .reset = ResetTransitionRecorder,
  };
  safeMalloc(that.transitions, that.nbMaxTransition);
  return that;
}

// Set the current state
// Input:
//   state: index of the current state
// Output:
//   The current state is set and the number of step is reset
static void SetCurState(size_t const state) {
  methodOf(CapyMarkovDecisionProcess);
  that->curState = state;
  that->nbStep = 0;
}

// Get the current state
// Output:
//   Return the index of the current state
static size_t GetCurState(void) {
  methodOf(CapyMarkovDecisionProcess);
  return that->curState;
}

// Get the number of step
// Output:
//   Return the number of step
static size_t GetNbStep(void) {
  methodOf(CapyMarkovDecisionProcess);
  return that->nbStep;
}

// Step the MDP according to its transitions definition
// Output:
//   The current state and the number of step are updated. Return the
//   transition.
static CapyMDPTransition* Step(void) {
  methodOf(CapyMarkovDecisionProcess);

  // Variable to memorise the selected transition
  CapyMDPTransition* transition = NULL;

  // Calculate the sum of probabilities of transition from the current state
  double sum = 0.0;
  loop(action, that->nbAction) loop(toState, that->nbState) {
    CapyMDPTransition const* const trans =
      $(that, getTransition)(that->curState, action, toState);
    sum += trans->prob;
  }

  // Get the random value to select the transition
  CapyRangeDouble range = {.vals = {0.0, sum}};
  double const x = $(&(that->rng), getDoubleRange)(&range);

  // Get the selected transition
  sum = 0.0;
  loop(action, that->nbAction) loop(toState, that->nbState) {
    CapyMDPTransition* const trans =
      $(that, getTransition)(that->curState, action, toState);
    if(sum + trans->prob >= x) {
      transition = trans;
      action = that->nbAction;
      toState = that->nbState;
    }
    sum += trans->prob;
  }
  if(transition == NULL) {
    raiseExc(CapyExc_UndefinedExecution);
    assert(false);
  }

  // Update the current state and number of step
  that->curState = transition->toState;
  that->nbStep += 1;

  // Return the selected transition
  return transition;
}

// Step the MDP according to a given policy
// Output:
//   The current state and the number of step are updated. Return the
//   transition. If the MDP's environment is known it is used to get the
//   result state.
static CapyMDPTransition* StepPolicy(
  CapyMDPPolicy const* const policy) {
  methodOf(CapyMarkovDecisionProcess);

  // Variable to memorise the selected transition
  CapyMDPTransition* transition = NULL;

  // Get the action according to the policy
  size_t const action = $(policy, getAction)(that->curState);

  // If the environment is known
  if(that->environment != NULL) {

    // Get the result action according to the environment and correct the
    // selected transition
    size_t const toState = $(that->environment, step)(that->curState, action);
    transition = $(that, getTransition)(that->curState, action, toState);

  // Else, the environment is not known
  } else {

    // Calculate the sum of probabilities of transition from the current state
    // for the policy action
    double sum = 0.0;
    loop(toState, that->nbState) {
      CapyMDPTransition const* const trans =
        $(that, getTransition)(that->curState, action, toState);
      sum += trans->prob;
    }

    // Get the random value to select the transition
    CapyRangeDouble range = {.vals = {0.0, sum}};
    double const x = $(&(that->rng), getDoubleRange)(&range);

    // Get the selected transition
    sum = 0.0;
    loop(toState, that->nbState) {
      CapyMDPTransition* const trans =
        $(that, getTransition)(that->curState, action, toState);
      if(sum + trans->prob >= x) {
        transition = trans;
        toState = that->nbState;
      }
      sum += trans->prob;
    }
  }
  if(transition == NULL) {
    raiseExc(CapyExc_UndefinedExecution);
    assert(false);
  }

  // Update the current state and number of step
  that->curState = transition->toState;
  that->nbStep += 1;

  // Return the selected transition
  return transition;
}

// Initialise the pseudo random generator
// Input:
//   seed: the seed
// Output:
//   The pseudo random generator is reset.
static void ResetRng(CapyRandomSeed_t const seed) {
  methodOf(CapyMarkovDecisionProcess);
  $(&(that->rng), destruct)();
  that->rng = CapyRandomCreate(seed);
}

// Search the optimal policy (given that the MDP's transitions are all set
// with the correct transitions probabilities and rewards)
// Output:
//   Calculate the optimal policy, update 'optimalPolicy' which is also
//   used as the initial policy for the search
static void SearchOptimalPolicy(void) {
  methodOf(CapyMarkovDecisionProcess);

  // Create a buffer CapyMDPPolicy for computation and initialise it with the
  // current optimal policy
  CapyMDPPolicy bufferPolicy = CapyMDPPolicyCreate(that->nbState);
  loop(iState, that->nbState) {
    bufferPolicy.values[iState] = that->optimalPolicy.values[iState];
    bufferPolicy.actions[iState] = that->optimalPolicy.actions[iState];
  }

  // Variables for convergence condition
  double deltaValue = 2.0 * that->epsilon;
  bool flagPolicyChanged = true;

  // Loop until convergence
  size_t nbStep = 0;
  while(
    (deltaValue > that->epsilon || flagPolicyChanged) &&
    nbStep < that->nbMaxStep
  ) {
    nbStep += 1;
    deltaValue = 0.0;
    flagPolicyChanged = false;

    // Update the states' value
    // In "policy iteration" we would loop over this update with the same
    // policy until the values converge.
    // In "value iteration" we do this update only once between each policy
    // update.
    loop(iState, that->nbState) {
      bufferPolicy.values[iState] = 0.0;
      size_t const iAction = that->optimalPolicy.actions[iState];
      loop(jState, that->nbState) {
        CapyMDPTransition const* const transition =
          $(that, getTransition)(iState, iAction, jState);
        bufferPolicy.values[iState] +=
          transition->prob * (
            transition->reward +
            that->discount * that->optimalPolicy.values[jState]);
      }
      double const delta =
        fabs(bufferPolicy.values[iState] - that->optimalPolicy.values[iState]);
      if(deltaValue < delta) deltaValue = delta;
    }

    // Update the states' optimal action
    loop(iState, that->nbState) {
      size_t bestAction = 0;
      double bestValue = 0.0;
      loop(iAction, that->nbAction) {
        double value = 0.0;
        loop(jState, that->nbState) {
          CapyMDPTransition const* const transition =
            $(that, getTransition)(iState, iAction, jState);
          value +=
            transition->prob * (
              transition->reward +
              that->discount * bufferPolicy.values[jState]);
        }
        if(iAction == 0 || bestValue + 1e-9 < value) {
          bestAction = iAction;
          bestValue = value;
        }
      }
      flagPolicyChanged |= (that->optimalPolicy.actions[iState] != bestAction);
      bufferPolicy.actions[iState] = bestAction;
    }

    // Copy the current optimal policy into the buffer one
    loop(iState, that->nbState) {
      that->optimalPolicy.values[iState] = bufferPolicy.values[iState];
      that->optimalPolicy.actions[iState] = bufferPolicy.actions[iState];
    }
  }
  if(nbStep >= that->nbMaxStep) raiseExc(CapyExc_InfiniteLoop);

  // Free the buffer policy
  $(&bufferPolicy, destruct)();
}

// Record a trajectory through the MDP given an initial state and a policy
// Input:
//   recorder: the recorder
//   startState: the initial state of the trajectory
//   policy: the policy used to select transitions
// Output:
//   The recorder is reset and updated with the trajectory. The trajectory
//   stops when encountering an end state, or when it reaches
//   'that->nbMaxStep'. The current state of the MDP is modified.
static void RecordTrajectory(
  CapyMDPTransitionRecorder* const recorder,
                      size_t const startState,
        CapyMDPPolicy const* const policy) {
  methodOf(CapyMarkovDecisionProcess);

  // Reset the recorder
  $(recorder, reset)();

  // Set the initial state
  $(that, setCurState)(startState);

  // Loop until we reach an end state or the maximum number of state
  size_t nbStep = 0;
  while(
    that->flagEndStates[that->curState] == false &&
    nbStep < that->nbMaxStep
  ) {

    // Get the transition according to the policy
    CapyMDPTransition const* transition = $(that, stepPolicy)(policy);

    // Memorise the transition
    $(recorder, addTransition)(transition);
    nbStep += 1;
  }
  if(nbStep >= that->nbMaxStep) {
    raiseExc(CapyExc_InfiniteLoop);
    return;
  }
}

// Update a policy given the current action value of the MDP's transitions and
// a given 'from' state (greedy version)
static void UpdatePolicyGreedy(
  CapyMarkovDecisionProcess const* const that,
                            size_t const fromState,
                    CapyMDPPolicy* const policy) {

  // Variables to memorize the best action
  size_t bestAction = 0;
  double bestValue = 0.0;
  bool firstUpdate = true;

  // Loop on the actions
  loop(iAction, that->nbAction) {

    // Variable to calculate the weighted value over all 'to' state
    double value = 0.0;

    // Loop on the "to" state
    loop(toState, that->nbState) {

      // Get the transition for (fromState, action, toState)
      CapyMDPTransition const* const transition =
        $(that, getTransition)(fromState, iAction, toState);

      // Update the average of return
      if(transition->nbVisit > 0) {
        value +=
          transition->value *
          (double)(transition->nbOccurence) / (double)(transition->nbVisit);
      }
    }

    // If the transition's value is better than the current best
    // value (break tie randomly)
    if(firstUpdate) {
      bestAction = iAction;
      bestValue = value;
      firstUpdate = false;
    } else {
      bool flagUpdate = false;
      if(fabs(bestValue - value) < 1e-18) {
        flagUpdate = ($(&(that->rng), getDouble)() < 0.5);
      } else {
        flagUpdate = (bestValue < value);
      }
      if(flagUpdate) {
        bestAction = iAction;
        bestValue = value;
      }
    }
  }

  // Update the policy with the best action
  policy->values[fromState] = bestValue;
  policy->actions[fromState] = bestAction;
}

// Get the max action value for a given state
// Input:
//   that: the MDP
//   state: the state
// Output:
//   Return the max average value per action from 'state'
static double GetActionMaxValue(
  CapyMarkovDecisionProcess const* const that,
                            size_t const state) {
  double maxValue = 0.0;
  bool firstUpdate = true;
  loop(action, that->nbAction) {
    double value = 0.0;
    bool flagVisited = false;
    loop(toState, that->nbState) {
      CapyMDPTransition const* const transition =
        $(that, getTransition)(state, action, toState);
      if(transition->nbVisit > 0) {
        value +=
          transition->value *
          (double)(transition->nbOccurence) / (double)(transition->nbVisit);
        flagVisited = true;
      }
    }
    if(flagVisited) if(firstUpdate || value > maxValue) {
      maxValue = value;
      firstUpdate = false;
    }
  }
  return maxValue;
}

// Search the optimal policy using Q-Learning (converges to the optimal
// policy by exploring the environment instead of using transitions
// probabilities, only needs the transition rewards; uses an
// epsilon-soft policy to explore the transitions)
// Input:
//   epsilon: exploration coefficient (in ]0, 1])
//   alpha: learning rate (in ]0, 1])
//   nbEpisode: number of training episodes
// Output:
//   Calculate the optimal policy, update 'optimalPolicy' which is also
//   used as the initial policy for the search.
static void QLearning(
  double const epsilon,
  double const alpha,
  size_t const nbEpisode) {
  methodOf(CapyMarkovDecisionProcess);

  // If the environment is not set the search can't be performed
  if(that->environment == NULL) {
    raiseExc(CapyExc_InvalidParameters);
    return;
  }

  // Reset the counters in the transition according to their probability
  loop(fromState, that->nbState) loop(action, that->nbAction) {
    loop(toState, that->nbState) {
      CapyMDPTransition* const transition =
        $(that, getTransition)(fromState, action, toState);
      transition->nbOccurence = 0;
      transition->nbVisit = 0;
    }
  }
  loop(fromState, that->nbState) loop(action, that->nbAction) {
    loop(toState, that->nbState) {
      CapyMDPTransition* const transition =
        $(that, getTransition)(fromState, action, toState);
      if(transition->prob > 0.0) {
        transition->nbOccurence = (size_t)(transition->prob * 100.0);
        loop(toOtherState, that->nbState) {
          CapyMDPTransition* const otherTransition =
            $(that, getTransition)(fromState, action, toOtherState);
          otherTransition->nbVisit = 100;
        }
      }
    }
  }

  // Create an epsilon soft policy for training, initialised with the
  // current optimal policy
  CapyMDPPolicyEpsilonSoft softPolicy =
    CapyMDPPolicyEpsilonSoftCreate(that->nbState, that->nbAction, epsilon);
  loop(iState, that->nbState) {
    softPolicy.values[iState] = that->optimalPolicy.values[iState];
    softPolicy.actions[iState] = that->optimalPolicy.actions[iState];
  }

  // Loop on the episodes
  loop(iEpisode, nbEpisode) {

    // Get a random initial state
    size_t const startState = $(that, getRndStartState)();

    // Set the initial state
    $(that, setCurState)(startState);

    // Loop until we reach an end state or the maximum number of state
    size_t nbStep = 0;
    while(
      that->flagEndStates[that->curState] == false &&
      nbStep < that->nbMaxStep
    ) {
      nbStep += 1;

      // Get the transition according to the soft policy
      CapyMDPTransition* transition =
        $(that, stepPolicy)((CapyMDPPolicy*)&softPolicy);

      // Update the counters
      transition->nbOccurence += 1;
      loop(toState, that->nbState) {
        CapyMDPTransition* const otherTransition =
          $(that, getTransition)(
            transition->fromState, transition->action, toState);
        otherTransition->nbVisit += 1;
      }

      // Get the max action value for the result state
      double const maxNextActionValue =
        GetActionMaxValue(that, transition->toState);

      // Update the action value
      transition->value += alpha * (
        transition->reward +
        that->discount * maxNextActionValue -
        transition->value);

      // Update the soft policy
      UpdatePolicyGreedy(
        that, transition->fromState, (CapyMDPPolicy*)&softPolicy);

      // Update the current state
      that->curState = transition->toState;
    }
  }

  // Update the optimal policy with the soft policy
  loop(iState, that->nbState) {
    that->optimalPolicy.values[iState] = softPolicy.values[iState];
    that->optimalPolicy.actions[iState] = softPolicy.actions[iState];
  }

  // Update the transitions probability according to their counter
  loop(fromState, that->nbState) loop(action, that->nbAction) {
    loop(toState, that->nbState) {
      CapyMDPTransition* const transition =
        $(that, getTransition)(fromState, action, toState);
      if(transition->nbVisit > 0) {
        transition->prob =
          (double)(transition->nbOccurence) / (double)(transition->nbVisit);
      } else {
        transition->prob = 0.0;
      }
    }
  }

  // Free memory
  $(&softPolicy, destruct)();
}

// Get the expected sum of reward
// Input:
//   nbRun: number of run used to calculate the expected reward
// Output:
//   Return the expected sum of reward, or 0.0 and raise
//   CapyExc_UndefinedExecution if the MDP can't reach an end state within
//   that->nbMaxIter. The start state is selected at random. Randomly
//   select the transitions according to their probabilities.
static double GetExpReward(size_t const nbRun) {
  methodOf(CapyMarkovDecisionProcess);
  double expSumReward = 0.0;
  double const avgCoeff = 1.0 / (double)nbRun;
  loop(iRun, nbRun) {
    size_t const fromState = $(that, getRndStartState)();
    $(that, setCurState)(fromState);
    double sumReward = 0.0;
    while(
      that->flagEndStates[that->curState] == false &&
      that->nbStep < that->nbMaxStep
    ) {
      CapyMDPTransition* const trans = $(that, step)();
      sumReward += trans->reward;
      that->nbStep += 1;
    }
    if(that->nbStep >= that->nbMaxStep) {
      raiseExc(CapyExc_InfiniteLoop);
      return 0.0;
    }
    expSumReward += sumReward * avgCoeff;
  }
  return expSumReward;
}

// Get the expected sum of reward from a given start state
// Input:
//   fromState: the start state
//   nbRun: number of run used to calculate the expected reward
// Output:
//   Return the expected sum of reward, or 0.0 and raise
//   CapyExc_UndefinedExecution if the MDP can't reach an end state within
//   that->nbMaxIter. Randomly select the transitions according to their
//   probabilities.
static double GetExpRewardFromState(
  size_t const fromState,
  size_t const nbRun) {
  methodOf(CapyMarkovDecisionProcess);
  double expSumReward = 0.0;
  double const avgCoeff = 1.0 / (double)nbRun;
  loop(iRun, nbRun) {
    $(that, setCurState)(fromState);
    double sumReward = 0.0;
    while(
      that->flagEndStates[that->curState] == false &&
      that->nbStep < that->nbMaxStep
    ) {
      CapyMDPTransition* const trans = $(that, step)();
      sumReward += trans->reward;
      that->nbStep += 1;
    }
    if(that->nbStep >= that->nbMaxStep) {
      raiseExc(CapyExc_InfiniteLoop);
      return 0.0;
    }
    expSumReward += sumReward * avgCoeff;
  }
  return expSumReward;
}

// Get the expected sum of reward using a given policy
// Input:
//   nbRun: number of run used to calculate the expected reward
//   policy: the policy
// Output:
//   Return the expected sum of reward, or 0.0 and raise
//   CapyExc_UndefinedExecution if the MDP can't reach an end state within
//   that->nbMaxIter. The start state is selected at random. Select the
//   transitions according to the given policy.
static double GetExpRewardForPolicy(
                size_t const nbRun,
  CapyMDPPolicy const* const policy) {
  methodOf(CapyMarkovDecisionProcess);
  double expSumReward = 0.0;
  double const avgCoeff = 1.0 / (double)nbRun;
  loop(iRun, nbRun) {
    size_t const fromState = $(that, getRndStartState)();
    $(that, setCurState)(fromState);
    double sumReward = 0.0;
    while(
      that->flagEndStates[that->curState] == false &&
      that->nbStep < that->nbMaxStep
    ) {
      CapyMDPTransition* const trans =
        $(that, stepPolicy)(policy);
      sumReward += trans->reward;
      that->nbStep += 1;
    }
    if(that->nbStep >= that->nbMaxStep) {
      raiseExc(CapyExc_InfiniteLoop);
      return 0.0;
    }
    expSumReward += sumReward * avgCoeff;
  }
  return expSumReward;
}

// Get the expected sum of reward from a given start state using a given
// policy
// Input:
//   fromState: the start state
//   nbRun: number of run used to calculate the expected reward
//   policy: the policy
// Output:
//   Return the expected sum of reward, or 0.0 and raise
//   CapyExc_UndefinedExecution if the MDP can't reach an end state within
//   that->nbMaxIter. Select the transitions according to the given policy.
static double GetExpRewardFromStateForPolicy(
                size_t const fromState,
                size_t const nbRun,
  CapyMDPPolicy const* const policy) {
  methodOf(CapyMarkovDecisionProcess);
  double expSumReward = 0.0;
  double const avgCoeff = 1.0 / (double)nbRun;
  loop(iRun, nbRun) {
    $(that, setCurState)(fromState);
    double sumReward = 0.0;
    while(
      that->flagEndStates[that->curState] == false &&
      that->nbStep < that->nbMaxStep
    ) {
      CapyMDPTransition* const trans =
        $(that, stepPolicy)(policy);
      sumReward += trans->reward;
      that->nbStep += 1;
    }
    if(that->nbStep >= that->nbMaxStep) {
      raiseExc(CapyExc_InfiniteLoop);
      return 0.0;
    }
    expSumReward += sumReward * avgCoeff;
  }
  return expSumReward;
}

// Get a random start state
// Output:
//   Return one of the start states. If there are no start states, return 0 by
//   default.
static size_t GetRndStartState(void) {
  methodOf(CapyMarkovDecisionProcess);

  // Count the start states
  size_t nbStartState = 0;
  loop(iState, that->nbState) {
    nbStartState += (that->flagStartStates[iState] ? 1 : 0);
  }

  // If there are no start states, return 0 by default
  if(nbStartState == 0) return 0;

  // Pick one at random and return it
  CapyRangeSize const range = {.min = 1, .max = nbStartState};
  size_t idxStartState = $(&(that->rng), getSizeRange)(&range);
  size_t jState = 0;
  loop(iState, that->nbState) {
    jState += (that->flagStartStates[iState] ? 1 : 0);
    if(idxStartState == jState) return iState;
  }

  // Should never reach here
  raiseExc(CapyExc_UndefinedExecution);
  return 0;
}

// Create a CapyMarkovDecisionProcess
// Input:
//   nbState: the number of state
//   nbAction: the number of action
// Output:
//   Return a CapyMarkovDecisionProcess
CapyMarkovDecisionProcess CapyMarkovDecisionProcessCreate(
  size_t const nbState,
  size_t const nbAction) {
  CapyMarkovDecisionProcess that = {
    .nbState = nbState,
    .nbAction = nbAction,
    .nbTransition = nbState * nbAction * nbState,
    .curState = 0,
    .nbStep = 0,
    .nbMaxStep = 1e9,
    .rng = CapyRandomCreate((CapyRandomSeed_t)time(NULL)),
    .optimalPolicy = CapyMDPPolicyCreate(nbState),
    .discount = 0.9,
    .epsilon = 1e-6,
    .flagEveryVisit = false,
    .environment = NULL,
    .destruct = Destruct,
    .getTransition = GetTransition,
    .setCurState = SetCurState,
    .getCurState = GetCurState,
    .getNbStep = GetNbStep,
    .step = Step,
    .stepPolicy = StepPolicy,
    .resetRng = ResetRng,
    .searchOptimalPolicy = SearchOptimalPolicy,
    .recordTrajectory = RecordTrajectory,
    .qLearning = QLearning,
    .getExpReward = GetExpReward,
    .getExpRewardForPolicy = GetExpRewardForPolicy,
    .getExpRewardFromState = GetExpRewardFromState,
    .getExpRewardFromStateForPolicy = GetExpRewardFromStateForPolicy,
    .getRndStartState = GetRndStartState,
  };
  safeMalloc(that.transitions, that.nbTransition);
  safeMalloc(that.flagStartStates, that.nbState);
  safeMalloc(that.flagEndStates, that.nbState);
  loop(iState, nbState) {
    that.flagStartStates[iState] = false;
    that.flagEndStates[iState] = false;
  }
  loop(fromState, nbState) loop(action, nbAction) loop(toState, nbState) {
    size_t const iTrans = (fromState * nbAction + action) * nbState + toState;
    CapyMDPTransition trans = {
      .fromState = fromState,
      .action = action,
      .toState = toState,
      .prob = 0.0,
      .reward = 0.0,
      .value = 0.0,
    };
    memcpy(that.transitions + iTrans, &trans, sizeof(that.transitions[0]));
  }
  return that;
}

// Allocate memory for a new CapyMarkovDecisionProcess and create it
// Input:
//   nbState: the number of state
//   nbAction: the number of action
// Output:
//   Return a CapyMarkovDecisionProcess
// Exception:
//   May raise CapyExc_MallocFailed.
CapyMarkovDecisionProcess* CapyMarkovDecisionProcessAlloc(
  size_t const nbState,
  size_t const nbAction) {
  CapyMarkovDecisionProcess* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyMarkovDecisionProcessCreate(nbState, nbAction);
  return that;
}

// Free the memory used by a CapyMarkovDecisionProcess* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyMarkovDecisionProcess to free
void CapyMarkovDecisionProcessFree(CapyMarkovDecisionProcess** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}
