Commit 08d5c77d authored by Michael Meyer's avatar Michael Meyer
Browse files

reference implementation by Castryck et al

parent c00c37e0
all:
@cc \
-std=c99 -pedantic \
-Wall -Wextra \
-O2 -funroll-loops \
rng.c \
u512.s fp.s \
mont.c \
csidh.c \
main.c \
-o main
debug:
cc \
-std=c99 -pedantic \
-Wall -Wextra \
-g \
rng.c \
u512.s fp.s \
mont.c \
csidh.c \
main.c \
-o main
clean:
rm -f main
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <time.h>
#include <assert.h>
#include "u512.h"
#include "fp.h"
#include "mont.h"
#include "csidh.h"
#include <inttypes.h>
static __inline__ uint64_t rdtsc(void)
{
uint32_t hi, lo;
__asm__ __volatile__ ("rdtsc" : "=a"(lo), "=d"(hi));
return lo | (uint64_t) hi << 32;
}
unsigned long its = 10000;
int main()
{
clock_t t0, t1, time = 0;
uint64_t c0, c1, cycles = 0;
private_key priv;
public_key pub = base;
for (unsigned long i = 0; i < its; ++i) {
csidh_private(&priv);
t0 = clock();
c0 = rdtsc();
/**************************************/
assert(validate(&pub));
action(&pub, &pub, &priv);
/**************************************/
c1 = rdtsc();
t1 = clock();
cycles += c1 - c0;
time += t1 - t0;
}
printf("iterations: %lu\n", its);
printf("clock cycles: %" PRIu64 "\n", (uint64_t) cycles / its);
printf("wall-clock time: %.3lf ms\n", 1000. * time / CLOCKS_PER_SEC / its);
}
#include <string.h>
#include <assert.h>
#include "csidh.h"
#include "rng.h"
/* specific to p, should perhaps be somewhere else */
const unsigned primes[num_primes] = {
3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59,
61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137,
139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227,
229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313,
317, 331, 337, 347, 349, 353, 359, 367, 373, 587,
};
const u512 four_sqrt_p = {{
0x85e2579c786882cf, 0x4e3433657e18da95, 0x850ae5507965a0b3, 0xa15bc4e676475964,
}};
const public_key base = {0}; /* A = 0 */
void csidh_private(private_key *priv)
{
memset(&priv->e, 0, sizeof(priv->e));
for (size_t i = 0; i < num_primes; ) {
int8_t buf[64];
randombytes(buf, sizeof(buf));
for (size_t j = 0; j < sizeof(buf); ++j) {
if (buf[j] <= max_exponent && buf[j] >= -max_exponent) {
priv->e[i / 2] |= (buf[j] & 0xf) << i % 2 * 4;
if (++i >= num_primes)
break;
}
}
}
}
/* compute [(p+1)/l] P for all l in our list of primes. */
/* divide and conquer is much faster than doing it naively,
* but uses more memory. */
static void cofactor_multiples(proj *P, const proj *A, size_t lower, size_t upper)
{
assert(lower < upper);
if (upper - lower == 1)
return;
size_t mid = lower + (upper - lower + 1) / 2;
u512 cl = u512_1, cu = u512_1;
for (size_t i = lower; i < mid; ++i)
u512_mul3_64(&cu, &cu, primes[i]);
for (size_t i = mid; i < upper; ++i)
u512_mul3_64(&cl, &cl, primes[i]);
xMUL(&P[mid], A, &P[lower], &cu);
xMUL(&P[lower], A, &P[lower], &cl);
cofactor_multiples(P, A, lower, mid);
cofactor_multiples(P, A, mid, upper);
}
/* never accepts invalid keys. */
bool validate(public_key const *in)
{
const proj A = {in->A, fp_1};
do {
proj P[num_primes];
fp_random(&P->x);
P->z = fp_1;
/* maximal 2-power in p+1 */
xDBL(P, &A, P);
xDBL(P, &A, P);
cofactor_multiples(P, &A, 0, num_primes);
u512 order = u512_1;
for (size_t i = num_primes - 1; i < num_primes; --i) {
/* we only gain information if [(p+1)/l] P is non-zero */
if (memcmp(&P[i].z, &fp_0, sizeof(fp))) {
u512 tmp;
u512_set(&tmp, primes[i]);
xMUL(&P[i], &A, &P[i], &tmp);
if (memcmp(&P[i].z, &fp_0, sizeof(fp)))
/* P does not have order dividing p+1. */
return false;
u512_mul3_64(&order, &order, primes[i]);
if (u512_sub3(&tmp, &four_sqrt_p, &order)) /* returns borrow */
/* order > 4 sqrt(p), hence definitely supersingular */
return true;
}
}
/* P didn't have big enough order to prove supersingularity. */
} while (1);
}
/* compute x^3 + Ax^2 + x */
static void montgomery_rhs(fp *rhs, fp const *A, fp const *x)
{
fp tmp;
*rhs = *x;
fp_sq1(rhs);
fp_mul3(&tmp, A, x);
fp_add2(rhs, &tmp);
fp_add2(rhs, &fp_1);
fp_mul2(rhs, x);
}
/* totally not constant-time. */
void action(public_key *out, public_key const *in, private_key const *priv)
{
u512 k[2];
u512_set(&k[0], 4); /* maximal 2-power in p+1 */
u512_set(&k[1], 4); /* maximal 2-power in p+1 */
uint8_t e[2][num_primes];
for (size_t i = 0; i < num_primes; ++i) {
int8_t t = (int8_t) (priv->e[i / 2] << i % 2 * 4) >> 4;
if (t > 0) {
e[0][i] = t;
e[1][i] = 0;
u512_mul3_64(&k[1], &k[1], primes[i]);
}
else if (t < 0) {
e[1][i] = -t;
e[0][i] = 0;
u512_mul3_64(&k[0], &k[0], primes[i]);
}
else {
e[0][i] = 0;
e[1][i] = 0;
u512_mul3_64(&k[0], &k[0], primes[i]);
u512_mul3_64(&k[1], &k[1], primes[i]);
}
}
proj A = {in->A, fp_1};
bool done[2] = {false, false};
do {
assert(!memcmp(&A.z, &fp_1, sizeof(fp)));
proj P;
fp_random(&P.x);
P.z = fp_1;
fp rhs;
montgomery_rhs(&rhs, &A.x, &P.x);
bool sign = !fp_issquare(&rhs);
if (done[sign])
continue;
xMUL(&P, &A, &P, &k[sign]);
done[sign] = true;
for (size_t i = 0; i < num_primes; ++i) {
if (e[sign][i]) {
u512 cof = u512_1;
for (size_t j = i + 1; j < num_primes; ++j)
if (e[sign][j])
u512_mul3_64(&cof, &cof, primes[j]);
proj K;
xMUL(&K, &A, &P, &cof);
if (memcmp(&K.z, &fp_0, sizeof(fp))) {
xISOG(&A, &P, &K, primes[i]);
if (!--e[sign][i])
u512_mul3_64(&k[sign], &k[sign], primes[i]);
}
}
done[sign] &= !e[sign][i];
}
fp_inv(&A.z);
fp_mul2(&A.x, &A.z);
A.z = fp_1;
} while (!(done[0] && done[1]));
out->A = A.x;
}
/* includes public-key validation. */
bool csidh(public_key *out, public_key const *in, private_key const *priv)
{
if (!validate(in)) {
fp_random(&out->A);
return false;
}
action(out, in, priv);
return true;
}
#ifndef CSIDH_H
#define CSIDH_H
#include "u512.h"
#include "fp.h"
#include "mont.h"
/* specific to p, should perhaps be somewhere else */
#define num_primes 74
#define max_exponent 5 /* (2*5+1)^74 is roughly 2^256 */
typedef struct private_key {
int8_t e[(num_primes + 1) / 2]; /* packed int4_t */
} private_key;
typedef struct public_key {
fp A; /* Montgomery coefficient: represents y^2 = x^3 + Ax^2 + x */
} public_key;
extern const public_key base;
void csidh_private(private_key *priv);
bool csidh(public_key *out, public_key const *in, private_key const *priv);
#endif
#ifndef FP_H
#define FP_H
#include "u512.h"
/* fp is in the Montgomery domain, so interpreting that
as an integer should never make sense.
enable compiler warnings when mixing up u512 and fp. */
typedef struct fp {
u512 x;
} fp;
extern const fp fp_0;
extern const fp fp_1;
void fp_set(fp *x, uint64_t y);
void fp_cswap(fp *x, fp *y, bool c);
void fp_enc(fp *x, u512 const *y); /* encode to Montgomery representation */
void fp_dec(u512 *x, fp const *y); /* decode from Montgomery representation */
void fp_add2(fp *x, fp const *y);
void fp_sub2(fp *x, fp const *y);
void fp_mul2(fp *x, fp const *y);
void fp_add3(fp *x, fp const *y, fp const *z);
void fp_sub3(fp *x, fp const *y, fp const *z);
void fp_mul3(fp *x, fp const *y, fp const *z);
void fp_sq1(fp *x);
void fp_sq2(fp *x, fp const *y);
void fp_inv(fp *x);
bool fp_issquare(fp const *x);
void fp_random(fp *x);
#endif
.intel_syntax noprefix
.section .rodata
.set pbits, 511
p:
.quad 0x1b81b90533c6c87b, 0xc2721bf457aca835, 0x516730cc1f0b4f25, 0xa7aac6c567f35507
.quad 0x5afbfcc69322c9cd, 0xb42d083aedc88c42, 0xfc8ab0d15e3e4c4a, 0x65b48e8f740f89bf
.global fp_0
fp_0: .quad 0, 0, 0, 0, 0, 0, 0, 0
.global fp_1
fp_1: /* 2^512 mod p */
.quad 0xc8fc8df598726f0a, 0x7b1bc81750a6af95, 0x5d319e67c1e961b4, 0xb0aa7275301955f1
.quad 0x4a080672d9ba6c64, 0x97a5ef8a246ee77b, 0x06ea9e5d4383676a, 0x3496e2e117e0ec80
/* (2^512)^2 mod p */
.r_squared_mod_p:
.quad 0x36905b572ffc1724, 0x67086f4525f1f27d, 0x4faf3fbfd22370ca, 0x192ea214bcc584b1
.quad 0x5dae03ee2f5de3d0, 0x1e9248731776b371, 0xad5f166e20e4f52d, 0x4ed759aea6f3917e
/* -p^-1 mod 2^64 */
.inv_min_p_mod_r:
.quad 0x66c1301f632e294d
.section .text
.global fp_copy
fp_copy:
cld
mov rcx, 8
rep movsq
ret
.global fp_set
fp_set:
push rdi
call u512_set
pop rdi
mov rsi, rdi
jmp fp_enc
.global fp_cswap
fp_cswap:
movzx rax, dl
neg rax
.set k, 0
.rept 8
mov rcx, [rdi + 8*k]
mov rdx, [rsi + 8*k]
mov r8, rcx
xor r8, rdx
and r8, rax
xor rcx, r8
xor rdx, r8
mov [rdi + 8*k], rcx
mov [rsi + 8*k], rdx
.set k, k+1
.endr
ret
.reduce_once:
push rbp
mov rbp, rdi
mov rdi, [rbp + 0]
sub rdi, [rip + p + 0]
mov rsi, [rbp + 8]
sbb rsi, [rip + p + 8]
mov rdx, [rbp + 16]
sbb rdx, [rip + p + 16]
mov rcx, [rbp + 24]
sbb rcx, [rip + p + 24]
mov r8, [rbp + 32]
sbb r8, [rip + p + 32]
mov r9, [rbp + 40]
sbb r9, [rip + p + 40]
mov r10, [rbp + 48]
sbb r10, [rip + p + 48]
mov r11, [rbp + 56]
sbb r11, [rip + p + 56]
setnc al
movzx rax, al
neg rax
.macro cswap2, r, m
xor \r, \m
and \r, rax
xor \m, \r
.endm
cswap2 rdi, [rbp + 0]
cswap2 rsi, [rbp + 8]
cswap2 rdx, [rbp + 16]
cswap2 rcx, [rbp + 24]
cswap2 r8, [rbp + 32]
cswap2 r9, [rbp + 40]
cswap2 r10, [rbp + 48]
cswap2 r11, [rbp + 56]
pop rbp
ret
.global fp_add3
fp_add3:
push rdi
call u512_add3
pop rdi
jmp .reduce_once
.global fp_add2
fp_add2:
mov rdx, rdi
jmp fp_add3
.global fp_sub3
fp_sub3:
push rdi
call u512_sub3
pop rdi
xor rsi, rsi
xor rdx, rdx
xor rcx, rcx
xor r8, r8
xor r9, r9
xor r10, r10
xor r11, r11
test rax, rax
cmovnz rax, [rip + p + 0]
cmovnz rsi, [rip + p + 8]
cmovnz rdx, [rip + p + 16]
cmovnz rcx, [rip + p + 24]
cmovnz r8, [rip + p + 32]
cmovnz r9, [rip + p + 40]
cmovnz r10, [rip + p + 48]
cmovnz r11, [rip + p + 56]
add [rdi + 0], rax
adc [rdi + 8], rsi
adc [rdi + 16], rdx
adc [rdi + 24], rcx
adc [rdi + 32], r8
adc [rdi + 40], r9
adc [rdi + 48], r10
adc [rdi + 56], r11
ret
.global fp_sub2
fp_sub2:
mov rdx, rdi
xchg rsi, rdx
jmp fp_sub3
/* Montgomery arithmetic */
.global fp_enc
fp_enc:
lea rdx, [rip + .r_squared_mod_p]
jmp fp_mul3
.global fp_dec
fp_dec:
lea rdx, [rip + u512_1]
jmp fp_mul3
.global fp_mul3
fp_mul3:
push rbp
push rbx
push r12
push r13
push r14
push r15
push rdi
mov rdi, rsi
mov rsi, rdx
xor r8, r8
xor r9, r9
xor r10, r10
xor r11, r11
xor r12, r12
xor r13, r13
xor r14, r14
xor r15, r15
xor rbp, rbp
/* flags are already cleared */
.macro MULSTEP, k, r0, r1, r2, r3, r4, r5, r6, r7, r8
mov rdx, [rsi + 0]
mulx rcx, rdx, [rdi + 8*\k]
add rdx, \r0
mulx rcx, rdx, [rip + .inv_min_p_mod_r]
xor rax, rax /* clear flags */
mulx rbx, rax, [rip + p + 0]
adox \r0, rax
mulx rcx, rax, [rip + p + 8]
adcx \r1, rbx
adox \r1, rax
mulx rbx, rax, [rip + p + 16]
adcx \r2, rcx
adox \r2, rax