#include "capy.h"
#ifndef FIXTURE
#define FIXTURE

// Dummy state for test
typedef enum DummyState {
  dummyStatePlaying,
  dummyStateEnd,
  dummyStateNb,
} DummyState;

// Dummy action for test
typedef enum DummyAction {
  dummyActionContinue,
  dummyActionQuit,
  dummyActionNb,
} DummyAction;

static CapyMarkovDecisionProcess* CreateMDPDummy1(void) {
  size_t const nbState = dummyStateNb;
  size_t const nbAction = dummyActionNb;
  CapyMarkovDecisionProcess* mdp =
    CapyMarkovDecisionProcessAlloc(nbState, nbAction);
  mdp->flagStartStates[dummyStatePlaying] = true;
  mdp->flagEndStates[dummyStateEnd] = true;
  double const probs[dummyStateNb][dummyActionNb][dummyStateNb] = {
    [dummyStatePlaying] = {
      [dummyActionContinue] = {
        [dummyStatePlaying] = 1.0,
        [dummyStateEnd] = 0.0,
      },
      [dummyActionQuit] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 1.0,
      },
    },
    [dummyStateEnd] = {
      [dummyActionContinue] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 1.0,
      },
      [dummyActionQuit] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 1.0,
      },
    },
  };
  double const rewards[dummyStateNb][dummyActionNb][dummyStateNb] = {
    [dummyStatePlaying] = {
      [dummyActionContinue] = {
        [dummyStatePlaying] = 3.0,
        [dummyStateEnd] = 0.0,
      },
      [dummyActionQuit] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 5.0,
      },
    },
    [dummyStateEnd] = {
      [dummyActionContinue] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 0.0,
      },
      [dummyActionQuit] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 0.0,
      },
    },
  };
  loop(fromState, nbState) loop(action, nbAction) loop(toState, nbState) {
    CapyMDPTransition* const transition =
      $(mdp, getTransition)(fromState, action, toState);
    transition->prob = probs[fromState][action][toState];
    transition->reward = rewards[fromState][action][toState];
  }
  return mdp;
}

static CapyMarkovDecisionProcess* CreateMDPDummy2(void) {
  size_t const nbState = dummyStateNb;
  size_t const nbAction = dummyActionNb;
  CapyMarkovDecisionProcess* mdp =
    CapyMarkovDecisionProcessAlloc(nbState, nbAction);
  mdp->flagStartStates[dummyStatePlaying] = true;
  mdp->flagEndStates[dummyStateEnd] = true;
  double const probs[dummyStateNb][dummyActionNb][dummyStateNb] = {
    [dummyStatePlaying] = {
      [dummyActionContinue] = {
        [dummyStatePlaying] = 0.666666666,
        [dummyStateEnd] = 0.333333334,
      },
      [dummyActionQuit] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 0.0,
      },
    },
    [dummyStateEnd] = {
      [dummyActionContinue] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 1.0,
      },
      [dummyActionQuit] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 1.0,
      },
    },
  };
  double const rewards[dummyStateNb][dummyActionNb][dummyStateNb] = {
    [dummyStatePlaying] = {
      [dummyActionContinue] = {
        [dummyStatePlaying] = 3.0,
        [dummyStateEnd] = 5.0,
      },
      [dummyActionQuit] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 5.0,
      },
    },
    [dummyStateEnd] = {
      [dummyActionContinue] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 0.0,
      },
      [dummyActionQuit] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 0.0,
      },
    },
  };
  loop(fromState, nbState) loop(action, nbAction) loop(toState, nbState) {
    CapyMDPTransition* const transition =
      $(mdp, getTransition)(fromState, action, toState);
    transition->prob = probs[fromState][action][toState];
    transition->reward = rewards[fromState][action][toState];
  }
  return mdp;
}

static CapyMarkovDecisionProcess* CreateMDPDummy3(void) {
  size_t const nbState = dummyStateNb;
  size_t const nbAction = dummyActionNb;
  CapyMarkovDecisionProcess* mdp =
    CapyMarkovDecisionProcessAlloc(nbState, nbAction);
  mdp->flagStartStates[dummyStatePlaying] = true;
  mdp->flagEndStates[dummyStateEnd] = true;
  double const probs[dummyStateNb][dummyActionNb][dummyStateNb] = {
    [dummyStatePlaying] = {
      [dummyActionContinue] = {
        [dummyStatePlaying] = 0.666666666,
        [dummyStateEnd] = 0.333333334,
      },
      [dummyActionQuit] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 1.0,
      },
    },
    [dummyStateEnd] = {
      [dummyActionContinue] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 1.0,
      },
      [dummyActionQuit] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 1.0,
      },
    },
  };
  double const rewards[dummyStateNb][dummyActionNb][dummyStateNb] = {
    [dummyStatePlaying] = {
      [dummyActionContinue] = {
        [dummyStatePlaying] = 3.0,
        [dummyStateEnd] = -10.0,
      },
      [dummyActionQuit] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 5.0,
      },
    },
    [dummyStateEnd] = {
      [dummyActionContinue] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 0.0,
      },
      [dummyActionQuit] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 0.0,
      },
    },
  };
  loop(fromState, nbState) loop(action, nbAction) loop(toState, nbState) {
    CapyMDPTransition* const transition =
      $(mdp, getTransition)(fromState, action, toState);
    transition->prob = probs[fromState][action][toState];
    transition->reward = rewards[fromState][action][toState];
  }
  return mdp;
}

