// ---------------------------------- btree.h ---------------------------------
/*
    LibCapy - a general purpose library of C functions and elem structures
    Copyright (C) 2021-2025 Pascal Baillehache info@baillehachepascal.dev
    https://baillehachepascal.dev
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.
    You should have received a copy of the GNU General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef CAPY_BTREE_H
#define CAPY_BTREE_H
#include "externalHeaders.h"
#include "cext.h"

// Description:
// BTree class.

// Generic b-tree structure. Declaration macro for a b-tree
// structure named 'name', containing elem of type 'type'.
#define CapyDecBTree(name, type)                           \
typedef struct name name;                                  \
struct name {                                              \
  type* elems;                                             \
  size_t nbMaxElem;                                        \
  size_t nbElem;                                           \
  name* parent;                                            \
  name** childs;                                           \
  void (*destruct)(void);                                  \
  int (*cmp)(                                              \
    type const* const elemA,                               \
    type const* const elemB);                              \
  void (*add)(type const elem);                            \
  void (*remove)(type const elem);                         \
  void (*removeElem)(size_t const iElem);                  \
  type* (*find)(type const elem);                          \
  name* (*findNode)(type const elem);                      \
  name* (*getLeafNodeForElem)(type const elem);            \
  void (*split)(void);                                     \
  name* (*insertNode)(name const node);                    \
  void (*mergeChildren)(void);                             \
  void (*rebalance)(void);                                 \
};                                                         \
name name ## Create(                                       \
  size_t const nb,                                         \
            int(*cmp)(type const*, type const*));          \
name* name ## Alloc(                                       \
  size_t const nb,                                         \
            int(*cmp)(type const*, type const*));          \
void name ## Destruct(void);                               \
void name ## Free(name** const that);                      \
name* name ## GetLeafNodeForElem(type const elem);         \
void name ## Split(void);                                  \
void name ## Add(type const elem);                         \
void name ## Remove(type const elem);                      \
void name ## RemoveElem(size_t const iElem);               \
name* name ## InsertNode(name const node);                 \
type* name ## Find(type const elem);                       \
name* name ## FindNode(type const elem);                   \
void name ## MergeChildren(void);                          \
void name ## Rebalance(void);

// Generic b-tree structure. Definition macro for a b-tree
// structure named 'name', containing elem of type 'type'
// Create a b-tree
// Input:
//   nb: max number of element per node
//   cmp: comparison function to sort elem
// Output:
//   Return a b-tree containing elem of type 'type'
#define CapyDefBTreeCreate(name, type)                      \
name name ## Create(                                        \
  size_t const nb,                                          \
           int(*cmp)(type const*, type const*)) {           \
  name that =  {                                            \
    .nbMaxElem = nb,                                        \
    .destruct = name ## Destruct,                           \
    .cmp = cmp,                                             \
    .add = name ## Add,                                     \
    .remove = name ## Remove,                               \
    .removeElem = name ## RemoveElem,                       \
    .find = name ## Find,                                   \
    .findNode = name ## FindNode,                           \
    .getLeafNodeForElem = name ## GetLeafNodeForElem,       \
    .split = name ## Split,                                 \
    .insertNode = name ## InsertNode,                       \
    .mergeChildren = name ## MergeChildren,                 \
    .rebalance = name ## Rebalance,                         \
  };                                                        \
  safeMalloc(that.elems, nb);                               \
  loop(i, nb) that.elems[i] = (type){0};                    \
  safeMalloc(that.childs, nb + 1);                          \
  loop(i, nb + 1) that.childs[i] = NULL;                    \
  return that;                                              \
}

// Allocate memory for a new b-tree and create it
// Input:
//   nb: max number of element per node
//   cmp: comparison function to sort elem
// Output:
//   Return a b-tree containing elem of type 'type'
// Exception:
//   May raise CapyExc_MallocFailed.
#define CapyDefBTreeAlloc(name, type)                        \
name* name ## Alloc(                                         \
  size_t const nb,                                           \
           int(*cmp)(type const*, type const*)) {            \
  name* that = NULL;                                         \
  safeMalloc(that, 1);                                       \
  if(!that) return NULL;                                     \
  *that = name ## Create(nb, cmp);                           \
  return that;                                               \
}

// Free the memory used by a b-tree.
// Input:
//   that: the b-tree to free
#define CapyDefBTreeDestruct(name, type)                     \
void name ## Destruct(void) {                                \
  name* that = (name*)capyThat;                              \
  loop(iElem, that->nbElem) {                                \
    if(that->elems[iElem].destruct != NULL) {                \
      $(that->elems + iElem, destruct)();                    \
    }                                                        \
  }                                                          \
  loop(iChild, that->nbElem + 1) {                           \
    if(that->childs[iChild] != NULL) {                       \
      name ## Free(that->childs + iChild);                   \
    }                                                        \
  }                                                          \
  free(that->elems);                                         \
  free(that->childs);                                        \
  *that = (name){0};                                         \
}

// Free the memory used by a pointer to a b-tree and reset '*that' to NULL
// Input:
//   that: a pointer to the b-tree to free
#define CapyDefBTreeFree(name, type)             \
void name ## Free(name** const that) {           \
  if(that == NULL || *that == NULL) return;      \
  $(*that, destruct)();                          \
  free(*that);                                   \
  *that = NULL;                                  \
}

// Get the leaf node where an element should be
// Input:
//   elem: the element
// Output:
//   Traverse the BTree and return the first leaf that accomodate the element.
#define CapyDefBTreeGetLeafNodeForElem(name, type)           \
name* name ## GetLeafNodeForElem(type const elem) {          \
  name* that = (name*)capyThat;                              \
  size_t iElem = 0;                                          \
  while(                                                     \
    iElem < that->nbElem &&                                  \
    (that->cmp)(that->elems + iElem, &elem) < 0              \
  ) {                                                        \
    iElem += 1;                                              \
  }                                                          \
  if(that->childs[iElem] != NULL) {                          \
    return $(that->childs[iElem], getLeafNodeForElem)(elem); \
  } else {                                                   \
    return that;                                             \
  }                                                          \
}

// Split the root node in two
// Output:
//   If the root node has less than three elements nothing happens. Else,
//   the elements of the root node are split into two new nodes, one for
//   those lower than the median element in the root and the other for those
//   greater than the median element in the root. The median element remains
//   in the root and the new nodes become its prior and posterior childs.
#define CapyDefBTreeSplit(name, type)                                         \
void name ## Split(void) {                                                    \
  name* that = (name*)capyThat;                                               \
  if(that->nbElem < 3) return;                                                \
  size_t iMedian = that->nbElem / 2;                                          \
  name* priorNode = name ## Alloc(that->nbMaxElem, that->cmp);                \
  name* posteriorNode = name ## Alloc(that->nbMaxElem, that->cmp);            \
  priorNode->parent = that;                                                   \
  posteriorNode->parent = that;                                               \
  loop(iElem, iMedian) {                                                      \
    priorNode->elems[iElem] = that->elems[iElem];                             \
    priorNode->childs[iElem] = that->childs[iElem];                           \
    that->childs[iElem] = NULL;                                               \
    if(priorNode->childs[iElem] != NULL) {                                    \
      priorNode->childs[iElem]->parent = priorNode;                           \
    }                                                                         \
  }                                                                           \
  priorNode->childs[iMedian] = that->childs[iMedian];                         \
  that->childs[iMedian] = NULL;                                               \
  if(priorNode->childs[iMedian] != NULL) {                                    \
    priorNode->childs[iMedian]->parent = priorNode;                           \
  }                                                                           \
  priorNode->nbElem = iMedian;                                                \
  loop(iElem, that->nbElem - iMedian - 1) {                                   \
    posteriorNode->elems[iElem] = that->elems[iMedian + 1 + iElem];           \
    posteriorNode->childs[iElem] = that->childs[iMedian + 1 + iElem];         \
    that->childs[iMedian + 1 + iElem] = NULL;                                 \
    if(posteriorNode->childs[iElem] != NULL) {                                \
      posteriorNode->childs[iElem]->parent = posteriorNode;                   \
    }                                                                         \
  }                                                                           \
  posteriorNode->childs[that->nbElem - iMedian - 1] =                         \
    that->childs[that->nbElem];                                               \
  that->childs[that->nbElem] = NULL;                                          \
  if(posteriorNode->childs[that->nbElem - iMedian - 1] != NULL) {             \
    posteriorNode->childs[that->nbElem - iMedian - 1]->parent =               \
      posteriorNode;                                                          \
  }                                                                           \
  posteriorNode->nbElem = that->nbElem - iMedian - 1;                         \
  that->elems[0] = that->elems[iMedian];                                      \
  that->nbElem = 1;                                                           \
  that->childs[0] = priorNode;                                                \
  that->childs[1] = posteriorNode;                                            \
}

// Insert a node into the elements of the tree
// Input:
//   node: the node to insert
// Output:
//   The node is inserted. This may trigger a restructuration of the BTree
//   which invalidate the 'that' pointer. Then, this method returns 'that' if
//   it's still valid, or its parent if 'that' doesn't exist any more.
#define CapyDefBTreeInsertNode(name, type)                             \
name* name ## InsertNode(name const node) {                            \
  name* that = (name*)capyThat;                                        \
  if(that->nbElem >= that->nbMaxElem) {                                \
    $(that, split)();                                                  \
    if((that->cmp)(that->elems, node.elems) > 0) {                     \
      $(that->childs[0], insertNode)(node);                            \
    } else {                                                           \
      $(that->childs[1], insertNode)(node);                            \
    }                                                                  \
    if(that->parent != NULL) {                                         \
      size_t iChild = 0;                                               \
      while(iChild <= that->parent->nbElem) {                          \
        if(that->parent->childs[iChild] == that) {                     \
          that->parent->childs[iChild] = NULL;                         \
        }                                                              \
        iChild += 1;                                                   \
      }                                                                \
      name* newThat = $(that->parent, insertNode)(*that);              \
      free(that->childs);                                              \
      free(that->elems);                                               \
      free(that);                                                      \
      that = newThat;                                                  \
    }                                                                  \
  } else {                                                             \
    size_t iElem = 0;                                                  \
    while(                                                             \
      iElem < that->nbElem &&                                          \
      (that->cmp)(that->elems + iElem, node.elems) < 0                 \
    ) {                                                                \
      iElem += 1;                                                      \
    }                                                                  \
    if(that->childs[iElem] == NULL) {                                  \
      size_t jElem = that->nbElem;                                     \
      while(jElem > iElem) {                                           \
        that->childs[jElem + 1] = that->childs[jElem];                 \
        that->childs[jElem] = that->childs[jElem - 1];                 \
        that->elems[jElem] = that->elems[jElem - 1];                   \
        jElem -= 1;                                                    \
      }                                                                \
      that->elems[iElem] = node.elems[0];                              \
      that->childs[iElem] = node.childs[0];                            \
      if(that->childs[iElem] != NULL) {                                \
        that->childs[iElem]->parent = that;                            \
      }                                                                \
      that->childs[iElem + 1] = node.childs[1];                        \
      if(that->childs[iElem + 1] != NULL) {                            \
        that->childs[iElem + 1]->parent = that;                        \
      }                                                                \
      that->nbElem += 1;                                               \
    } else {                                                           \
      that = $(that->childs[iElem], insertNode)(node);                 \
    }                                                                  \
  }                                                                    \
  return that;                                                         \
}

// Add an element to the b-tree
// Input:
//   elem: the element to add
#define CapyDefBTreeAdd(name, type)                                  \
void name ## Add(type const elem) {                                  \
  name* that = (name*)capyThat;                                      \
  name node = name ## Create(that->nbMaxElem, that->cmp);            \
  node.elems[0] = elem;                                              \
  node.nbElem = 1;                                                   \
  name* leaf = $(that, getLeafNodeForElem)(elem);                    \
  leaf = $(leaf, insertNode)(node);                                  \
  $(&node, destruct)();                                              \
}

// Rebalance the tree to ensure it has the required minimum number of elements
// Output:
//   The tree is rebalanced.
#define CapyDefBTreeRebalance(name, type)                              \
void name ## Rebalance(void) {                                         \
  name* that = (name*)capyThat;                                        \
  if(that->parent == NULL) return;                                     \
  size_t iChildOfThatInParent = 0;                                     \
  while(                                                               \
    iChildOfThatInParent <= that->parent->nbElem &&                    \
    that->parent->childs[iChildOfThatInParent] != that                 \
  ) {                                                                  \
    iChildOfThatInParent += 1;                                         \
  }                                                                    \
  size_t nbElemInSiblings[2] = {0, 0};                                 \
  if(                                                                  \
    iChildOfThatInParent > 0 &&                                        \
    that->parent->childs[iChildOfThatInParent - 1] != NULL             \
  ) {                                                                  \
    nbElemInSiblings[0] =                                              \
      that->parent->childs[iChildOfThatInParent - 1]->nbElem;          \
  }                                                                    \
  if(                                                                  \
    iChildOfThatInParent < that->parent->nbElem &&                     \
    that->parent->childs[iChildOfThatInParent + 1] != NULL             \
  ) {                                                                  \
    nbElemInSiblings[1] =                                              \
      that->parent->childs[iChildOfThatInParent + 1]->nbElem;          \
  }                                                                    \
  size_t iGiver = 0;                                                   \
  if(nbElemInSiblings[1] > nbElemInSiblings[0]) iGiver = 1;            \
  if(nbElemInSiblings[iGiver] > that->nbMaxElem / 2) {                 \
    size_t iElemTaken = iChildOfThatInParent + iGiver - 1;             \
    type valTaken = that->parent->elems[iElemTaken];                   \
    size_t iSiblingGiver =                                             \
      iChildOfThatInParent + 2 * iGiver - 1;                           \
    name* siblingGiver = that->parent->childs[iSiblingGiver];          \
    if(iGiver == 0) {                                                  \
      size_t iElem = that->nbElem;                                     \
      while(iElem > 0) {                                               \
        that->elems[iElem] = that->elems[iElem - 1];                   \
        that->childs[iElem + 1] = that->childs[iElem];                 \
        iElem -= 1;                                                    \
      }                                                                \
      that->childs[1] = that->childs[0];                               \
      that->nbElem += 1;                                               \
      that->elems[0] = valTaken;                                       \
      that->childs[0] = siblingGiver->childs[siblingGiver->nbElem];    \
      if(that->childs[0] != NULL) {                                    \
        that->childs[0]->parent = that;                                \
      }                                                                \
      that->parent->elems[iElemTaken] =                                \
        siblingGiver->elems[siblingGiver->nbElem - 1];                 \
      siblingGiver->childs[siblingGiver->nbElem] = NULL;               \
      siblingGiver->nbElem -= 1;                                       \
    } else {                                                           \
      that->elems[that->nbElem] = valTaken;                            \
      that->childs[that->nbElem + 1] = siblingGiver->childs[0];        \
      if(that->childs[that->nbElem + 1] != NULL) {                     \
        that->childs[that->nbElem + 1]->parent = that;                 \
      }                                                                \
      that->nbElem += 1;                                               \
      that->parent->elems[iElemTaken] = siblingGiver->elems[0];        \
      siblingGiver->nbElem -= 1;                                       \
      loop(iElem, siblingGiver->nbElem) {                              \
        siblingGiver->elems[iElem] = siblingGiver->elems[iElem + 1];   \
        siblingGiver->childs[iElem] = siblingGiver->childs[iElem + 1]; \
      }                                                                \
      siblingGiver->childs[siblingGiver->nbElem] =                     \
        siblingGiver->childs[siblingGiver->nbElem + 1];                \
      siblingGiver->childs[siblingGiver->nbElem + 1] = NULL;           \
    }                                                                  \
  } else {                                                             \
    if(0) {                                                            \
      printf(                                                          \
        "TODO: 'that' and 'that->parent' and 'siblinGiver' should be " \
        "merged together into one single node to keep the minimum "    \
        "number of element in a node. Cf issue #36.\n");               \
    }                                                                  \
  }                                                                    \
}

// Merge the children of a node into that node.
// Output:
//   The children are merged into the node if possible. If not, nothing happen.
#define CapyDefBTreeMergeChildren(name, type)                          \
void name ## MergeChildren(void) {                                     \
  name* that = (name*)capyThat;                                        \
  size_t nbElemInChildren = 0;                                         \
  size_t iChild = 0;                                                   \
  while(iChild <= that->nbElem) {                                      \
    if(that->childs[iChild] != NULL) {                                 \
      nbElemInChildren += that->childs[iChild]->nbElem;                \
    }                                                                  \
    iChild += 1;                                                       \
  }                                                                    \
  if(                                                                  \
    nbElemInChildren + that->nbElem > that->nbMaxElem ||               \
    nbElemInChildren == 0                                              \
  ) {                                                                  \
    return;                                                            \
  }                                                                    \
  name newNode = name ## Create(that->nbMaxElem, that->cmp);           \
  newNode.parent = that->parent;                                       \
  iChild = 0;                                                          \
  while(iChild <= that->nbElem) {                                      \
    name* child = that->childs[iChild];                                \
    if(child != NULL) {                                                \
      loop(iElem, child->nbElem) {                                     \
        newNode.elems[newNode.nbElem] = child->elems[iElem];           \
        newNode.childs[newNode.nbElem] = child->childs[iElem];         \
        if(newNode.childs[newNode.nbElem] != NULL) {                   \
          newNode.childs[newNode.nbElem]->parent = that;               \
        }                                                              \
        newNode.nbElem += 1;                                           \
      }                                                                \
      newNode.childs[newNode.nbElem] = child->childs[child->nbElem];   \
      if(newNode.childs[newNode.nbElem] != NULL) {                     \
        newNode.childs[newNode.nbElem]->parent = that;                 \
      }                                                                \
      free(that->childs[iChild]->elems);                               \
      free(that->childs[iChild]->childs);                              \
      free(that->childs[iChild]);                                      \
    }                                                                  \
    if(iChild < that->nbElem) {                                        \
      newNode.elems[newNode.nbElem] = that->elems[iChild];             \
      newNode.nbElem += 1;                                             \
    }                                                                  \
    iChild += 1;                                                       \
  }                                                                    \
  *that = newNode;                                                     \
}

// Remove an element at a given position in a node of a b-tree
// Input:
//   iElem: the position of the element to remove
// Output:
//   The element is removed.
#define CapyDefBTreeRemoveElem(name, type)                       \
void name ## RemoveElem(size_t const iElem) {                    \
  name* that = (name*)capyThat;                                  \
  if(iElem >= that->nbElem) return;                              \
  if(that->childs[iElem] != NULL) {                              \
    name* child = that->childs[iElem];                           \
    while(child->childs[child->nbElem] != NULL) {                \
      child = child->childs[child->nbElem];                      \
    }                                                            \
    that->elems[iElem] = child->elems[child->nbElem - 1];        \
    $(child, removeElem)(child->nbElem - 1);                     \
  } else if(that->childs[iElem + 1] != NULL) {                   \
    name* child = that->childs[iElem + 1];                       \
    while(child->childs[0] != NULL) child = child->childs[0];    \
    that->elems[iElem] = child->elems[0];                        \
    $(child, removeElem)(0);                                     \
  } else {                                                       \
    size_t jElem = iElem;                                        \
    while(jElem < that->nbElem - 1) {                            \
      that->elems[jElem] = that->elems[jElem + 1];               \
      that->childs[jElem + 1] = that->childs[jElem + 2];         \
      jElem += 1;                                                \
    }                                                            \
    that->nbElem -= 1;                                           \
    if(that->nbElem < that->nbMaxElem / 2) {                     \
      $(that, rebalance)();                                      \
    }                                                            \
    if(that->nbElem == 0 && that->parent != NULL) {              \
      size_t iChildThatInParent = 0;                             \
      while(that->parent->childs[iChildThatInParent] != that) {  \
        iChildThatInParent += 1;                                 \
      }                                                          \
      name ## Free(that->parent->childs + iChildThatInParent);   \
      that = NULL;                                               \
    }                                                            \
  }                                                              \
}

// Remove an element from the b-tree
// Input:
//   elem: the element to remove
// Output:
//   The element is removed if it exists, else nothing happens. If the
//   element is removed it is destruct
#define CapyDefBTreeRemove(name, type)                                      \
void name ## Remove(type const val) {                                       \
  name* that = (name*)capyThat;                                             \
  if(that->nbElem == 0) return;                                             \
  size_t iElem = 0;                                                         \
  while(iElem < that->nbElem) {                                             \
    int cmpVal = (that->cmp)(that->elems + iElem, &val);                    \
    if(cmpVal == 0) {                                                       \
      if(that->elems[iElem].destruct != NULL) {                             \
        $(that->elems + iElem, destruct)();                                 \
      }                                                                     \
      $(that, removeElem)(iElem);                                           \
      return;                                                               \
    } else if(cmpVal > 0) {                                                 \
      if(that->childs[iElem] != NULL) {                                     \
        $(that->childs[iElem], remove)(val);                                \
      }                                                                     \
      return;                                                               \
    } else if(cmpVal < 0) {                                                 \
      iElem += 1;                                                           \
    }                                                                       \
  }                                                                         \
  if(that->childs[iElem] != NULL) {                                         \
    $(that->childs[iElem], remove)(val);                                    \
  }                                                                         \
}

// Find an element in the b-tree
// Input:
//   elem: the element to find
// Output:
//   If the element is found return a pointer to it, else return null.
#define CapyDefBTreeFind(name, type)                      \
type* name ## Find(type const val) {                      \
  name* that = (name*)capyThat;                           \
  if(that->nbElem == 0) return NULL;                      \
  size_t iElem = 0;                                       \
  while(iElem < that->nbElem) {                           \
    int cmpVal = (that->cmp)(that->elems + iElem, &val);  \
    if(cmpVal == 0) {                                     \
      return that->elems + iElem;                         \
    } else if(cmpVal > 0) {                               \
      if(that->childs[iElem] != NULL) {                   \
        return $(that->childs[iElem], find)(val);         \
      } else {                                            \
        return NULL;                                      \
      }                                                   \
    } else if(cmpVal < 0) {                               \
      iElem += 1;                                         \
    }                                                     \
  }                                                       \
  if(that->childs[iElem] != NULL) {                       \
    return $(that->childs[iElem], find)(val);             \
  } else {                                                \
    return NULL;                                          \
  }                                                       \
}

// Find a node containing an element in the b-tree
// Input:
//   elem: the element to find
// Output:
//   If the element is actually in the tree, return the node containing the
//   element, else return NULL.
#define CapyDefBTreeFindNode(name, type)                  \
name* name ## FindNode(type const val) {                  \
  name* that = (name*)capyThat;                           \
  if(that->nbElem == 0) return NULL;                      \
  size_t iElem = 0;                                       \
  while(iElem < that->nbElem) {                           \
    int cmpVal = (that->cmp)(that->elems + iElem, &val);  \
    if(cmpVal == 0) {                                     \
      return that;                                        \
    } else if(cmpVal > 0) {                               \
      if(that->childs[iElem] != NULL) {                   \
        return $(that->childs[iElem], findNode)(val);     \
      } else {                                            \
        return NULL;                                      \
      }                                                   \
    } else if(cmpVal < 0) {                               \
      iElem += 1;                                         \
    }                                                     \
  }                                                       \
  if(that->childs[iElem] != NULL) {                       \
    return $(that->childs[iElem], findNode)(val);         \
  } else {                                                \
    return NULL;                                          \
  }                                                       \
}

// Definition macro calling all the submacros at once for a BTree containing
// elements of type 'type'.
#define CapyDefBTree(name, type)                 \
  CapyDefBTreeCreate(name, type)                 \
  CapyDefBTreeAlloc(name, type)                  \
  CapyDefBTreeDestruct(name, type)               \
  CapyDefBTreeFree(name, type)                   \
  CapyDefBTreeAdd(name, type)                    \
  CapyDefBTreeRemove(name, type)                 \
  CapyDefBTreeRemoveElem(name, type)             \
  CapyDefBTreeGetLeafNodeForElem(name, type)     \
  CapyDefBTreeSplit(name, type)                  \
  CapyDefBTreeInsertNode(name, type)             \
  CapyDefBTreeFind(name, type)                   \
  CapyDefBTreeFindNode(name, type)               \
  CapyDefBTreeMergeChildren(name, type)          \
  CapyDefBTreeRebalance(name, type)
#endif
