#include "capy.h"
#ifndef FIXTURE
#define FIXTURE
static void FunEval(
  double const* const in,
        double* const out) {
  out[0] = in[0] * in[0] + in[1] * in[1];
}

static void FunDerivative(
  double* const in,
   size_t const iDim,
  double* const out) {
  if(iDim == 0) out[0] = 2.0 * in[0];
  else out[0] = 2.0 * in[1];
}

#endif
CUTEST(test001, "standard") {
  CapyMathFun fun = CapyMathFunCreate(2, 1);
  fun.eval = CapyHimmelblau;
  double initIn[2] = {0.0, 0.0};
  CapyGradientDescent* gd = CapyGradientDescentAlloc(&fun, initIn);
  gd->learnRate = 0.01;
  loop(i, 20) $(gd, step)();
  CUTEST_ASSERT(
    fabs(gd->in.vals[0] - 3.0) < 0.01 && fabs(gd->in.vals[1] - 2.0) < 0.01,
    "unexpected minimum [%lf,%lf] != [3,2]",
    gd->in.vals[0], gd->in.vals[1]);
  CapyGradientDescentFree(&gd);
  $(&fun, destruct)();
}

CUTEST(test002, "momentum") {
  CapyMathFun fun = CapyMathFunCreate(2, 1);
  fun.eval = CapyHimmelblau;
  double initIn[2] = {0.0, 0.0};
  CapyGradientDescent* gd = CapyGradientDescentAlloc(&fun, initIn);
  gd->learnRate = 0.01;
  $(gd, setType)(capyGradientDescent_momentum);
  gd->momentum = 0.5;
  loop(i, 20) $(gd, step)();
  CUTEST_ASSERT(
    fabs(gd->in.vals[0] - 3.0) < 0.01 && fabs(gd->in.vals[1] - 2.0) < 0.01,
    "unexpected minimum [%lf,%lf] != [3,2]",
    gd->in.vals[0], gd->in.vals[1]);
  CapyGradientDescentFree(&gd);
  $(&fun, destruct)();
}

CUTEST(test003, "explicit derivative") {
  CapyMathFun fun = CapyMathFunCreate(2, 1);
  fun.eval = FunEval;
  fun.evalDerivative = FunDerivative;
  double initIn[2] = {2.0, 3.0};
  CapyGradientDescent* gd = CapyGradientDescentAlloc(&fun, initIn);
  gd->learnRate = 0.2;
  loop(i, 20) $(gd, step)();
  CUTEST_ASSERT(
    fabs(gd->in.vals[0]) < 0.01 && fabs(gd->in.vals[1]) < 0.01,
    "unexpected minimum [%lf,%lf] != [0,0]",
    gd->in.vals[0], gd->in.vals[1]);
  CapyGradientDescentFree(&gd);
  $(&fun, destruct)();
}

CUTEST(test004, "adam") {
  CapyMathFun fun = CapyMathFunCreate(2, 1);
  fun.eval = CapyHimmelblau;
  double initIn[2] = {0.0, 0.0};
  CapyGradientDescent* gd = CapyGradientDescentAlloc(&fun, initIn);
  $(gd, setType)(capyGradientDescent_adam);
  loop(i, 200) $(gd, step)();
  CUTEST_ASSERT(
    fabs(gd->in.vals[0] - 3.0) < 0.01 && fabs(gd->in.vals[1] - 2.0) < 0.01,
    "unexpected minimum [%lf,%lf] != [3,2]",
    gd->in.vals[0], gd->in.vals[1]);
  CapyGradientDescentFree(&gd);
  $(&fun, destruct)();
}

CUTEST(test005, "standard (minibatch size 4)") {
  CapyMathFun fun = CapyMathFunCreate(2, 1);
  fun.eval = CapyHimmelblau;
  double initIn[2] = {0.0, 0.0};
  CapyGradientDescent* gd = CapyGradientDescentAlloc(&fun, initIn);
  gd->learnRate = 0.01;
  gd->minibatchSize = 4;
  loop(i, 200) $(gd, step)();
  CUTEST_ASSERT(
    fabs(gd->in.vals[0] - 3.0) < 0.01 && fabs(gd->in.vals[1] - 2.0) < 0.01,
    "unexpected minimum [%lf,%lf] != [3,2]",
    gd->in.vals[0], gd->in.vals[1]);
  CapyGradientDescentFree(&gd);
  $(&fun, destruct)();
}

CUTEST(test006, "momentum (minibatch size 4)") {
  CapyMathFun fun = CapyMathFunCreate(2, 1);
  fun.eval = CapyHimmelblau;
  double initIn[2] = {0.0, 0.0};
  CapyGradientDescent* gd = CapyGradientDescentAlloc(&fun, initIn);
  gd->learnRate = 0.01;
  gd->minibatchSize = 4;
  $(gd, setType)(capyGradientDescent_momentum);
  gd->momentum = 0.5;
  loop(i, 200) $(gd, step)();
  CUTEST_ASSERT(
    fabs(gd->in.vals[0] - 3.0) < 0.01 && fabs(gd->in.vals[1] - 2.0) < 0.01,
    "unexpected minimum [%lf,%lf] != [3,2]",
    gd->in.vals[0], gd->in.vals[1]);
  CapyGradientDescentFree(&gd);
  $(&fun, destruct)();
}

CUTEST(test007, "adam (minibatch size 4)") {
  CapyMathFun fun = CapyMathFunCreate(2, 1);
  fun.eval = CapyHimmelblau;
  double initIn[2] = {0.0, 0.0};
  CapyGradientDescent* gd = CapyGradientDescentAlloc(&fun, initIn);
  $(gd, setType)(capyGradientDescent_adam);
  gd->minibatchSize = 4;
  loop(i, 600) $(gd, step)();
  CUTEST_ASSERT(
    fabs(gd->in.vals[0] - 3.0) < 0.01 && fabs(gd->in.vals[1] - 2.0) < 0.01,
    "unexpected minimum [%lf,%lf] != [3,2]",
    gd->in.vals[0], gd->in.vals[1]);
  CapyGradientDescentFree(&gd);
  $(&fun, destruct)();
}