static CapyMarkovDecisionProcess* CreateMDPDummy4(void) {
  size_t const nbState = dummyStateNb;
  size_t const nbAction = dummyActionNb;
  CapyMarkovDecisionProcess* mdp =
    CapyMarkovDecisionProcessAlloc(nbState, nbAction);
  mdp->flagStartStates[dummyStatePlaying] = true;
  mdp->flagEndStates[dummyStateEnd] = true;
  double const probs[dummyStateNb][dummyActionNb][dummyStateNb] = {
    [dummyStatePlaying] = {
      [dummyActionContinue] = {
        [dummyStatePlaying] = 6.0 / 7.0,
        [dummyStateEnd] = 1.0 / 7.0,
      },
      [dummyActionQuit] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 1.0,
      },
    },
    [dummyStateEnd] = {
      [dummyActionContinue] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 1.0,
      },
      [dummyActionQuit] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 1.0,
      },
    },
  };
  double const rewards[dummyStateNb][dummyActionNb][dummyStateNb] = {
    [dummyStatePlaying] = {
      [dummyActionContinue] = {
        [dummyStatePlaying] = 3.0,
        [dummyStateEnd] = -10.0,
      },
      [dummyActionQuit] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 5.0,
      },
    },
    [dummyStateEnd] = {
      [dummyActionContinue] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 0.0,
      },
      [dummyActionQuit] = {
        [dummyStatePlaying] = 0.0,
        [dummyStateEnd] = 0.0,
      },
    },
  };
  loop(fromState, nbState) loop(action, nbAction) loop(toState, nbState) {
    CapyMDPTransition* const transition =
      $(mdp, getTransition)(fromState, action, toState);
    transition->prob = probs[fromState][action][toState];
    transition->reward = rewards[fromState][action][toState];
  }
  return mdp;
}

static CapyMarkovDecisionProcess* CreateMDPGambler(void) {
  size_t const nbState = 101;
  size_t const nbAction = 51;
  CapyMarkovDecisionProcess* mdp =
    CapyMarkovDecisionProcessAlloc(nbState, nbAction);
  loop(iState, nbState) {
    if(iState > 0 && iState < 100) {
      mdp->flagStartStates[iState] = true;
      mdp->flagEndStates[iState] = false;
    } else {
      mdp->flagStartStates[iState] = false;
      mdp->flagEndStates[iState] = true;
    }
  }
  loop(fromState, nbState) loop(action, nbAction) loop(toState, nbState) {
    CapyMDPTransition* const transition =
      $(mdp, getTransition)(fromState, action, toState);
    if(action > fromState || action + fromState > 100) {
      if(toState == 0) {
        transition->prob = 1.0;
        transition->reward = -1.0;
      } else {
        transition->prob = 0.0;
        transition->reward = 0.0;
      }
    } else if(fromState == 0 || fromState == 100) {
      if(toState == fromState) {
        transition->prob = 1.0;
        transition->reward = 0.0;
      } else {
        transition->prob = 0.0;
        transition->reward = 0.0;
      }
    } else if(action == 0) {
      if(toState == fromState) {
        transition->prob = 1.0;
        transition->reward = -0.01;
      } else {
        transition->prob = 0.0;
        transition->reward = 0.0;
      }
    } else if(fromState + action == toState){
      transition->prob = 0.4;
      transition->reward = (toState == 100 ? 1.0 : 0.0);
    } else if(fromState - action == toState){
      transition->prob = 0.6;
      transition->reward = (toState == 0 ? -1.0 : 0.0);
    } else {
      transition->prob = 0.0;
      transition->reward = 0.0;
    }
  }
  return mdp;
}

static CapyMarkovDecisionProcess* CreateMDPGamblerWithRewardOnly(void) {
  size_t const nbState = 101;
  size_t const nbAction = 51;
  CapyMarkovDecisionProcess* mdp =
    CapyMarkovDecisionProcessAlloc(nbState, nbAction);
  loop(iState, nbState) {
    if(iState > 0 && iState < 100) {
      mdp->flagStartStates[iState] = true;
      mdp->flagEndStates[iState] = false;
    } else {
      mdp->flagStartStates[iState] = false;
      mdp->flagEndStates[iState] = true;
    }
  }
  loop(fromState, nbState) loop(action, nbAction) loop(toState, nbState) {
    CapyMDPTransition* const transition =
      $(mdp, getTransition)(fromState, action, toState);
    if(action == 0 || action > fromState || action + fromState > 100) {
      transition->reward = -1.0;
    } else if(toState == 0){
      transition->reward = -1.0;
    } else if(toState == 100){
      transition->reward = 1.0;
    } else {
      transition->reward = 0.0;
    }
  }
  return mdp;
}

// Environment for the Gambler problem
typedef struct GamblerEnvironment {

  // Inherits CapyMDPEnvironment
  struct CapyMDPEnvironmentDef;

  // RNG
  CapyRandom rng;

  // Destructor for the parent class
  void (*destructCapyMDPEnvironment)(void);
} GamblerEnvironment;

// Get the result action for a given state in the gambler environment
// Input:
//   fromState: the 'from' state
//   action: the applied action
// Output:
//   Return the result state
static size_t StepGamblerEnvironment(
  size_t const fromState,
  size_t const action) {
  GamblerEnvironment* that = (GamblerEnvironment*)capyThat;
  if(action == 0 || action > fromState || fromState + action > 100) {
    return 0;
  } else {
    double const rng = $(&(that->rng), getDouble)();
    if(rng < 0.6) {
      return fromState - action;
    } else {
      return fromState + action;
    }
  }
}

