// ------------------------- hierarchicalclustering.c ------------------------
/*
    LibCapy - a general purpose library of C functions and data structures
    Copyright (C) 2021-2025 Pascal Baillehache info@baillehachepascal.dev
    https://baillehachepascal.dev
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.
    You should have received a copy of the GNU General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "hierarchicalclustering.h"

// Clustering as a tree of dataset rows
CapyDefTree(CapyDatasetRowCluster, CapyListSize)

// Find the largest cluster in a tree of clusters
// Inputs:
//   clusters: the tree of clusters
//   mat: the dataset converted to a numerical
// Output:
//   Return the largest cluster.
static CapyDatasetRowCluster* FindLargestCluster(
  CapyDatasetRowCluster* const clusters,
          CapyMat const* const mat) {

  // Variable to memorise the largest cluster
  CapyDatasetRowCluster* largestCluster = NULL;
  double maxDist = 0.0;

  // Loop on the clusters
  forEach(cluster, clusters->iter) {

    // If the cluster contains at least two rows
    size_t const nbRow = $(&cluster, getSize)();
    if(nbRow >= 2) {

      // Loop on the pairs of row
      $(&cluster, initIterator)();
      CapyListSizeIterator iterA = cluster.iter;
      forEach(indexA, iterA) {
        CapyListSizeIterator iterB = iterA;
        $(&iterB, next)();
        while($(&iterB, isActive)()) {
          size_t indexB = iterB.curElem->data;

          // Get the distance between the two rows
          CapyVec u = {
            .dim = mat->nbCol - 1,
            .vals = mat->vals + indexA * mat->nbCol,
          };
          CapyVec v = {
            .dim = mat->nbCol - 1,
            .vals = mat->vals + indexB * mat->nbCol,
          };
          double const dist = CapyVecGetDistance(&u, &v);

          // Update the largest cluster
          if(largestCluster == NULL || maxDist < dist) {
            largestCluster = clusters->iter.steps->node;
            maxDist = dist;
          }

          // Move to the next row
          $(&iterB, next)();
        }
      }
    }
  }
  return largestCluster;
}

// Find the row with most dissimilarity in a cluster
// Input:
//   cluster: the cluster to be split
//   mat: the dataset converted to a numerical
// Output:
//   Return the row index in the cluster which has a maximum average distance
//   with other rows in the cluster
static size_t FindMostDissimilar(
  CapyDatasetRowCluster* const cluster,
          CapyMat const* const mat) {

  // Variables to memorise the most dissimilar row
  size_t mostDissimilarRow = 0;
  double maxDist = NAN;

  // Loop on the rows
  $(&(cluster->data), initIterator)();
  CapyListSizeIterator iterA = cluster->data.iter;
  forEach(indexA, iterA) {

    // Variable to memorise the average distance
    double avgDist = 0.0;

    // Loop on the other rows
    $(&(cluster->iter), reset)();
    CapyListSizeIterator iterB = cluster->data.iter;
    size_t nb = 0;
    forEach(indexB, iterB) if(indexA != indexB) {

      // Get the distance between the two rows
      CapyVec u = {
        .dim = mat->nbCol - 1,
        .vals = mat->vals + indexA * mat->nbCol,
      };
      CapyVec v = {
        .dim = mat->nbCol - 1,
        .vals = mat->vals + indexB * mat->nbCol,
      };

      // Update the average distance
      avgDist += CapyVecGetDistance(&u, &v);
      nb += 1;
    }

    // Update the average distance
    avgDist /= (double)(nb);

    // Update the mostDissimlar row
    if(isnan(maxDist) || maxDist < avgDist) {
      mostDissimilarRow = indexA;
      maxDist = avgDist;
    }
  }
  return mostDissimilarRow;
}

// Migrate rows from a cluster to another according to their similarity
// Input:
//   rowsFrom: the origin rows
//   rowsTo: the destination rows
//   mat: the dataset converted to a numerical
// Output:
//   Rows of 'rowsFrom' more similar to rows of 'rowsTo' than
//   those of 'rowsFrom' are migrated to 'rowsTo'.
static void MigrateRows(
   CapyListSize* const rowsFrom,
   CapyListSize* const rowsTo,
  CapyMat const* const mat) {

  // Variable to memorise the migration distance
  double distMax = 1.0;

  // Loop until the origin cluster has no more row to move or the migration
  // distance becomes negative (meaning the remaining rows are better
  // staying in the origin cluster because they are more similar to this one
  // than those in the destination cluster)
  size_t nbRowOldCluster = $(rowsFrom, getSize)();
  while(distMax >= 0.0 && nbRowOldCluster > 1) {

    // Variable to memorise the next row to migrate (the one most similar
    // to products in new cluster and most dissimilar to rows in old
    // cluster)
    size_t migrateRow = 0;

    // Reset the dist max
    distMax = NAN;

    // Loop on the rows of the origin cluster
    forEach(oldRow, rowsFrom->iter) {

      // Variable to calculate the average distance
      double sum = 0.0;
      size_t nb = 0;

      // Loop on the other rows of the origin cluster
      CapyListSizeIterator iterB = rowsFrom->iter;
      $(&iterB, reset)();
      forEach(otherOldRow, iterB) {
        if(otherOldRow != oldRow) {

          // Get the distance between the two rows
          CapyVec u = {
            .dim = mat->nbCol - 1,
            .vals = mat->vals + oldRow * mat->nbCol,
          };
          CapyVec v = {
            .dim = mat->nbCol - 1,
            .vals = mat->vals + otherOldRow * mat->nbCol,
          };

          // Update the average distance
          sum += CapyVecGetDistance(&u, &v);
          nb += 1;
        }
      }

      // Calculate the migration distance
      double dist = sum / (double)nb;

      // Reset the variables for the calculation of the average distance
      sum = 0.0;
      nb = 0;

      // Loop on the rows of the destination cluster
      forEach(newRow, rowsTo->iter) {

        // Get the distance between the two rows
        CapyVec u = {
          .dim = mat->nbCol - 1,
          .vals = mat->vals + oldRow * mat->nbCol,
        };
        CapyVec v = {
          .dim = mat->nbCol - 1,
          .vals = mat->vals + newRow * mat->nbCol,
        };

        // Update the average distance
        sum += CapyVecGetDistance(&u, &v);
        nb += 1;
      }

      // Update the migration distance
      dist -= sum / (double)nb;

      // Update the row with maximum migration distance (most similar
      // to rows in destination cluster and most dissimilar to rows in
      // origin cluster)
      if(isnan(distMax) || dist > distMax) {
        distMax = dist;
        migrateRow = oldRow;
      }
    }

    // If the migration distance is positive
    if(distMax >= 0.0) {

      // Migrate the product from the old to the new cluster
      $(rowsFrom, remove)(&capyComparatorSizeInc, &migrateRow);
      $(rowsTo, add)(migrateRow);
    }

    // Update the number of rows in the old cluster
    nbRowOldCluster = $(rowsFrom, getSize)();
  }
}

// Split the cluster a cluster (according diana algorithm)
// Input:
//   cluster: the cluster to be split
//   mat: the dataset converted to a numerical
// Output:
//   The most dissimilar row is moved to a new child cluster of 'cluster',
//   then similar rows are migrated to the new cluster. At the end if only one
//   row remains in 'cluster' it is moved to its own new cluster as well.
static void SplitClusterDiana(
  CapyDatasetRowCluster* const cluster,
          CapyMat const* const mat) {

  // Find the row with most dissimilarity in the cluster
  size_t const splittingRow = FindMostDissimilar(cluster, mat);

  // Move the row with most dissimilarity from the cluster to
  // a new cluster
  CapyListSize newCluster = CapyListSizeCreate();
  $(&newCluster, initIterator)();
  $(&newCluster, add)(splittingRow);
  $(&(cluster->data), remove)(&capyComparatorSizeInc, &splittingRow);

  // Migrate other rows
  MigrateRows(&(cluster->data), &newCluster, mat);

  // Add the new cluster to the childs of the cluster
  $(cluster, addChild)(&newCluster);

  // Move the remaining rows into a new child row
  $(cluster, addChild)(&(cluster->data));
  cluster->data = CapyListSizeCreate();
  $(&(cluster->data), initIterator)();
}

// Cluster a dataset using the DIANA algorithm
// Input:
//   dataset: the dataset to cluster
// Output:
//   Return the clustering as a tree of dataset's rows. The clustering
//   occurs based on input fields converted ot numerical values. The dataset
//   can have output fields.
static CapyDatasetRowCluster* DianaClustering(
  CapyDataset const* const dataset) {
  methodOf(CapyHierarchicalClustering);

  // Create the result clustering
  CapyListSize defaultCluster = CapyListSizeCreate();
  CapyDatasetRowCluster* clusters = CapyDatasetRowClusterAlloc(&defaultCluster);

  // Convert the dataset to a matrix
  CapyMat mat = $(dataset, cvtToMatForSingleCatPredictor)(0);

  // Initialise the first cluster with all rows
  loop(iRow, dataset->nbRow) {
    $(&(clusters->data), push)(iRow);
  }

  // Loop until the number of cluster equal the number of rows
  size_t nbCluster = $(clusters, getNbLeaf)();
  while(nbCluster < dataset->nbRow) {
    if(that->verboseStream) {
      fprintf(that->verboseStream, "%lu/%lu\n", nbCluster, dataset->nbRow);
      fflush(that->verboseStream);
    }

    // Find the cluster with maximum diameter
    CapyDatasetRowCluster* splitCluster = FindLargestCluster(clusters, &mat);

    // Split the cluster
    SplitClusterDiana(splitCluster, &mat);

    // Update the number of clusters
    nbCluster = $(clusters, getNbLeaf)();
  }

  // Free memory
  CapyMatDestruct(&mat);

  // Return the result clustering
  return clusters;
}

// Recursive function for the fast version of DIANA
static void FastDianaClusteringRec(
  CapyHierarchicalClustering* const that,
           CapyDataset const* const dataset,
               CapyMat const* const mat,
       CapyDatasetRowCluster* const clusters,
                      size_t* const nbCluster) {

  // If the cluster has at least two elements
  size_t const nbElem = $(&(clusters->data), getSize)();
  if(nbElem >= 2) {

    // Split the cluster
    SplitClusterDiana(clusters, mat);

    // Increment the number of clusters
    *nbCluster += 1;
    if(that->verboseStream) {
      fprintf(that->verboseStream, "%lu/%lu\n", *nbCluster, dataset->nbRow);
      fflush(that->verboseStream);
    }

    // Split the two childs
    FastDianaClusteringRec(that, dataset, mat, clusters->child, nbCluster);
    FastDianaClusteringRec(
      that, dataset, mat, clusters->child->brother, nbCluster);
  }
}

// Cluster a dataset using a slightly faster version of the DIANA algorithm
// Input:
//   dataset: the dataset to cluster
// Output:
//   Return the clustering as a tree of dataset's rows. The clustering
//   occurs based on input fields converted ot numerical values. The dataset
//   can have output fields.
static CapyDatasetRowCluster* FastDianaClustering(
  CapyDataset const* const dataset) {
  methodOf(CapyHierarchicalClustering);

  // Create the result clustering
  CapyListSize defaultCluster = CapyListSizeCreate();
  CapyDatasetRowCluster* clusters = CapyDatasetRowClusterAlloc(&defaultCluster);

  // Convert the dataset to a matrix
  CapyMat mat = $(dataset, cvtToMatForSingleCatPredictor)(0);

  // Initialise the first cluster with all rows
  loop(iRow, dataset->nbRow) {
    $(&(clusters->data), push)(iRow);
  }

  // Start the recursion
  size_t nbCluster = 1;
  FastDianaClusteringRec(that, dataset, &mat, clusters, &nbCluster);

  // Free memory
  CapyMatDestruct(&mat);

  // Return the result clustering
  return clusters;
}

// Free the memory used by a CapyHierarchicalClustering
static void Destruct(void) {
  return;
}

// Create a CapyHierarchicalClustering
// Output:
//   Return a CapyHierarchicalClustering
CapyHierarchicalClustering CapyHierarchicalClusteringCreate(void) {
  CapyHierarchicalClustering that = {
    .verboseStream = NULL,
    .destruct = Destruct,
    .dianaClustering = DianaClustering,
    .fastDianaClustering = FastDianaClustering,
  };
  return that;
}

// Allocate memory for a new CapyHierarchicalClustering and create it
// Output:
//   Return a CapyHierarchicalClustering
// Exception:
//   May raise CapyExc_MallocFailed.
CapyHierarchicalClustering* CapyHierarchicalClusteringAlloc(void) {
  CapyHierarchicalClustering* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyHierarchicalClusteringCreate();
  return that;
}

// Free the memory used by a CapyHierarchicalClustering* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyHierarchicalClustering to free
void CapyHierarchicalClusteringFree(CapyHierarchicalClustering** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}
