// ------------------------------ kmeans.c ------------------------------
/*
    LibCapy - a general purpose library of C functions and data structures
    Copyright (C) 2021-2025 Pascal Baillehache baillehache.pascal@gmail.com
    https://baillehachepascal.dev
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.
    You should have received a copy of the GNU General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "kmeans.h"
#include "random.h"

// Initialise the clusters' center for the K-mean algorithm (rand version)
// Input:
//      that: the KMeans
//   cluster: the point cloud
// Output:
//   The cluster point cloud is updated with initial values
static void InitRand(
            CapyKMeans* const that,
  CapyPointCloud const* const pointCloud) {
  $(pointCloud, updateRange)();
  CapyRandom rand = CapyRandomCreate(pointCloud->size);
  loop(iCluster, that->clusters->size) loop(iDim, pointCloud->dim) {
    that->clusters->points[iCluster].vals[iDim] =
      $(&rand, getDoubleRangeFast)(pointCloud->range + iDim);
  }
  $(&rand, destruct)();
}

// Initialise the clusters' center for the K-mean algorithm (Forgy version)
// Input:
//      that: KMeans
//   cluster: the point cloud
// Output:
//   The cluster point cloud is updated with initial values
static void InitForgy(
            CapyKMeans* const that,
  CapyPointCloud const* const pointCloud) {
  CapyRandom rand = CapyRandomCreate(pointCloud->size);
  size_t iPoint = 0;
  loop(iCluster, that->clusters->size) {
    CapyRangeSize rangeNextPoint = (CapyRangeSize){
      .vals = {iPoint, pointCloud->size - that->clusters->size + iCluster}
    };
    iPoint = $(&rand, getSizeRange)(&rangeNextPoint);
    loop(iDim, pointCloud->dim) {
      that->clusters->points[iCluster].vals[iDim] =
        pointCloud->points[iPoint].vals[iDim];
    }
    ++iPoint;
  }
  $(&rand, destruct)();
}

// Function to update the discrete distribution in InitPlusPlus
static void InitPlusPlusUpdateDistribution(
      CapyKMeans const* const that,
                 size_t const iCluster,
      CapyDistDiscrete* const dist,
  CapyPointCloud const* const pointCloud) {
  size_t size = that->clusters->size;
  that->clusters->size = iCluster + 1;
  loop(iPoint, pointCloud->size) {
    CapyKMeansClusterOfPoint cluster =
      $(that, getClusterOfPoint)(pointCloud->points + iPoint);
    CapyDistDiscreteOccurence occ = {
      .prob = pow(cluster.dist, 2.0),
      .evt = {.id = iPoint},
    };
    $(dist->occ, set)(iPoint, &occ);
  }
  that->clusters->size = size;
}

// Initialise the clusters' center for the K-mean algorithm (plusplus version)
// Input:
//      that: the KMeans
//   cluster: the point cloud
// Output:
//   The cluster point cloud is updated with initial values
static void InitPlusPlus(
            CapyKMeans* const that,
  CapyPointCloud const* const pointCloud) {
  CapyRandom rand = CapyRandomCreate(pointCloud->size);
  CapyDistDiscrete dist = CapyDistDiscreteCreate(pointCloud->size);
  loop(iCluster, that->clusters->size) {
    if(iCluster == 0) {
      CapyRangeSize rangePoint = (CapyRangeSize){
        .vals = {0, pointCloud->size - 1}
      };
      size_t iPoint = $(&rand, getSizeRange)(&rangePoint);
      loop(iDim, pointCloud->dim) {
        that->clusters->points[iCluster].vals[iDim] =
          pointCloud->points[iPoint].vals[iDim];
      }
      InitPlusPlusUpdateDistribution(that, iCluster, &dist, pointCloud);
    } else {
      CapyDistEvt evt = $(&rand, getDistEvt)((CapyDist*)&dist);
      loop(iDim, pointCloud->dim) {
        that->clusters->points[iCluster].vals[iDim] =
          pointCloud->points[evt.id].vals[iDim];
      }
      InitPlusPlusUpdateDistribution(that, iCluster, &dist, pointCloud);
    }
  }
  $(&rand, destruct)();
  $(&dist, destruct)();
}

// Cluster a point cloud.
// Input:
//   pointCloud: the point cloud to be clustered
//            k: the number of clusters
// Output:
//   Update that->clusters.
// Exception:
//   May raise CapyExc_MallocFailed.
static void Run(
  CapyPointCloud const* const pointCloud,
                 size_t const k) {
  methodOf(CapyKMeans);

  // Ensure the clusters doesn't exist already
  CapyPointCloudFree(&(that->clusters));

  // If there is no cluster requested, nothing to do.
  if(k == 0) return;

  // Create the clusters
  that->clusters = CapyPointCloudAlloc(pointCloud->dim);
  safeMalloc(that->clusters->points, k);
  if(!(that->clusters->points)) return;
  that->clusters->size = k;
  loop(i, k) that->clusters->points[i] = CapyVecCreate(pointCloud->dim);

  // Initialise the centers
  if(that->typeInit == capyKMeanInit_rand) {
    InitRand(that, pointCloud);
  } else if(that->typeInit == capyKMeanInit_forgy) {
    InitForgy(that, pointCloud);
  } else if(that->typeInit == capyKMeanInit_plusplus) {
    InitPlusPlus(that, pointCloud);
  }

  // Create a temporary variables for calculation
  CapyPointCloud clusters = CapyPointCloudCreate(that->clusters->dim);
  safeMalloc(clusters.points, that->clusters->size);
  if(!clusters.points) return;
  clusters.size = that->clusters->size;
  loop(iCluster, that->clusters->size) {
    clusters.points[iCluster] = CapyVecCreate(clusters.dim);
  }
  size_t* nb = NULL;
  safeMalloc(nb, that->clusters->size);
  if(!(that->clusters->size)) return;

  // Apply the K-mean algorithm
  bool flagContinue = true;
  while(flagContinue) {
    loop(iCluster, clusters.size) {
      loop(iDim, clusters.dim) {
        clusters.points[iCluster].vals[iDim] = 0.0;
      }
      nb[iCluster] = 0;
    }
    loop(iPoint, pointCloud->size) {
      CapyKMeansClusterOfPoint cluster =
        $(that, getClusterOfPoint)(pointCloud->points + iPoint);
      loop(iDim, pointCloud->dim) {
        clusters.points[cluster.id].vals[iDim] +=
          pointCloud->points[iPoint].vals[iDim];
      }
      ++(nb[cluster.id]);
    }
    flagContinue = false;
    loop(iCluster, clusters.size) {
      loop(iDim, clusters.dim) {
        double v = clusters.points[iCluster].vals[iDim] / (double)nb[iCluster];
        if (
          fabs(that->clusters->points[iCluster].vals[iDim] - v) > DBL_EPSILON
        ) {
          flagContinue = true;
          that->clusters->points[iCluster].vals[iDim] = v;
        }
      }
    }
  }

  // Free memory
  $(&clusters, destruct)();
  free(nb);
}

// Get the cluster of a given point
// Input:
//   v: the point
// Output:
//   Return the cluster id and distance to the cluster center for the given
// point
static CapyKMeansClusterOfPoint GetClusterOfPoint(CapyVec const* const v) {
  methodOf(CapyKMeans);
  double distMin = 0.0;
  size_t id = 0;
  loop(iCluster, that->clusters->size) {
    double dist = 0.0;
    loop(iDim, v->dim) {
      dist +=
        pow(v->vals[iDim] - that->clusters->points[iCluster].vals[iDim], 2.0);
    }
    if(iCluster == 0) distMin = dist;
    else if(distMin > dist) {
      distMin = dist;
      id = iCluster;
    }
  }
  return (CapyKMeansClusterOfPoint){id, distMin};
}

// Apply the clustering to an image color
// Input:
//   img: the image
// Output:
//   The image colors are replaced with the color of their cluster)
static void ApplyToImgColor(CapyImg* const img) {
  methodOf(CapyKMeans);
  forEach(pixel, img->iter) {
    CapyVec v = {.dim = 4, .vals = pixel.color->vals};
    CapyKMeansClusterOfPoint cluster = $(that, getClusterOfPoint)(&v);
    loop(iDim, 4) {
      pixel.color->vals[iDim] = that->clusters->points[cluster.id].vals[iDim];
    }
  }
}

// Free the memory used by a CapyKMeans
static void Destruct(void) {
  methodOf(CapyKMeans);
  CapyPointCloudFree(&(that->clusters));
}

// Create a CapyKMeans
// Output:
//   Return a CapyKMeans
CapyKMeans CapyKMeansCreate(void) {
  return (CapyKMeans){
    .typeInit = capyKMeanInit_plusplus,
    .clusters = NULL,
    .destruct = Destruct,
    .run = Run,
    .getClusterOfPoint = GetClusterOfPoint,
    .applyToImgColor = ApplyToImgColor,
  };
}

// Allocate memory for a new CapyKMeans and create it
// Output:
//   Return a CapyKMeans
// Exception:
//   May raise CapyExc_MallocFailed.
CapyKMeans* CapyKMeansAlloc(void) {
  CapyKMeans* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyKMeansCreate();
  return that;
}

// Free the memory used by a CapyKMeans* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyKMeans to free
void CapyKMeansFree(CapyKMeans** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}