// Free the memory used by a CapyMDPEnvironment
static void DestructGamblerEnvironment(void) {
  GamblerEnvironment* that = (GamblerEnvironment*)capyThat;
  $(that, destructCapyMDPEnvironment)();
  $(&(that->rng), destruct)();
}

// Create a new GamblerEnvironment
// Input:
//   seed: RNG seed
// Output:
//   Return a GamblerEnvironment
static GamblerEnvironment GamblerEnvironmentCreate(
  CapyRandomSeed_t const seed) {
  GamblerEnvironment that = {0};
  CapyInherits(that, CapyMDPEnvironment, ());
  that.rng = CapyRandomCreate(seed);
  that.step = StepGamblerEnvironment;
  that.destruct = DestructGamblerEnvironment;
  return that;
}

static CapyMarkovDecisionProcess* CreateMDPFrozenLake(void) {
  size_t const nbState = 16;
  size_t const nbAction = 4;
  CapyMarkovDecisionProcess* mdp =
    CapyMarkovDecisionProcessAlloc(nbState, nbAction);
  loop(iState, nbState) {
    if(iState == 0) {
      mdp->flagStartStates[iState] = true;
      mdp->flagEndStates[iState] = false;
    } else if(
      iState == 15 ||
      iState == 5 ||
      iState == 7 ||
      iState == 11 ||
      iState == 12
    ) {
      mdp->flagStartStates[iState] = false;
      mdp->flagEndStates[iState] = true;
    } else {
      mdp->flagStartStates[iState] = false;
      mdp->flagEndStates[iState] = false;
    }
  }
  loop(fromState, nbState) loop(action, nbAction) loop(toState, nbState) {
    CapyMDPTransition* const transition =
      $(mdp, getTransition)(fromState, action, toState);
    if(toState == 15) {
      transition->reward = 1.0;
    } else if(
      toState == 5 ||
      toState == 7 ||
      toState == 11 ||
      toState == 12
    ) {
      transition->reward = -1.0;
    } else {
      transition->reward = 0.0;
    }
  }
  return mdp;
}

// Action for the frozen lake environment
typedef enum FrozenLakeAction {
  frozenLakeActionN,
  frozenLakeActionE,
  frozenLakeActionS,
  frozenLakeActionW,
  frozenLakeActionNb,
} FrozenLakeAction;

// Environment for the frozen lake problem
typedef struct FrozenLakeEnvironment {

  // Inherits CapyMDPEnvironment
  struct CapyMDPEnvironmentDef;

  // Destructor for the parent class
  void (*destructCapyMDPEnvironment)(void);
} FrozenLakeEnvironment;

// Get the result action for a given state in the frozen lake environment
// Input:
//   fromState: the 'from' state
//   action: the applied action
// Output:
//   Return the result state
static size_t StepFrozenLakeEnvironment(
  size_t const fromState,
  size_t const action) {
  if(action == frozenLakeActionN) {
    if(fromState > 3) {
      return fromState - 4;
    } else {
      return fromState;
    }
  } else if(action == frozenLakeActionE) {
    if(fromState % 4 != 3) {
      return fromState + 1;
    } else {
      return fromState;
    }
  } else if(action == frozenLakeActionS) {
    if(fromState < 12) {
      return fromState + 4;
    } else {
      return fromState;
    }
  } else if(action == frozenLakeActionW) {
    if(fromState % 4 != 0) {
      return fromState - 1;
    } else {
      return fromState;
    }
  }
  return 0;
}

// Free the memory used by a CapyMDPEnvironment
static void DestructFrozenLakeEnvironment(void) {
  FrozenLakeEnvironment* that = (FrozenLakeEnvironment*)capyThat;
  $(that, destructCapyMDPEnvironment)();
}

// Create a new FrozenLakeEnvironment
// Input:
//   seed: RNG seed
// Output:
//   Return a FrozenLakeEnvironment
static FrozenLakeEnvironment FrozenLakeEnvironmentCreate(void) {
  FrozenLakeEnvironment that = {0};
  CapyInherits(that, CapyMDPEnvironment, ());
  that.step = StepFrozenLakeEnvironment;
  that.destruct = DestructFrozenLakeEnvironment;
  return that;
}

// Action for the black jack environment
typedef enum BlackJackAction {
  blackJackActionStick,
  blackJackActionHit,
  blackJackActionNb,
} BlackJackAction;

// Score values for the black jack environment
typedef enum BlackJackScore {
  blackJackScore12,
  blackJackScore13,
  blackJackScore14,
  blackJackScore15,
  blackJackScore16,
  blackJackScore17,
  blackJackScore18,
  blackJackScore19,
  blackJackScore20,
  blackJackScore21,
  blackJackScoreNb,
} BlackJackScore;

// Opponent score values for the black jack environment
typedef enum BlackJackOppScore {
  blackJackOppScore01,
  blackJackOppScore02,
  blackJackOppScore03,
  blackJackOppScore04,
  blackJackOppScore05,
  blackJackOppScore06,
  blackJackOppScore07,
  blackJackOppScore08,
  blackJackOppScore09,
  blackJackOppScore10,
  blackJackOppScoreNb,
} BlackJackOppScore;

