// ---------------------------- quaternion.c ---------------------------
/*
    LibCapy - a general purpose library of C functions and data structures
    Copyright (C) 2021-2025 Pascal Baillehache baillehache.pascal@gmail.com
    https://baillehachepascal.dev
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.
    You should have received a copy of the GNU General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "quaternion.h"

// Convert the quaternion 'that' to a 3x3 rotation matrix
// Input:
//   res: the result 3x3 rotation matrix
// Output:
//   'res' is updated.
static void ToRotMat(CapyMat* const res) {
  methodOf(CapyQuaternion);
  double x2 = that->vals[0] * 2.0;
  double y2 = that->vals[1] * 2.0;
  double z2 = that->vals[2] * 2.0;
  double xx = that->vals[0] * x2;
  double xy = that->vals[0] * y2;
  double xz = that->vals[0] * z2;
  double yy = that->vals[1] * y2;
  double yz = that->vals[1] * z2;
  double zz = that->vals[2] * z2;
  double wx = that->vals[3] * x2;
  double wy = that->vals[3] * y2;
  double wz = that->vals[3] * z2;
  res->vals[0] = 1.0 - (yy + zz);
  res->vals[3] = xy + wz;
  res->vals[6] = xz - wy;
  res->vals[1] = xy - wz;
  res->vals[4] = 1.0 - (xx + zz);
  res->vals[7] = yz + wx;
  res->vals[2] = xz + wy;
  res->vals[5] = yz - wx;
  res->vals[8] = 1.0 - (xx + yy);
}

// Calculate the quaternion equivalent to the rotation of 'that' followed by
// the rotation of 'tho'
// Inputs:
//   tho: the second quaternion
//   res: the result quaternion, can be same as 'that' or 'tho'
// Output:
//   'res' is updated.
static void Compose(
  CapyQuaternion const* const tho,
        CapyQuaternion* const res) {
  methodOf(CapyQuaternion);
  double r[4];
  r[0] =
    that->vals[0] * tho->vals[3] + that->vals[3] * tho->vals[0] +
    that->vals[1] * tho->vals[2] - that->vals[2] * tho->vals[1];
  r[1] =
    that->vals[1] * tho->vals[3] + that->vals[3] * tho->vals[1] +
    that->vals[2] * tho->vals[0] - that->vals[0] * tho->vals[2];
  r[2] =
    that->vals[2] * tho->vals[3] + that->vals[3] * tho->vals[2] +
    that->vals[0] * tho->vals[1] - that->vals[1] * tho->vals[0];
  r[3] =
    that->vals[3] * tho->vals[3] - that->vals[0] * tho->vals[0] -
    that->vals[1] * tho->vals[1] - that->vals[2] * tho->vals[2];
  loop(i, 4) res->vals[i] = r[i];
}

// Calculate the quaternion equivalent to the rotation necessary to convert
// 'that' into 'tho'
// Inputs:
//   tho: the target quaternion
//   res: the result quaternion, can be same as 'that' or 'tho'
// Output:
//   'res' is updated. tho = compose(difference(that, tho), that)
static void Difference(
  CapyQuaternion const* const tho,
        CapyQuaternion* const res) {
  methodOf(CapyQuaternion);
  $(that, inverse)(res);
  $(tho, compose)(res, res);
  if(res->vals[3] < 0.0) loop(i, 4) res->vals[i] *= -1.0;
}

// Calculate the inverse quaternion of the quaternion 'that'
// Input:
//   res: the result quaternion, can be same as 'that'
// Output:
//   'res' is updated
static void Inverse(CapyQuaternion* const res) {
  methodOf(CapyQuaternion);
  loop(i, 3) res->vals[i] = that->vals[i] * -1.0;
  res->vals[3] = that->vals[3];
}

// Return true if 'that' and 'tho' are equals, false else
// Input:
//   tho: the quaternion to compare to
// Output:
//   Return true if the quaternions are equal
static bool IsEqualTo(CapyQuaternion const* const tho) {
  methodOf(CapyQuaternion);
  bool res =
    equald(that->vals[0], tho->vals[0]) &&
    equald(that->vals[1], tho->vals[1]) &&
    equald(that->vals[2], tho->vals[2]) &&
    equald(that->vals[3], tho->vals[3]);
  return res;
}

// Rotate the vector 'v' by the quaternion 'that'
// Inputs:
//   v: the vector to be rotated
//   res: the result vector, can be same as 'v'
// Output:
//   'res' is updated.
static void Apply(
  double const* const v,
        double* const res) {
  methodOf(CapyQuaternion);
  CapyQuaternion p = CapyQuaternionCreate();
  loop(i, 3) p.vals[i] = v[i];
  p.vals[3] = 0.0;
  CapyQuaternion inv = CapyQuaternionCreate();
  $(that, compose)(&p, &p);
  $(that, inverse)(&inv);
  $(&p, compose)(&inv, &p);
  loop(i, 3) res[i] = p.vals[i];
  $(&p, destruct)();
  $(&inv, destruct)();
}

// Normalise the quaternion
// Output:
//   The quaternion is normalised
static void Normalise(void) {
  methodOf(CapyQuaternion);
  double n =
    1.0 / sqrt(
      that->vals[0] * that->vals[0] + that->vals[1] * that->vals[1] +
      that->vals[2] * that->vals[2] + that->vals[3] * that->vals[3]);
  loop(i, 4) that->vals[i] *= n;
}

// Get the rotation axis of the quaternion 'that'
// Input:
//   res: the result rotation axis
// Output:
//   'res' is set to the rotation axis
static void GetRotAxis(double* const res) {
  methodOf(CapyQuaternion);
  double s = sqrt(1.0 - that->vals[3]) * that->vals[3];
  loop(i, 3) res[i] = that->vals[i] / s;
  double n = 1.0 / sqrt(res[0] * res[0] + res[1] * res[1] + res[2] * res[2]);
  loop(i, 3) res[i] *= n;
}

// Get the rotation angle (in radians) of the quaternion 'that'
// Output:
//   Return the rotation angle
static double GetRotAngle(void) {
  methodOf(CapyQuaternion);
  return acos(that->vals[3]) * 2.0;
}

// Free the memory used by a CapyQuaternion
static void Destruct(void) {
  return;
}

// Create a CapyQuaternion
// Output:
//   Return a CapyQuaternion
CapyQuaternion CapyQuaternionCreate(void) {
  CapyQuaternion that = {
    .vals = {0.0, 0.0, 0.0, 1.0},
    .destruct = Destruct,
    .toRotMat = ToRotMat,
    .compose = Compose,
    .difference = Difference,
    .inverse = Inverse,
    .isEqualTo = IsEqualTo,
    .apply = Apply,
    .normalise = Normalise,
    .getRotAxis = GetRotAxis,
    .getRotAngle = GetRotAngle,
  };
  return that;
}

// Allocate memory for a new CapyQuaternion and create it
// Output:
//   Return a CapyQuaternion
// Exception:
//   May raise CapyExc_MallocFailed.
CapyQuaternion* CapyQuaternionAlloc(void) {
  CapyQuaternion* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyQuaternionCreate();
  return that;
}

// Create a new static quaternion from the 3x3 rotation matrix 'rotMat'
CapyQuaternion CapyQuaternionCreateFromRotMat(CapyMat const* const rotMat) {
  CapyQuaternion that = {
    .vals = {0.0, 0.0, 0.0, 1.0},
    .destruct = Destruct,
    .toRotMat = ToRotMat,
    .compose = Compose,
    .difference = Difference,
    .inverse = Inverse,
    .isEqualTo = IsEqualTo,
    .apply = Apply,
    .normalise = Normalise,
    .getRotAxis = GetRotAxis,
    .getRotAngle = GetRotAngle,
  };
  double sumDiag = 1.0 + rotMat->vals[0] + rotMat->vals[4] + rotMat->vals[8];
  if (sumDiag > 0.0) {
    double s = sqrt(sumDiag) * 2.0;
    that.vals[0] = (rotMat->vals[7] - rotMat->vals[5]) / s;
    that.vals[1] = (rotMat->vals[2] - rotMat->vals[6]) / s;
    that.vals[2] = (rotMat->vals[3] - rotMat->vals[1]) / s;
    that.vals[3] = 0.25 * s;
  } else {
    if (
      rotMat->vals[0] > rotMat->vals[4] &&
      rotMat->vals[0] > rotMat->vals[0]
    ) {
      double s =
        sqrt(1.0 + rotMat->vals[0] - rotMat->vals[4] - rotMat->vals[8]) * 2.0;
      that.vals[0] = 0.25 * s;
      that.vals[1] = (rotMat->vals[3] + rotMat->vals[1]) / s;
      that.vals[2] = (rotMat->vals[2] + rotMat->vals[6]) / s;
      that.vals[3] = (rotMat->vals[7] - rotMat->vals[5]) / s;
    } else if(rotMat->vals[4] > rotMat->vals[8]) {
      double s =
        sqrt(1.0 - rotMat->vals[0] + rotMat->vals[4] - rotMat->vals[8]) * 2.0;
      that.vals[0] = (rotMat->vals[3] + rotMat->vals[1]) / s;
      that.vals[1] = 0.25 * s;
      that.vals[2] = (rotMat->vals[7] + rotMat->vals[5]) / s;
      that.vals[3] = (rotMat->vals[2] - rotMat->vals[6]) / s;
    } else {
      double s =
        sqrt(1.0 - rotMat->vals[0] - rotMat->vals[4] + rotMat->vals[8]) * 2.0;
      that.vals[0] = (rotMat->vals[2] + rotMat->vals[6]) / s;
      that.vals[1] = (rotMat->vals[7] + rotMat->vals[5]) / s;
      that.vals[2] = 0.25 * s;
      that.vals[3] = (rotMat->vals[3] - rotMat->vals[1]) / s;
    }
  }
  return that;
}

// Allocate memory and create a new Quaternion from the 3x3 rotation matrix
// 'rotMat'
CapyQuaternion* CapyQuaternionAllocFromRotMat(CapyMat const* const rotMat) {
  CapyQuaternion* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyQuaternionCreateFromRotMat(rotMat);
  return that;
}

// Create a new static quaternion corresponding to the rotation around
// 'axis' (must be normalized) by 'theta' (in radians)
CapyQuaternion CapyQuaternionCreateFromRotAxis(
  double const* const axis,
         double const theta) {
  CapyQuaternion that = {
    .vals = {0.0, 0.0, 0.0, 1.0},
    .destruct = Destruct,
    .toRotMat = ToRotMat,
    .compose = Compose,
    .difference = Difference,
    .inverse = Inverse,
    .isEqualTo = IsEqualTo,
    .apply = Apply,
    .normalise = Normalise,
    .getRotAxis = GetRotAxis,
    .getRotAngle = GetRotAngle,
  };
  double s = sin(theta / 2.0);
  that.vals[0] = axis[0] * s;
  that.vals[1] = axis[1] * s;
  that.vals[2] = axis[2] * s;
  that.vals[3] = cos(theta / 2.0);
  return that;
}

// Allocate memory and create a new Quaternion corresponding to the rotation
// around 'axis' (must be normalized) by 'theta' (in radians)
CapyQuaternion* CapyQuaternionAllocFromRotAxis(
  double const* const axis,
         double const theta) {
  CapyQuaternion* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyQuaternionCreateFromRotAxis(axis, theta);
  return that;
}

// Create a new static quaternion corresponding to the rotation bringing the
// vector 'from' to the vector 'to'
CapyQuaternion CapyQuaternionCreateRotFromVecToVec(
  double const* const from,
  double const* const to) {
  CapyVec normFrom = CapyVecCreateLocal3D;
  CapyVec normTo = CapyVecCreateLocal3D;
  loop(i, 3) {
    normFrom.vals[i] = from[i];
    normTo.vals[i] = to[i];
  }
  CapyVecNormalise(&normFrom);
  CapyVecNormalise(&normTo);
  double d = 0.0;
  CapyVecDot(&normFrom, &normTo, &d);
  if(fabs(d - 1.0) < 1e-6) {
    return CapyQuaternionCreateFromRotAxis((double[3]){1.0, 0.0, 0.0}, 0.0);
  }
  if(fabs(d + 1.0) < 1e-6) {
    CapyVec ortho = CapyVecCreateLocal3D;
    CapyVec3DGetOrtho(normFrom.vals, ortho.vals);
    return CapyQuaternionCreateFromRotAxis(ortho.vals, M_PI);
  }
  CapyQuaternion that = CapyQuaternionCreate();
  CapyVec u = CapyVecCreateLocal3D;
  CapyVecCross(&normFrom, &normTo, &u);
  CapyVecDot(&u, &u, that.vals + 3);
  that.vals[3] = d + sqrt(d * d + that.vals[3]);
  loop(i, 3) that.vals[i] = u.vals[i];
  $(&that, normalise)();
  return that;
}

// Allocate memory and create a new quaternion corresponding to the rotation
// bringing the vector 'from' to the vector 'to'
CapyQuaternion* CapyQuaternionAllocRotFromVecToVec(
  double const* const from,
  double const* const to) {
  CapyQuaternion* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyQuaternionCreateRotFromVecToVec(from, to);
  return that;
}


// Free the memory used by a CapyQuaternion* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyQuaternion to free
void CapyQuaternionFree(CapyQuaternion** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}
