// ---------------------------- rsacipher.c ---------------------------
/*
    LibCapy - a general purpose library of C functions and data structures
    Copyright (C) 2021-2025 Pascal Baillehache info@baillehachepascal.dev
    https://baillehachepascal.dev
    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.
    You should have received a copy of the GNU General Public License
    along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "rsacipher.h"
#include "capymath.h"

// Get the public exponent for a given totient
// Input:
//   lambda: the totient
// Output:
//   Return the public exponent 'e' such as 2<e<n and e and n are coprime.
static uint64_t GetPublicExplonent(uint64_t const lambda) {
  uint64_t const candidates[5] = {65537, 257, 17, 5, 3};
  loop(iCandidate, 5) {
    if(
      candidates[iCandidate] < lambda &&
      (candidates[iCandidate] % lambda) != 0
    ) {
      return candidates[iCandidate];
    }
  }
  return 3;
}

// Generate keys
// Input:
//   p: the first prime number to be used for key generation
//   q: the second prime number to be used for key generation
// Output:
//   'that->keys' is updated.
static void GenerateKeys(
  uint64_t const p,
  uint64_t const q) {
  methodOf(CapyRSACipher);

  // If the multiplication p*q overflows
  if(UINT64_MAX / p < q) {

    // Raise an exception
    raiseExc(CapyExc_InvalidParameters);
    that->keys = (CapyRSACipherKey){0};
    return;
  }

  // Get the modulus
  that->keys.n = p * q;

  // Get the Carmichael totient
  uint64_t const lambda = CapyLcm(p - 1, q - 1);

  // Get the public exponent
  that->keys.e = GetPublicExplonent(lambda);

  // Get the modular multiplicate inverse
  int64_t x = 0;
  int64_t y = 0;
  uint64_t const gcd =
    CapyGcdDecompositionUnsignedInput(that->keys.e, lambda, &x, &y);

  // Should not happen, but...
  if(gcd != 1) {
    raiseExc(CapyExc_UndefinedExecution);
    that->keys = (CapyRSACipherKey){0};
    return;
  }

  // Ensure d is positive thanks to the modulo lambda
  if(x > 0) that->keys.d = (uint64_t)x;
  else that->keys.d = lambda - (uint64_t)(-x);
}

// Return the block size such as its integer representation is always smallest
// than the given modulus
// Input:
//   modulus: the modulus
// Output:
//   Return the size in byte
static size_t GetBlockSize(uint64_t const modulus) {
  uint64_t size = 1;
  while(size < 8 && (1ULL << (8 * size)) < modulus) size += 1;
  if(size > 8) {
    raiseExc(CapyExc_InvalidParameters);
    return 8;
  }
  return (size_t)(size - 1);
}

// Cipher a message
// Input:
//   message: the message to cipher
// Output:
//   Return the ciphered message.
static uint64_t Cipher(uint64_t const message) {
  methodOf(CapyRSACipher);
  if(message >= that->keys.n) {
    raiseExc(CapyExc_InvalidParameters);
    return 0;
  }
  return CapyPowMod(message, that->keys.e, that->keys.n);
}

// Decipher a message
// Input:
//   message: the message to decipher
// Output:
//   Return the deciphered message.
static uint64_t Decipher(uint64_t const message) {
  methodOf(CapyRSACipher);
  if(message >= that->keys.n) {
    raiseExc(CapyExc_InvalidParameters);
    return 0;
  }
  return CapyPowMod(message, that->keys.d, that->keys.n);
}

// Cipher a message divided into blocks of appropriate sizes
// Input:
//   message: the message to cipher
// Output:
//   Return the ciphered message.
static CapyRSACipherData CipherBlock(CapyRSACipherData const message) {
  methodOf(CapyRSACipher);

  // Variable to memorise the result
  CapyRSACipherData result = {0};

  // Get the size of a block in the original message
  size_t const blockSizeIn = GetBlockSize(that->keys.n);

  // Get the size of a block in the ciphered message
  size_t const blockSizeOut = blockSizeIn + 1;

  // Get the number of blocks
  size_t nbBlock = message.size / blockSizeIn;
  size_t const nbByteLeft = message.size - nbBlock * blockSizeIn;
  nbBlock += (nbByteLeft > 0 ? 1 : 0);

  // Allocate memory for the result
  safeMalloc(result.data, nbBlock * blockSizeOut);
  if(result.data == NULL) return result;
  result.size = nbBlock * blockSizeOut;

  // Loop on the blocks
  loop(iBlock, nbBlock) {

    // Get the input value for the block
    uint64_t in = 0;
    uint8_t* p = (uint8_t*)&in;
    if(CapyIsBigEndian() && iBlock == nbBlock - 1 && nbByteLeft > 0) {
      p += blockSizeIn - nbByteLeft;
    }
    loop(iByte, blockSizeIn) if(iBlock * blockSizeIn + iByte < message.size){
      *p = message.data[iBlock * blockSizeIn + iByte];
      p += 1;
    }

    // Cipher the input value
    uint64_t const out = CapyPowMod(in, that->keys.e, that->keys.n);
    p = (uint8_t*)&out;
    if(CapyIsBigEndian()) {
      p += sizeof(uint64_t) - blockSizeOut;
    }

    // Write the ciphered value in the result
    loop(iByte, blockSizeOut) {
      result.data[iBlock * blockSizeOut + iByte] = *p;
      p += 1;
    }
  }

  // Return the result
  return result;
}

// Decipher a message divided into blocks of appropriate sizes
// Input:
//   message: the message to decipher
// Output:
//   Return the deciphered message.
static CapyRSACipherData DecipherBlock(CapyRSACipherData const message) {
  methodOf(CapyRSACipher);

  // Variable to memorise the result
  CapyRSACipherData result = {0};

  // Get the size of a block in the deciphered message
  size_t const blockSizeOut = GetBlockSize(that->keys.n);

  // Get the size of a block in the ciphered
  size_t const blockSizeIn = blockSizeOut + 1;

  // Get the number of blocks
  size_t nbBlock = message.size / blockSizeIn;

  // Allocate memory for the result
  safeMalloc(result.data, nbBlock * blockSizeOut);
  if(result.data == NULL) return result;
  result.size = nbBlock * blockSizeOut;

  // Loop on the blocks
  loop(iBlock, nbBlock) {

    // Get the input value for the block
    uint64_t in = 0;
    uint8_t* p = (uint8_t*)&in;
    if(CapyIsBigEndian()) {
      p += sizeof(uint64_t) - blockSizeIn;
    }
    loop(iByte, blockSizeIn) {
      *p = message.data[iBlock * blockSizeIn + iByte];
      p += 1;
    }

    // Decipher the input value
    uint64_t const out = CapyPowMod(in, that->keys.d, that->keys.n);

    // Write the deciphered value in the result
    p = (uint8_t*)&out;
    if(CapyIsBigEndian()) {
      p += sizeof(uint64_t) - blockSizeOut;
    }
    loop(iByte, blockSizeOut) {
      result.data[iBlock * blockSizeOut + iByte] = *p;
      p += 1;
    }
  }

  // Return the result
  return result;
}

// Free the memory used by a CapyRSACipher
static void Destruct(void) {
  methodOf(CapyRSACipher);
  that->keys = (CapyRSACipherKey){0};
}

// Create a CapyRSACipher
// Output:
//   Return a CapyRSACipher
CapyRSACipher CapyRSACipherCreate(void) {
  CapyRSACipher that = {
    .destruct = Destruct,
    .generateKeys = GenerateKeys,
    .cipher = Cipher,
    .decipher = Decipher,
    .cipherBlock = CipherBlock,
    .decipherBlock = DecipherBlock,
  };
  return that;
}

// Allocate memory for a new CapyRSACipher and create it
// Output:
//   Return a CapyRSACipher
// Exception:
//   May raise CapyExc_MallocFailed.
CapyRSACipher* CapyRSACipherAlloc(void) {
  CapyRSACipher* that = NULL;
  safeMalloc(that, 1);
  if(!that) return NULL;
  *that = CapyRSACipherCreate();
  return that;
}

// Free the memory used by a CapyRSACipher* and reset '*that' to NULL
// Input:
//   that: a pointer to the CapyRSACipher to free
void CapyRSACipherFree(CapyRSACipher** const that) {
  if(that == NULL || *that == NULL) return;
  $(*that, destruct)();
  free(*that);
  *that = NULL;
}