// State for the black jack environment
typedef struct BlackJackState {
  BlackJackScore playerScore;
  BlackJackOppScore oppScore;
  bool flagAce;
  CapyPad(bool, flagAce);
} BlackJackState;

#define BlackJackStateIdxWin 0
#define BlackJackStateIdxLoose 1
#define BlackJackStateIdxDraw 2
#define BlackJackStateNb (3+2*blackJackOppScoreNb*blackJackScoreNb)
static size_t BlackJackStateToIdx(BlackJackState const state) {
  size_t idx =
    3 +
    (state.playerScore * blackJackOppScoreNb + state.oppScore) * 2 +
    state.flagAce;
  return idx;
}

static BlackJackState BlackJackIdxToState(size_t const idx) {
  BlackJackState state = {0};
  size_t i = idx - 3;
  state.playerScore = i / (blackJackOppScoreNb * 2);
  state.oppScore =
    (i - state.playerScore * blackJackOppScoreNb * 2) / 2;
  state.flagAce =
    i -
    (state.playerScore * blackJackOppScoreNb + state.oppScore) * 2;
  return state;
}

// Black jack environment
typedef struct BlackJackEnvironment {

  // Inherits CapyMDPEnvironment
  struct CapyMDPEnvironmentDef;

  // RNG
  CapyRandom rng;

  // Destructor for the parent class
  void (*destructCapyMDPEnvironment)(void);
} BlackJackEnvironment;

// Get the result action for a given state in the BlackJack environment
// Input:
//   fromState: the 'from' state
//   action: the applied action
// Output:
//   Return the result state
static size_t StepBlackJackEnvironment(
  size_t const fromState,
  size_t const action) {
  BlackJackEnvironment* that = (BlackJackEnvironment*)capyThat;
  if(
    fromState == BlackJackStateIdxLoose ||
    fromState == BlackJackStateIdxDraw ||
    fromState == BlackJackStateIdxWin
  ) {
    return fromState;
  }
  BlackJackState state = BlackJackIdxToState(fromState);
  CapyRangeUInt8 const cards = {.min = 1, .max = 13};
  bool flagDealerTurn = false;
  size_t toState = fromState;
  size_t playerScore = 12 + state.playerScore;
  if(action == blackJackActionHit) {
    uint8_t card = $(&(that->rng), getUInt8Range)(&cards);
    if(card > 10) card = 10;
    if(playerScore < 11 && card == 1 && state.flagAce == false) {
      card = 11;
      state.flagAce = true;
    }
    playerScore += card;
    if(playerScore == 21) {
      toState = BlackJackStateIdxWin;
      flagDealerTurn = true;
    } else if(playerScore > 21) {
      if(state.flagAce) {
        playerScore -= 10;
        state.flagAce = false;
        state.playerScore = playerScore - 12;
        toState = BlackJackStateToIdx(state);
      } else {
        return BlackJackStateIdxLoose;
      }
    } else {
      state.playerScore = playerScore - 12;
      toState = BlackJackStateToIdx(state);
    }
  } else if(action == blackJackActionStick) {
    flagDealerTurn = true;
  }
  if(flagDealerTurn) {
    size_t oppScore = state.oppScore + 1;
    bool flagOppAce = false;
    uint8_t hiddenCard = $(&(that->rng), getUInt8Range)(&cards);
    if(hiddenCard > 10) hiddenCard = 10;
    if(
      (oppScore == 1 && hiddenCard == 10) ||
      (oppScore == 10 && hiddenCard == 1)
    ) {
      if(toState == BlackJackStateIdxWin) {
        return BlackJackStateIdxDraw;
      } else {
        return BlackJackStateIdxLoose;
      }
    }
    oppScore += hiddenCard;
    while(true) {
      if(oppScore == 21) {
        if(toState == BlackJackStateIdxWin) {
          return BlackJackStateIdxDraw;
        } else {
          return BlackJackStateIdxLoose;
        }
      } else if(oppScore > 21) {
        if(flagOppAce) {
          flagOppAce = false;
          oppScore -= 10;
        } else {
          if(toState == BlackJackStateIdxLoose) {
            return BlackJackStateIdxDraw;
          } else {
            return BlackJackStateIdxWin;
          }
        }
      }
      if(oppScore > 16) {
        if(toState == BlackJackStateIdxLoose) {
          return toState;
        } else {
          if(oppScore > playerScore) {
            return BlackJackStateIdxLoose;
          } else if(oppScore == playerScore) {
            return BlackJackStateIdxDraw;
          } else {
            return BlackJackStateIdxWin;
          }
        }
      }
      uint8_t card = $(&(that->rng), getUInt8Range)(&cards);
      if(card > 10) card = 10;
      if(oppScore < 11 && card == 1 && flagOppAce == false) {
        card = 11;
        flagOppAce = true;
      }
      oppScore += card;
    }
  }
  return toState;
}

// Free the memory used by a BlackJackEnvironment
static void DestructBlackJackEnvironment(void) {
  BlackJackEnvironment* that = (BlackJackEnvironment*)capyThat;
  $(&(that->rng), destruct)();
  $(that, destructCapyMDPEnvironment)();
}

