// ---------------------------- springmass.c ---------------------------
/*
    LibCapy - a general purpose library of C functions and data structures
    Copyright (C) 2021-2024 Pascal Baillehache info@baillehachepascal.dev
    https://baillehachepascal.dev
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.
    You should have received a copy of the GNU General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "springmass.h"

// Add one mass
// Input:
//   id: id of the mass
// Output:
//   If no mass with the given id already exists a new one is added. The
//   mass id is the id of the node in the graph. Return the mass. If created
//   the new mass is initialised with a mass of 1.0 and all other params
//   to 0.0.
static CapySpringMassNode* AddMass(size_t const id) {
  methodOf(CapySpringMass);

  // Try to get the node in the graph
  CapyGraphNode* node = $(that->graph, getNodeById)(id);

  // If a node with the given id doesn't exist
  if(node == NULL) {

    // Create the node in the graph with a newly allocated mass as user data
    CapyGraphNode newNode = {.id = id};
    CapySpringMassNode* mass = NULL;
    safeMalloc(mass, 1);
    newNode.data = mass;
    $(that->graph, addNode)(newNode);

    // Initialise the mass data
    mass->id = id;
    mass->mass = 1.0;
    mass->pos = CapyVecCreate(that->dim);
    mass->speed = CapyVecCreate(that->dim);
    mass->acc = CapyVecCreate(that->dim);
    mass->flagFixed = false;
    mass->data = NULL;

    // Get the new node
    node = $(that->graph, getNodeById)(id);
  }

  // Return the mass (node's user data)
  return (node == NULL ? NULL : (CapySpringMassNode*)(node->data));
}

// Get a mass given its id
// Input:
//   id: the id
// Output:
//   Return the mass, or NULL if there is no mass with the requested id
static CapySpringMassNode* GetMass(size_t const id) {
  methodOf(CapySpringMass);
  CapyGraphNode* node = $(that->graph, getNodeById)(id);
  return (node == NULL ? NULL : (CapySpringMassNode*)(node->data));
}

// Add one spring
// Input:
//   idA: id of the first mass
//   idB: id of the second mass
// Output:
//   Add a new spring, initialised with spring coefficients of 1.0 and
//   rest length of 1.0. Return the new spring. If there is no mass with the
//   given id they are automatically created.
static CapySpringMassLink* AddSpring(
  size_t const idA,
  size_t const idB) {
  methodOf(CapySpringMass);

  // Get the mass (adding them create them if necessary)
  CapySpringMassNode* masses[2];
  masses[0] = $(that, addMass)(idA);
  masses[1] = $(that, addMass)(idB);

  // Try to get the node in the graph
  CapyGraphLink* link = $(that->graph, getLinkBetweenNodes)(idA, idB);

  // If a link with the given ids doesn't exist
  if(link == NULL) {

    // Create the link in the graph
    $(that->graph, linkNodes)(idA, idB);

    // Get the new link
    link = $(that->graph, getLinkBetweenNodes)(idA, idB);
    if(link) {

      // Create the spring data
      CapySpringMassLink* spring = NULL;
      safeMalloc(spring, 1);
      link->data = spring;

      // Initialise the spring data
      spring->restLength = 1.0;
      loop(i, 2) {
        spring->coeffs[i] = 1.0;
        spring->masses[i] = masses[i];
      }
      spring->data = NULL;
    }
  }

  // Return the spring (link's user data)
  return (link == NULL ? NULL : (CapySpringMassLink*)(link->data));
}

// Get a spring given its masses id
// Input:
//   idA: the first mass id
//   idB: the second mass id
// Output:
//   Return the spring, or NULL if there is no spring with the requested ids
static CapySpringMassLink* GetSpring(
  size_t const idA,
  size_t const idB) {
  methodOf(CapySpringMass);
  CapyGraphLink* link = $(that->graph, getLinkBetweenNodes)(idA, idB);
  return (link == NULL ? NULL : (CapySpringMassLink*)(link->data));
}

// Step the spring mass system
// Input:
//   deltaT: size of the step
// Output:
//   The mass properties are updated.
static void Step(double const deltaT) {
  methodOf(CapySpringMass);

  // Reset all the masses' acceleration
  forEach(node, that->graph->nodes->iter) {
    CapySpringMassNode* mass = (CapySpringMassNode*)(node.data);
    loop(iDim, that->dim) mass->acc.vals[iDim] = 0.0;
  }

  // Loop on the links
  forEach(link, that->graph->links->iter) {

    // Get the spring
    CapySpringMassLink* spring = (CapySpringMassLink*)(link.data);

    // Get the current length of the spring
    double const length = CapyVecGetDistance(
      &(spring->masses[0]->pos),
      &(spring->masses[1]->pos));

    // Get the reaction force
    double const force =
      spring->coeffs[(length > spring->restLength)] *
      (length - spring->restLength);

    // Apply the reaction force to the acceleration of the masses
    loop(iMass, 2) if(spring->masses[iMass]->flagFixed == false) {
      loop(iDim, that->dim) {
        spring->masses[iMass]->acc.vals[iDim] +=
          force / spring->masses[iMass]->mass * (
            spring->masses[1 - iMass]->pos.vals[iDim] -
            spring->masses[iMass]->pos.vals[iDim]) / length;
      }
    }
  }

  // Set the deltaT in the RungeKutta instance
  that->rk->deltaT = deltaT;

  // Variable for calculation
  double vals[1 + 2 * that->dim];
  vals[0] = 0.0;

  // Loop on the masses
  forEach(node, that->graph->nodes->iter) {

    // Get the mass
    CapySpringMassNode* mass = (CapySpringMassNode*)(node.data);

    // Apply the acceleration to the mass position and speed using RK4 and
    // correct for dampening
    if(mass->flagFixed == false) {
      loop(iDim, that->dim) {
        vals[1 + iDim] = mass->pos.vals[iDim];
        vals[1 + iDim + that->dim] = mass->speed.vals[iDim] * that->dampening;
        that->deriv->accs[iDim] = mass->acc.vals[iDim];
      }
      $(that->rk, step)(vals);
      loop(iDim, that->dim) {
        mass->pos.vals[iDim] = vals[1 + iDim];
        mass->speed.vals[iDim] = vals[1 + iDim + that->dim];
      }
    }
  }
}

// Step the spring mass system until it stabilizes
// Input:
//   deltaT: size of the step
// Output:
//   The mass properties are updated.
static void StepToStableState(double const deltaT) {
  methodOf(CapySpringMass);
  size_t step = 0;
  bool isStable = false;
  while(step < that->nbMaxStep && isStable == false) {
    $(that, step)(deltaT);
    isStable = true;
    forEach(node, that->graph->nodes->iter) {
      CapySpringMassNode* mass = (CapySpringMassNode*)(node.data);
      loop(iDim, that->dim) {
        if(
          fabs(mass->speed.vals[iDim]) > that->epsilon ||
          fabs(mass->acc.vals[iDim]) > that->epsilon
        ) {
          isStable = false;
          $(&(that->graph->nodes->iter), toLast)();
        }
      }
    }
    step += 1;
  }
}

// Get the stress of the system
// Output:
//   Return the total of difference between rest length and actual length
//   of all links
static double GetStress(void) {
  methodOf(CapySpringMass);

  // Variable to memorise the stress
  double stress = 0.0;

  // Loop on the links
  forEach(link, that->graph->links->iter) {

    // Get the spring
    CapySpringMassLink* spring = (CapySpringMassLink*)(link.data);

    // Get the current length of the spring
    double const length = CapyVecGetDistance(
      &(spring->masses[0]->pos),
      &(spring->masses[1]->pos));

    // Update the stress with the difference between actual length and
    // rest length
    stress += fabs(length - spring->restLength);
  }

  // Return the stress
  return stress;
}

// Get the stress of a link
// Input:
//   idA: the first mass id
//   idB: the second mass id
// Output:
//   Return the difference between rest length and actual length
static double GetSpringStress(
  size_t const idA,
  size_t const idB) {
  methodOf(CapySpringMass);

  // Get the spring
  CapySpringMassLink* spring = $(that, getSpring)(idA, idB);

  // Get the current length of the spring
  double const length = CapyVecGetDistance(
    &(spring->masses[0]->pos),
    &(spring->masses[1]->pos));

  // Get the stress with the difference between actual length and
  // rest length
  double const stress = fabs(length - spring->restLength);

  // Return the stress
  return stress;
}

// Get the current length of a link
// Input:
//   idA: the first mass id
//   idB: the second mass id
// Output:
//   Return the current length of the link
static double GetSpringLength(
  size_t const idA,
  size_t const idB) {
  methodOf(CapySpringMass);

  // Get the spring
  CapySpringMassLink* spring = $(that, getSpring)(idA, idB);

  // Get the current length of the spring
  double const length = CapyVecGetDistance(
    &(spring->masses[0]->pos),
    &(spring->masses[1]->pos));

  // Return the length
  return length;
}

// Free the memory used by a CapySpringMass
static void Destruct(void) {
  methodOf(CapySpringMass);
  forEach(node, that->graph->nodes->iter) {
    CapySpringMassNode* mass = (CapySpringMassNode*)(nodePtr->data);
    CapyVecDestruct(&(mass->pos));
    CapyVecDestruct(&(mass->speed));
    CapyVecDestruct(&(mass->acc));
    free(nodePtr->data);
    nodePtr->data = NULL;
  }
  forEach(link, that->graph->links->iter) {
    free(linkPtr->data);
    linkPtr->data = NULL;
  }
  CapyGraphFree(&(that->graph));
  $(that->deriv, destruct)();
  free(that->deriv);
  CapyRungeKuttaFree(&(that->rk));
}

// Evaluation function for the derivative
// Input:
//   in: time, position, speed
//   out: acceleration
// Output:
//   Update 'out'
static void EvalDeriv(
  double const* const in,
        double* const out) {
  methodOf(CapySpringMassDeriv);
  (void)in;
  loop(iOut, that->dimOut) out[iOut] = that->accs[iOut];
}

// Free the memory used by a CapySpringMassDeriv
// Input:
//   that: the CapySpringMassDeriv to free
static void CapySpringMassDerivDestruct(void) {
  methodOf(CapySpringMassDeriv);
  $(that, destructCapyMathFun)();
  free(that->accs);
}

// Create a derivative for the spring mass system
// Input:
//   dim: the dimension of the system
// Output:
//   Return the derivative
static CapySpringMassDeriv* CapySpringMassDerivAlloc(size_t const dim) {
  CapySpringMassDeriv* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  CapyInherits(*that, CapyMathFun, (1 + dim * 2, dim));
  safeMalloc(that->accs, dim);
  that->eval = EvalDeriv;
  that->destruct = CapySpringMassDerivDestruct;
  return that;
}

// Create a CapySpringMass
// Input:
//   dim: space dimension
// Output:
//   Return a CapySpringMass
CapySpringMass CapySpringMassCreate(size_t const dim) {
  CapySpringMass that = {
    .dim = dim,
    .dampening = 1.0,
    .graph = CapyGraphAlloc(),
    .nbMaxStep = 1000,
    .epsilon = 1e-6,
    .destruct = Destruct,
    .addMass = AddMass,
    .getMass = GetMass,
    .addSpring = AddSpring,
    .getSpring = GetSpring,
    .getSpringStress = GetSpringStress,
    .getSpringLength = GetSpringLength,
    .step = Step,
    .stepToStableState = StepToStableState,
    .getStress = GetStress,
  };
  that.deriv = CapySpringMassDerivAlloc(dim);
  that.rk = CapyRungeKuttaAlloc((CapyMathFun*)(that.deriv), 2);
  return that;
}

// Allocate memory for a new CapySpringMass and create it
// Input:
//   dim: space dimension
// Output:
//   Return a CapySpringMass
// Exception:
//   May raise CapyExc_MallocFailed.
CapySpringMass* CapySpringMassAlloc(size_t const dim) {
  CapySpringMass* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapySpringMassCreate(dim);
  return that;
}

// Free the memory used by a CapySpringMass* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapySpringMass to free
void CapySpringMassFree(CapySpringMass** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}
