/*
 * elgamal-lib.c
 *
 * Copyright (C) 2012 Jerônimo C. Pellegrini
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 3 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to:
 *   The Free Software Foundation, Inc.,
 *   51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
 */


#include <stdio.h>
#include <gmp.h>
#include <stdlib.h>
#include <math.h>
#include <unistd.h>
#include <fcntl.h>
#include "elgamal.h"

/*
  Prints a message and exits with EXIT_FAILURE status.
*/  
void die (const char *msg) {
	printf (msg);
	exit(EXIT_FAILURE);
}

/*
  Sets r to a random GMP integer with the specified number
  of bits.
*/
void get_random_n_bits(mpz_t r, size_t bits)
{
	size_t size = (size_t) ceilf(bits/8);
	char *buffer = (char*) malloc(sizeof(char)*size);
	int prg = open("/dev/random", O_RDONLY);
	read(prg, buffer, size);
	close(prg);
	mpz_import (r, size,1,sizeof(char), 0,0, buffer);
	free(buffer);
}

/*
  Sets r to a random GMP integer smaller than max.
 */
void get_random_n (mpz_t r, mpz_t max) {
	do {
		get_random_n_bits(r,mpz_sizeinbase(max,2));
	} while (mpz_cmp(r,max)>=0);
	
}

/*
  Sets r to a random GMP *prime* integer, smaller than max.
 */
void get_random_n_prime (mpz_t r, mpz_t max) {
	do {
		get_random_n_bits(r,mpz_sizeinbase(max,2));
		mpz_nextprime(r,r);
	} while (mpz_cmp(r,max)>=0);
}

/*
  Saves an Elgamal public key.
*/
void save_pk (elg_pk *pk, const char *name) {
	FILE *out = fopen(name,"w");
	if (out == NULL)
		die("Cannot save public key.\n");

	mpz_out_raw(out,pk->n);
	mpz_out_raw(out,pk->g);
	mpz_out_raw(out,pk->h);
	fclose(out);
}

/*
  Saves an Elgamal secret key.
*/
void save_sk (elg_sk *sk, const char *name) {
	FILE *out = fopen(name,"w");
	if (out == NULL)
		die("Cannot save private key.\n");
	
	mpz_out_raw(out,sk->n);
	mpz_out_raw(out,sk->g);
	mpz_out_raw(out,sk->h);
	mpz_out_raw(out,sk->x);
	fclose(out);
}

/*
  Reads an Elgamal secret key.
*/
void read_sk (elg_sk *sk, const char *name) {
	FILE *in = fopen(name,"r");
	if (in == NULL)
		die("Cannot read private key.\n");

	mpz_init(sk->n);
	mpz_init(sk->g);
	mpz_init(sk->h);
	mpz_init(sk->x);
	
	mpz_inp_raw(sk->n,in);
	mpz_inp_raw(sk->g,in);
	mpz_inp_raw(sk->h,in);
	mpz_inp_raw(sk->x,in);
	fclose(in);
}

/*
  Reads an Elgamal public key.
*/
void read_pk (elg_pk *pk, const char *name) {
	FILE *in = fopen(name,"r");
	if (in == NULL) 
		die("Cannot read public key.\n");

	mpz_init(pk->n);
	mpz_init(pk->g);
	mpz_init(pk->h);

	mpz_inp_raw(pk->n,in);
	mpz_inp_raw(pk->g,in);
	mpz_inp_raw(pk->h,in);
	fclose(in);
}

/*
  Generates an Elgamal key pair.
  There is no need to initialize the members of pk and sk
  before calling this function, because this is done internally.

  However, the pk and sk structures must exist (this function
  receives two pointers to an already existing structure!)
 */
void gen_keys(elg_pk *pk, elg_sk *sk) {
	mpz_init(sk->n);
	mpz_init(sk->g);
	mpz_init(sk->h);
	mpz_init(sk->x);

	mpz_init(pk->n);
	mpz_init(pk->g);
	mpz_init(pk->h);
	
	/* n is a large prime */
	get_random_n_bits(sk->n,128);
	mpz_nextprime(sk->n,sk->n);

	/* Get some random x < n */
	get_random_n(sk->x,sk->n);

	/* g is the generator */
	get_random_n_prime(sk->g,sk->n);

	/* h = g^x (mod n) */
	mpz_powm_sec(sk->h,sk->g,sk->x,sk->n);

	mpz_set(pk->n,sk->n);
	mpz_set(pk->g,sk->g);
	mpz_set(pk->h,sk->h);
}

/*
  Decrypts using Elgamal.
 */
void elg_dec(mpz_t msg, mpz_t c1, mpz_t c2,
	     elg_sk *sk) {
	mpz_t s, inv_s;
	mpz_init(s);
	mpz_init(inv_s);

	/* s = c1^x */
	mpz_powm_sec(s,c1,sk->x,sk->n);
	gmp_printf("s= %Zd\n",s);

	/* inv_s = s^{-1} */
	mpz_invert(inv_s,s,sk->n);
	gmp_printf("inv_c1= %Zd\n",inv_s);

	/* msg = c2 inv_s */
	mpz_mul(msg,c2,inv_s);
	gmp_printf("msg= %Zd\n",msg);

	/* Take msg modulo n */
	mpz_mod(msg,msg,sk->n);

	/* Release memory alocated by libmgp */
	mpz_clears(s, inv_s, NULL);
}

/*
  Encrypts using Elgamal.
 */
void elg_enc(mpz_t c1, mpz_t c2,
	mpz_t msg, elg_pk *pk) {
	mpz_t y, s;
	mpz_init (y);
	mpz_init(c1);
	mpz_init(c2);
	mpz_init(s);

	get_random_n(y,pk->n);

	/* s = h^y (mod n) */
	mpz_powm_sec(s,pk->h,y,pk->n);
	gmp_printf("s= %Zd\n",s);

	/* c1 = g^y (mod n) */
	mpz_powm_sec(c1,pk->g,y,pk->n);

	/* c2 = msg s (mod n) */
	mpz_mul(c2,msg,s);
	mpz_mod(c2,c2,pk->n);

	/* Release memory alocated by libmgp */
	mpz_clears(s, y, NULL);
}

/*
 * Release memory allocated to store
 * prublic and private keys.
 */
void clear_elg_pk(elg_pk pk)
{
	mpz_clears(pk.n, pk.g, pk.h, NULL);
}

void clear_elg_sk(elg_sk sk)
{
	mpz_clears(sk.n, sk.g, sk.h, sk.x, NULL);
}