// Create a new BlackJackEnvironment
// Input:
//   seed: RNG seed
// Output:
//   Return a BlackJackEnvironment
static BlackJackEnvironment BlackJackEnvironmentCreate(
  CapyRandomSeed_t const seed) {
  BlackJackEnvironment that = {0};
  CapyInherits(that, CapyMDPEnvironment, ());
  that.rng = CapyRandomCreate(seed);
  that.step = StepBlackJackEnvironment;
  that.destruct = DestructBlackJackEnvironment;
  return that;
}

static CapyMarkovDecisionProcess* CreateMDPBlackJack(void) {
  size_t const nbState = BlackJackStateNb;
  size_t const nbAction = 2;
  CapyMarkovDecisionProcess* mdp =
    CapyMarkovDecisionProcessAlloc(nbState, nbAction);
  loop(iState, BlackJackStateNb) {
    if(
      iState == BlackJackStateIdxLoose ||
      iState == BlackJackStateIdxDraw ||
      iState == BlackJackStateIdxWin
    ) {
      mdp->flagStartStates[iState] = false;
      mdp->flagEndStates[iState] = true;
    } else {
      mdp->flagStartStates[iState] = true;
      mdp->flagEndStates[iState] = false;
    }
  }
  loop(fromState, nbState) loop(action, nbAction) loop(toState, nbState) {
    CapyMDPTransition* const transition =
      $(mdp, getTransition)(fromState, action, toState);
    if(toState == BlackJackStateIdxLoose) {
      transition->reward = -1.0;
    } else if(toState == BlackJackStateIdxWin) {
      transition->reward = 1.0;
    } else {
      transition->reward = 0.0;
    }
  }
  return mdp;
}

#endif
CUTEST(test001, "Alloc and free") {
  size_t const nbState = 2;
  size_t const nbAction = 3;
  CapyMarkovDecisionProcess* mdp =
    CapyMarkovDecisionProcessAlloc(nbState, nbAction);
  bool isOk =
    mdp->nbState == nbState &&
    mdp->nbAction == nbAction &&
    mdp->nbTransition == nbState * nbAction * nbState &&
    mdp->transitions != NULL &&
    mdp->curState == 0 &&
    mdp->nbStep == 0 &&
    equal(mdp->discount, 0.9) &&
    equal(mdp->epsilon, 1e-6);
  CapyMarkovDecisionProcessFree(&mdp);
  CUTEST_ASSERT(isOk && mdp == NULL, "Alloc/free failed");
}

CUTEST(test002, "Create dummy example 1") {
  CapyMarkovDecisionProcess* mdp = CreateMDPDummy1();
  bool isOk =
    mdp->nbState == dummyStateNb &&
    mdp->nbAction == dummyActionNb &&
    mdp->transitions[0].fromState == 0 &&
    mdp->transitions[0].action == 0 &&
    mdp->transitions[0].toState == 0 &&
    equal(mdp->transitions[0].prob, 1.0) &&
    equal(mdp->transitions[0].reward, 3.0) &&
    mdp->transitions[1].fromState == 0 &&
    mdp->transitions[1].action == 0 &&
    mdp->transitions[1].toState == 1 &&
    equal(mdp->transitions[1].prob, 0.0) &&
    equal(mdp->transitions[1].reward, 0.0) &&
    mdp->transitions[2].fromState == 0 &&
    mdp->transitions[2].action == 1 &&
    mdp->transitions[2].toState == 0 &&
    equal(mdp->transitions[2].prob, 0.0) &&
    equal(mdp->transitions[2].reward, 0.0) &&
    mdp->transitions[3].fromState == 0 &&
    mdp->transitions[3].action == 1 &&
    mdp->transitions[3].toState == 1 &&
    equal(mdp->transitions[3].prob, 1.0) &&
    equal(mdp->transitions[3].reward, 5.0);
  CUTEST_ASSERT(isOk, "Create dummy example failed");
  CapyMarkovDecisionProcessFree(&mdp);
}

CUTEST(test003, "Run a few steps on dummy example 1") {
  CapyMarkovDecisionProcess* mdp = CreateMDPDummy1();
  $(mdp, setCurState)(dummyStatePlaying);
  $(mdp, resetRng)(0);
  int nbStep = 10;
  bool isOk = true;
  double reward = 0.0;
  DummyState states[] = {
    dummyStatePlaying, dummyStatePlaying, dummyStatePlaying,
    dummyStatePlaying, dummyStatePlaying, dummyStateEnd,
    dummyStateEnd, dummyStateEnd, dummyStateEnd,
    dummyStateEnd
  };
  loop(iStep, nbStep) {
    CapyMDPTransition* const trans = $(mdp, step)();
    reward += trans->reward;
    isOk &= (mdp->curState == states[iStep]);
  }
  isOk &= equal(reward, 20.0);
  CUTEST_ASSERT(isOk, "Step failed");
  CapyMarkovDecisionProcessFree(&mdp);
}

CUTEST(test004, "Run a few steps on dummy example 2") {
  CapyMarkovDecisionProcess* mdp = CreateMDPDummy2();
  $(mdp, setCurState)(dummyStatePlaying);
  $(mdp, resetRng)(0);
  int nbStep = 10;
  bool isOk = true;
  double reward = 0.0;
  DummyState states[] = {
    dummyStatePlaying, dummyStatePlaying, dummyStatePlaying,
    dummyStatePlaying, dummyStatePlaying, dummyStateEnd,
    dummyStateEnd, dummyStateEnd, dummyStateEnd,
    dummyStateEnd
  };
  loop(iStep, nbStep) {
    CapyMDPTransition* const trans = $(mdp, step)();
    reward += trans->reward;
    isOk &= (mdp->curState == states[iStep]);
  }
  isOk &= equal(reward, 20.0);
  CUTEST_ASSERT(isOk, "Step failed");
  CapyMarkovDecisionProcessFree(&mdp);
}

