// ---------------------------- rungekutta.c ---------------------------
/*
    LibCapy - a general purpose library of C functions and data structures
    Copyright (C) 2021-2025 Pascal Baillehache info@baillehachepascal.dev
    https://baillehachepascal.dev
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.
    You should have received a copy of the GNU General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "rungekutta.h"

// Approximate X(t) using RK4
// Input:
//   t: value for which we want to approximate X(t)
//   vals: values [t0, X(t0)] or [t0, X(t0), X'(t0)]
// Output:
//   Update 'vals' with the approximated values at 't'
static void Eval(
   double const t,
  double* const vals) {
  methodOf(CapyRungeKutta);
  while(vals[0] < t) {
    $(that, step)(vals);
  }
}

// Run one step
// Input:
//   vals: current values [t0, X(t0)] or [t0, X(t0), X'(t0)]
// Output:
//   'vals' is updated with the result of one step of RK4
static void Step(double* const vals) {
  methodOf(CapyRungeKutta);
  double* const d0 = that->vecs[0].vals;
  double* const d1 = that->vecs[1].vals;
  double* const d2 = that->vecs[2].vals;
  double* const d3 = that->vecs[3].vals;
  double* const in = that->vecs[4].vals;
  $(that->derivative, eval)(vals, d0);
  loop(i, that->derivative->dimIn) {
    if(i == 0) {
      in[i] = vals[i] + that->deltaT * 0.5;
    } else if(that->order == 1) {
      in[i] = vals[i] + that->deltaT * 0.5 * d0[i - 1];
    } else if(that->order == 2) {
      if(i <= that->dimVar) {
        in[i] = vals[i] + that->deltaT * 0.5 * vals[that->dimVar + i];
      } else {
        in[i] = vals[i] + that->deltaT * 0.5 * d0[i - 1];
      }
    }
  }
  $(that->derivative, eval)(in, d1);
  loop(i, that->derivative->dimIn) {
    if(i == 0) {
      in[i] = vals[i] + that->deltaT * 0.5;
    } else if(that->order == 1) {
      in[i] = vals[i] + that->deltaT * 0.5 * d1[i - 1];
    } else if(that->order == 2) {
      if(i <= that->dimVar) {
        in[i] = vals[i] + that->deltaT * 0.5 * d0[i - 1];
      } else {
        in[i] = vals[i] + that->deltaT * 0.5 * d1[i - 1];
      }
    }
  }
  $(that->derivative, eval)(in, d2);
  loop(i, that->derivative->dimIn) {
    if(i == 0) {
      in[i] = vals[i] + that->deltaT * 0.5;
    } else if(that->order == 1) {
      in[i] = vals[i] + that->deltaT * d2[i];
    } else if(that->order == 2) {
      if(i <= that->dimVar) {
        in[i] = vals[i] + that->deltaT * d1[i - 1];
      } else {
        in[i] = vals[i] + that->deltaT * d2[i - 1];
      }
    }
  }
  $(that->derivative, eval)(in, d3);
  loop(i, that->derivative->dimIn) {
    if(i == 0) {
      vals[i] += that->deltaT;
    } else if(that->order == 1) {
      vals[i] +=
        that->deltaT / 6.0 *
        (d0[i - 1] + 2.0 * d1[i - 1] + 2.0 * d2[i - 1] + d3[i - 1]);
    } else if(that->order == 2) {
      if(i <= that->dimVar) {
        vals[i] +=
          that->deltaT * (
            vals[i + that->dimVar] +
            that->deltaT / 6.0 * (d0[i - 1] + d1[i - 1] + d2[i - 1]));
      } else {
        vals[i] +=
          that->deltaT / 6.0 * (
            d0[i - that->dimVar - 1] + 2.0 * d1[i - that->dimVar - 1] +
            2.0 * d2[i - that->dimVar - 1] + d3[i - that->dimVar - 1]);
      }
    }
  }
}

// Free the memory used by a CapyRungeKutta
static void Destruct(void) {
  methodOf(CapyRungeKutta);
  CapyVecDestruct(&(that->initVal));
  loop(i, 5) CapyVecDestruct(that->vecs + i);
  *that = (CapyRungeKutta){0};
}

// Create a CapyRungeKutta
// Input:
//   derivative: the derivative function
//   order: order of the derivative (in {1,2})
// Output:
//   Return a CapyRungeKutta
CapyRungeKutta CapyRungeKuttaCreate(
  CapyMathFun* const derivative,
        size_t const order) {
  if(order == 0 || order >= 3) {
    raiseExc(CapyExc_InvalidParameters);
  }
  if(derivative->dimIn < 2) {
    raiseExc(CapyExc_InvalidParameters);
  }
  CapyRungeKutta that = {
    .derivative = derivative,
    .order = order,
    .dimVar = (derivative->dimIn - 1) / order,
    .initVal = CapyVecCreate(derivative->dimIn),
    .deltaT = 1e-3,
    .eval = Eval,
    .step = Step,
    .destruct = Destruct,
  };
  if(
    derivative->dimOut != that.dimVar ||
    that.dimVar * order + 1 != derivative->dimIn
  ) {
    raiseExc(CapyExc_InvalidParameters);
  }
  loop(i, 5) that.vecs[i] = CapyVecCreate(derivative->dimIn);
  return that;
}

// Allocate memory for a new CapyRungeKutta and create it
// Input:
//   derivative: the derivative function
//   order: order of the derivative (in {1,2})
// Output:
//   Return a CapyRungeKutta
// Exception:
//   May raise CapyExc_MallocFailed.
CapyRungeKutta* CapyRungeKuttaAlloc(
  CapyMathFun* const derivative,
        size_t const order) {
  CapyRungeKutta* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyRungeKuttaCreate(derivative, order);
  return that;
}

// Free the memory used by a CapyRungeKutta* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyRungeKutta to free
void CapyRungeKuttaFree(CapyRungeKutta** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}
