#include #include #include #include #include #include uint64_t GCD( uint64_t a, uint64_t b) { uint8_t shift = 0; while(((a | b) & 1) == 0) { ++shift; a >>= 1; b >>= 1; } while((a & 1) == 0) a >>= 1; do { while((b & 1) == 0) b >>= 1; if(a > b) { uint64_t const t = b; b = a; a = t; } b -= a; } while(b != 0); return a << shift; } uint64_t LCM( uint64_t const a, uint64_t const b) { return (a * b) / GCD(a, b); } uint64_t GetPublicExplonent(uint64_t const lambda) { uint64_t const candidates[5] = {65537, 257, 17, 5, 3}; for(int iCandidate = 0; iCandidate < 5; iCandidate += 1) { if( candidates[iCandidate] < lambda && (candidates[iCandidate] % lambda) != 0 ) { return candidates[iCandidate]; } } return 3; } uint64_t GcdDecompositionUnsignedInput( uint64_t a, uint64_t b, int64_t* x, int64_t* y) { int64_t x1 = 1; int64_t y1 = 0; uint64_t a1 = a; int64_t x0 = 0; int64_t y0 = 1; uint64_t a2 = b; uint64_t q = 0; while (a2 != 0) { int64_t x2 = x0 - ((int64_t)q) * x1; int64_t y2 = y0 - ((int64_t)q) * y1; x0 = x1; y0 = y1; uint64_t a0 = a1; x1 = x2; y1 = y2; a1 = a2; q = a0 / a1; a2 = a0 - q * a1; } *x = x1; *y = y1; return a1; } typedef struct RSACipherKey { uint64_t n; uint64_t e; uint64_t d; } RSACipherKey; RSACipherKey GenerateKeys( uint64_t const p, uint64_t const q) { assert(UINT64_MAX / p >= q); RSACipherKey keys = {0}; keys.n = p * q; uint64_t const lambda = LCM(p - 1, q - 1); keys.e = GetPublicExplonent(lambda); int64_t x = 0; int64_t y = 0; uint64_t gcd = GcdDecompositionUnsignedInput(keys.e, lambda, &x, &y); assert(gcd == 1); if(x > 0) keys.d = (uint64_t)x; else keys.d = lambda - (uint64_t)(-x); return keys; } uint64_t MultMod( uint64_t a, uint64_t b, uint64_t c) { uint64_t res = 0; a = a % c; while (b > 0) { if (b % 2 == 1) res = (res + a) % c; a = (a * 2) % c; b /= 2; } return res % c; } uint64_t PowMod( uint64_t a, uint64_t b, uint64_t c) { uint64_t res = 1; uint64_t mask = 0x8000000000000000; while(!(b & mask)) mask >>= 1; while(mask) { res = MultMod(res, res, c); if(b & mask) { res = MultMod(res, a, c); } mask >>= 1; } return res; } uint64_t Cipher( uint64_t const message, RSACipherKey const key) { assert(message < key.n); return PowMod(message, key.e, key.n); } uint64_t Decipher( uint64_t const message, RSACipherKey const key) { assert(message < key.n); return PowMod(message, key.d, key.n); } void TestCipherDecipher( char const* const message, uint64_t const p, uint64_t const q) { printf("Test Cipher/Decipher (p=%lu, q=%lu) \"%s\"\n", p, q, message); RSACipherKey key = GenerateKeys(p, q); size_t const len = strlen(message); uint64_t cipheredMessage[len]; uint8_t decipheredMessage[len]; for(size_t i = 0; i < len; i += 1) { cipheredMessage[i] = Cipher((uint64_t)(message[i]), key); decipheredMessage[i] = (uint8_t)Decipher(cipheredMessage[i], key); assert(decipheredMessage[i] == message[i]); } printf("message: "); for(size_t i = 0; i < len; i += 1) { printf("%u ", message[i]); } printf("\n"); printf("cipher: "); for(size_t i = 0; i < len * 8; i += 1) { printf("%u ", ((uint8_t*)cipheredMessage)[i]); } printf("\n"); printf("decipher: "); for(size_t i = 0; i < len; i += 1) { printf("%u ", decipheredMessage[i]); } printf("\n"); } size_t GetBlockSize(uint64_t const modulus) { uint64_t size = 1; while(size < 8 && (1ULL << (8 * size)) < modulus) size += 1; assert(size <= 8); return (size_t)(size - 1); } bool IsBigEndian(void) { uint16_t const a = 0x0100; return (*((uint8_t*)&a) == 1); } typedef struct RSACipherData { size_t size; uint8_t* data; } RSACipherData; RSACipherData CipherBlock( RSACipherData const message, RSACipherKey const keys) { RSACipherData result = {0}; size_t const blockSizeIn = GetBlockSize(keys.n); size_t const blockSizeOut = blockSizeIn + 1; size_t nbBlock = message.size / blockSizeIn; size_t const nbByteLeft = message.size - nbBlock * blockSizeIn; nbBlock += (nbByteLeft > 0 ? 1 : 0); result.data = malloc(nbBlock * blockSizeOut); assert(result.data != NULL); result.size = nbBlock * blockSizeOut; for(size_t iBlock = 0; iBlock < nbBlock; iBlock += 1) { uint64_t inp = 0; uint8_t* p = (uint8_t*)&inp; if(IsBigEndian() && iBlock == nbBlock - 1 && nbByteLeft > 0) { p += blockSizeIn - nbByteLeft; } for( size_t iByte = 0; iByte < blockSizeIn; iByte += 1) { if(iBlock * blockSizeIn + iByte < message.size) { *p = message.data[iBlock * blockSizeIn + iByte]; p += 1; } } uint64_t const out = PowMod(inp, keys.e, keys.n); p = (uint8_t*)&out; if(IsBigEndian()) { p += sizeof(uint64_t) - blockSizeOut; } for(size_t iByte = 0; iByte < blockSizeOut; iByte += 1) { result.data[iBlock * blockSizeOut + iByte] = *p; p += 1; } } return result; } RSACipherData DecipherBlock( RSACipherData const message, RSACipherKey const keys) { RSACipherData result = {0}; size_t const blockSizeOut = GetBlockSize(keys.n); size_t const blockSizeIn = blockSizeOut + 1; size_t nbBlock = message.size / blockSizeIn; result.data = malloc(nbBlock * blockSizeOut); assert(result.data != NULL); result.size = nbBlock * blockSizeOut; for(size_t iBlock = 0; iBlock < nbBlock; iBlock += 1) { uint64_t inp = 0; uint8_t* p = (uint8_t*)&inp; if(IsBigEndian()) { p += sizeof(uint64_t) - blockSizeIn; } for(size_t iByte = 0; iByte < blockSizeIn; iByte += 1) { *p = message.data[iBlock * blockSizeIn + iByte]; p += 1; } uint64_t const out = PowMod(inp, keys.d, keys.n); p = (uint8_t*)&out; if(IsBigEndian()) { p += sizeof(uint64_t) - blockSizeOut; } for(size_t iByte = 0; iByte < blockSizeOut; iByte += 1) { result.data[iBlock * blockSizeOut + iByte] = *p; p += 1; } } return result; } void TestCipherDecipherBlock( char const* const str, uint64_t const p, uint64_t const q) { printf("Test Cipher/Decipher block (p=%lu, q=%lu) \"%s\"\n", p, q, str); RSACipherKey keys = GenerateKeys(p, q); RSACipherData message = { .data = (uint8_t*)str, .size = strlen(str), }; RSACipherData cipheredMessage = CipherBlock(message, keys); RSACipherData decipheredMessage = DecipherBlock(cipheredMessage, keys); printf("message: "); for(size_t i = 0; i < message.size; i += 1) { printf("%u ", message.data[i]); } printf("\n"); printf("cipher: "); for(size_t i = 0; i < cipheredMessage.size; i += 1) { printf("%u ", cipheredMessage.data[i]); } printf("\n"); printf("decipher: "); for(size_t i = 0; i < decipheredMessage.size; i += 1) { printf("%u ", decipheredMessage.data[i]); } printf("\n"); assert(memcmp(message.data, decipheredMessage.data, message.size) == 0); free(cipheredMessage.data); free(decipheredMessage.data); } int main() { char const* const str = "Hello world!"; TestCipherDecipher(str, 61, 53); TestCipherDecipher(str, 11922649, 74112287); TestCipherDecipherBlock(str, 61, 53); TestCipherDecipherBlock(str, 11922649, 74112287); return 0; }