CUTEST(test005, "Get expected reward on dummy example 2") {
  CapyMarkovDecisionProcess* mdp = CreateMDPDummy2();
  $(mdp, resetRng)(0);
  double const expReward =
    $(mdp, getExpRewardFromState)(dummyStatePlaying, 100000);
  bool isOk = fabs(expReward - 11.0) < 1e-1;
  CUTEST_ASSERT(isOk, "expReward=%lf", expReward);
  CapyMarkovDecisionProcessFree(&mdp);
}

CUTEST(test006, "Get optimal policy on dummy example 2") {
  CapyMarkovDecisionProcess* mdp = CreateMDPDummy2();
  $(mdp, searchOptimalPolicy)();
  bool isOk =
    equal(mdp->optimalPolicy.values[0], 9.166665440343482629) &&
    equal(mdp->optimalPolicy.values[1], 0.0) &&
    mdp->optimalPolicy.actions[dummyStatePlaying] == dummyActionContinue &&
    mdp->optimalPolicy.actions[dummyStateEnd] == dummyActionContinue;
  CUTEST_ASSERT(
    isOk, "values={%lf, %lf}, actions={%lu, %lu}",
    mdp->optimalPolicy.values[0],
    mdp->optimalPolicy.values[1],
    mdp->optimalPolicy.actions[0],
    mdp->optimalPolicy.actions[1]);
  CapyMarkovDecisionProcessFree(&mdp);
}

CUTEST(test007, "Run a few steps on dummy example 2 using optimal policy") {
  CapyMarkovDecisionProcess* mdp = CreateMDPDummy2();
  $(mdp, searchOptimalPolicy)();
  $(mdp, setCurState)(dummyStatePlaying);
  $(mdp, resetRng)(0);
  int nbStep = 10;
  bool isOk = true;
  double reward = 0.0;
  DummyState states[] = {
    dummyStatePlaying, dummyStatePlaying, dummyStatePlaying,
    dummyStatePlaying, dummyStatePlaying, dummyStateEnd,
    dummyStateEnd, dummyStateEnd, dummyStateEnd,
    dummyStateEnd
  };
  loop(iStep, nbStep) {
    CapyMDPTransition* const trans = $(mdp, stepPolicy)(&(mdp->optimalPolicy));
    reward += trans->reward;
    isOk &= (mdp->curState == states[iStep]);
  }
  isOk &= equal(reward, 20.0);
  CUTEST_ASSERT(isOk, "StepPolicy failed");
  CapyMarkovDecisionProcessFree(&mdp);
}

CUTEST(test008, "Get expected reward using optimal policy on dummy 2") {
  CapyMarkovDecisionProcess* mdp = CreateMDPDummy2();
  $(mdp, searchOptimalPolicy)();
  $(mdp, resetRng)(0);
  double const expReward = $(mdp, getExpRewardFromStateForPolicy)(
    dummyStatePlaying, 100000, &(mdp->optimalPolicy));
  bool isOk = fabs(expReward - 11.0) < 1e-1;
  CUTEST_ASSERT(isOk, "expReward=%lf", expReward);
  CapyMarkovDecisionProcessFree(&mdp);
}

CUTEST(test009, "Get expected reward on dummy example 3") {
  CapyMarkovDecisionProcess* mdp = CreateMDPDummy3();
  $(mdp, resetRng)(0);
  double const expReward =
    $(mdp, getExpRewardFromState)(dummyStatePlaying, 1000);
  bool isOk = fabs(expReward - 2.57) < 0.1;
  CUTEST_ASSERT(isOk, "expReward=%lf", expReward);
  CapyMarkovDecisionProcessFree(&mdp);
}

CUTEST(test010, "Get optimal policy on dummy example 3") {
  CapyMarkovDecisionProcess* mdp = CreateMDPDummy3();
  $(mdp, searchOptimalPolicy)();
  bool isOk =
    equal(mdp->optimalPolicy.values[0], 5.0) &&
    equal(mdp->optimalPolicy.values[1], 0.0) &&
    mdp->optimalPolicy.actions[dummyStatePlaying] == dummyActionQuit &&
    mdp->optimalPolicy.actions[dummyStateEnd] == dummyActionContinue;
  CUTEST_ASSERT(
    isOk, "values={%lf, %lf}, actions={%lu, %lu}",
    mdp->optimalPolicy.values[0],
    mdp->optimalPolicy.values[1],
    mdp->optimalPolicy.actions[0],
    mdp->optimalPolicy.actions[1]);
  CapyMarkovDecisionProcessFree(&mdp);
}

CUTEST(test011, "Get expected reward using optimal policy on dummy 3") {
  CapyMarkovDecisionProcess* mdp = CreateMDPDummy3();
  $(mdp, searchOptimalPolicy)();
  $(mdp, resetRng)(0);
  double const expReward = $(mdp, getExpRewardFromStateForPolicy)(
    dummyStatePlaying, 100000, &(mdp->optimalPolicy));
  bool isOk = fabs(expReward - 5.0) < 1.0;
  CUTEST_ASSERT(isOk, "expReward=%lf", expReward);
  CapyMarkovDecisionProcessFree(&mdp);
}

CUTEST(test012, "Get optimal policy on dummy example 4") {
  CapyMarkovDecisionProcess* mdp = CreateMDPDummy4();
  $(mdp, searchOptimalPolicy)();
  bool isOk =
    equal(mdp->optimalPolicy.values[0], 5.0) &&
    equal(mdp->optimalPolicy.values[1], 0.0) &&
    mdp->optimalPolicy.actions[dummyStatePlaying] == dummyActionContinue &&
    mdp->optimalPolicy.actions[dummyStateEnd] == dummyActionContinue;
  CUTEST_ASSERT(
    isOk, "values={%lf, %lf}, actions={%lu, %lu}",
    mdp->optimalPolicy.values[0],
    mdp->optimalPolicy.values[1],
    mdp->optimalPolicy.actions[0],
    mdp->optimalPolicy.actions[1]);
  CapyMarkovDecisionProcessFree(&mdp);
}

CUTEST(test013, "Get optimal policy on gambler example") {
  CapyMarkovDecisionProcess* mdp = CreateMDPGambler();
  $(mdp, resetRng)(0);
  mdp->discount = 1.0;
  $(mdp, searchOptimalPolicy)();
  bool isOk = true;
  size_t const checkActions[101] = {
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
    10, 11, 12, 12, 11, 10, 9, 8, 7, 6,
    5, 4, 3, 2, 1, 25, 1, 2, 3, 4,
    5, 6, 7, 8, 9, 10, 11, 12, 12, 11,
    10, 9, 8, 7, 6, 5, 4, 3, 2, 1,
    50, 1, 2, 3, 4, 5, 6, 7, 8, 9,
    10, 11, 12, 12, 11, 10, 9, 8, 7, 6,
    5, 4, 3, 2, 1, 25, 1, 2, 3, 4,
    5, 6, 7, 8, 9, 10, 11, 12, 12, 11,
    10, 9, 8, 7, 6, 5, 4, 3, 2, 1,
    0,
  };
  double const checkValues[101] = {
    0.000000, -0.995869, -0.989672, -0.981549, -0.974180,
    -0.965229, -0.953873, -0.944372, -0.935449, -0.924630,
    -0.913073, -0.899291, -0.884681, -0.869521, -0.860929,
    -0.851137, -0.838623, -0.826778, -0.811575, -0.793713,
    -0.782682, -0.768067, -0.748228, -0.732840, -0.711704,
    -0.680000, -0.673803, -0.664508, -0.652323, -0.641269,
    -0.627844, -0.610809, -0.596558, -0.583174, -0.566945,
    -0.549609, -0.528937, -0.507022, -0.484282, -0.471394,
    -0.456706, -0.437935, -0.420167, -0.397362, -0.370569,
    -0.354023, -0.332100, -0.302341, -0.279260, -0.247556,
    -0.200000, -0.193803, -0.184508, -0.172323, -0.161269,
    -0.147844, -0.130809, -0.116558, -0.103174, -0.086945,
    -0.069609, -0.048937, -0.027022, -0.004282, 0.008606,
    0.023294, 0.042065, 0.059833, 0.082638, 0.109431,
    0.125977, 0.147900, 0.177659, 0.200740, 0.232444,
    0.280000, 0.289296, 0.303239, 0.321515, 0.338096,
    0.358235, 0.383787, 0.405164, 0.425239, 0.449583,
    0.475586, 0.506595, 0.539467, 0.573577, 0.592909,
    0.614941, 0.643098, 0.669750, 0.703957, 0.744147,
    0.768965, 0.801850, 0.846488, 0.881110, 0.928666,
    0.000000,
  };
  loop(iState, mdp->nbState) {
    isOk &=
      fabs(mdp->optimalPolicy.values[iState] - checkValues[iState]) < 1e-6 &&
      mdp->optimalPolicy.actions[iState] == checkActions[iState];
  }
  CUTEST_ASSERT(isOk, "unexpected optimal policy for gambler example");
  CapyMarkovDecisionProcessFree(&mdp);
}

CUTEST(
  test014, "Get optimal policy on gambler example with Q-Learning") {
  GamblerEnvironment env = GamblerEnvironmentCreate(0);
  CapyMarkovDecisionProcess* mdp = CreateMDPGamblerWithRewardOnly();
  $(mdp, resetRng)(0);
  mdp->discount = 1.0;
  mdp->environment = (CapyMDPEnvironment*)&env;
  loop(fromState, mdp->nbState) {
    CapyRangeUInt8 range = {
      .min = 0,
      .max = (uint8_t)(fromState <= 50 ? fromState : 100 - fromState)
    };
    mdp->optimalPolicy.actions[fromState] =
      $(&(mdp->rng), getUInt8Range)(&range);
  }
  double const epsilon = 0.1;
  double const learningRate = 0.1;
  size_t nbEpisode = 100;
  double checks[15] = {
    -0.975400, -0.984800, -0.982200, -0.928600, -0.494200,
    -0.299000, -0.321400, -0.298000, -0.308600, -0.298000,
    -0.309400, -0.226600, -0.220600, -0.207600, -0.207200,
  };
  bool isOk = true;
  loop(i, 8) {
    $(mdp, qLearning)(epsilon, learningRate, nbEpisode);
    double const expReward =
      $(mdp, getExpRewardForPolicy)(10000, &(mdp->optimalPolicy));
    isOk &= (fabs(expReward - checks[i]) < 1e-6);
    nbEpisode *= 2;
  }
  CUTEST_ASSERT(isOk, "unexpected reward");
  CapyMarkovDecisionProcessFree(&mdp);
  $(&env, destruct)();
}

CUTEST(
  test015, "Get optimal policy on frozen lake example with Q-Learning") {
  FrozenLakeEnvironment env = FrozenLakeEnvironmentCreate();
  CapyMarkovDecisionProcess* mdp = CreateMDPFrozenLake();
  $(mdp, resetRng)(0);
  mdp->discount = 0.9;
  mdp->environment = (CapyMDPEnvironment*)&env;
  double const epsilon = 0.1;
  double const learningRate = 0.1;
  size_t const nbEpisode = 1000; //100000;
  bool isOk = true;
  $(mdp, qLearning)(epsilon, learningRate, nbEpisode);
  size_t const checkActions[16] = {
    1, 1, 2, 3,
    0, 0, 2, 0,
    0, 1, 2, 0,
    0, 1, 1, 0,
  };
  double const checkValues[16] = {
    0.590490, 0.656100, 0.729000, 0.629400,
    0.510987, 0.000000, 0.810000, 0.000000,
    0.113865, 0.730225, 0.900000, 0.000000,
    0.000000, 0.856827, 1.000000, 0.000000,
  };
  loop(iState, 16) {
    isOk &= (mdp->optimalPolicy.actions[iState] == checkActions[iState]);
    isOk &=
      (fabs(mdp->optimalPolicy.values[iState] - checkValues[iState]) < 1e-6);
  }
  double const checkReward = 1.0;
  double const expReward =
    $(mdp, getExpRewardForPolicy)(1, &(mdp->optimalPolicy));
  isOk &= (fabs(expReward - checkReward) < 1e-6);
  CUTEST_ASSERT(isOk, "unexpected policy");
  CapyMarkovDecisionProcessFree(&mdp);
  $(&env, destruct)();
}

CUTEST(
  test016, "Get optimal policy on black jack example with Q-Learning") {
  BlackJackEnvironment env = BlackJackEnvironmentCreate(0);
  CapyMarkovDecisionProcess* mdp = CreateMDPBlackJack();
  $(mdp, resetRng)(0);
  mdp->discount = 1.0;
  mdp->environment = (CapyMDPEnvironment*)&env;
  loop(fromState, mdp->nbState) {
    mdp->optimalPolicy.actions[fromState] = blackJackActionHit;
  }
  double const epsilon = 0.1;
  double const learningRate = 1.0 / 5000.0;
  size_t nbEpisode = 100;
  double checks[20] = {
    -0.240300, -0.092900, -0.077400, -0.025800, 0.004700,
    -0.010100, 0.022300, 0.047800, 0.046200, 0.064300,
    0.065700, 0.067400, 0.074600, 0.069200, 0.082300,
    0.102200, 0.098700, 0.090700, 0.091200, 0.110200,
  };
  bool isOk = true;
  loop(i, 15) {
    $(mdp, qLearning)(epsilon, learningRate, nbEpisode);
    double const expReward =
      $(mdp, getExpRewardForPolicy)(10000, &(mdp->optimalPolicy));
    isOk &= (fabs(expReward - checks[i]) < 1e-6);
    nbEpisode *= 2;
  }
  size_t checkActions[2][blackJackOppScoreNb][blackJackScoreNb] = {
    {
      {1, 1, 1, 1, 1, 0, 0, 0, 0, 0, },
      {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, },
      {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, },
      {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, },
      {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, },
      {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, },
      {1, 1, 1, 1, 0, 0, 0, 0, 0, 0, },
      {1, 1, 0, 1, 0, 0, 0, 0, 0, 0, },
      {1, 1, 1, 1, 1, 0, 0, 0, 0, 0, },
      {1, 1, 1, 1, 1, 0, 0, 0, 0, 0, },
    }, {
      {1, 1, 1, 1, 1, 1, 1, 1, 0, 0, },
      {1, 1, 1, 1, 1, 1, 0, 0, 0, 0, },
      {1, 1, 1, 1, 1, 1, 0, 0, 0, 0, },
      {1, 1, 1, 1, 1, 1, 0, 0, 0, 0, },
      {1, 1, 1, 1, 1, 1, 0, 0, 0, 0, },
      {1, 1, 1, 1, 1, 1, 0, 0, 0, 0, },
      {1, 1, 1, 1, 1, 1, 0, 0, 0, 0, },
      {1, 1, 1, 1, 1, 1, 0, 0, 0, 0, },
      {1, 1, 1, 1, 1, 1, 1, 0, 0, 0, },
      {1, 1, 1, 1, 1, 1, 1, 0, 0, 0, },
    },
  };
  loop(iAce, 2) {
    loop(oppScore, blackJackOppScoreNb) {
      loop(playerScore, blackJackScoreNb){
        BlackJackState state = {
          .playerScore=playerScore, .oppScore=oppScore, .flagAce=iAce
        };
        size_t idx = BlackJackStateToIdx(state);
        size_t action = mdp->optimalPolicy.actions[idx];
        isOk &= (action == checkActions[iAce][oppScore][playerScore]);
      }
    }
  }
  CUTEST_ASSERT(isOk, "unexpected policy");
  CapyMarkovDecisionProcessFree(&mdp);
  $(&env, destruct)();
}
