import torch
import numpy as np 
import os
import sympy  # This library is used for primality testing
import random 
import sys
import argparse
from tqdm import tqdm

class FourierEmbedding(torch.nn.Module):
    def __init__(self, num_channels, scale=16):
        super().__init__()
        self.register_buffer('freqs', torch.randn(num_channels // 2) * scale)

    def forward(self, x):
        x = x.ger((2 * np.pi * self.freqs).to(x.dtype))
        x = torch.cat([x.cos(), x.sin()], dim=1)
        return x


max_size = sys.maxsize
min_size = -sys.maxsize - 1

def is_prime(num):
    return sympy.isprime(num)

def generate_random_prime(m):
    while True:
        random_integer = np.random.randint(m + 1, max_size)  # You can adjust the upper limit as needed
        if is_prime(random_integer):
            return random_integer

def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)


def DHE(x, k, m, encoding_type='uniform'):
   
    set_seed(0)
    h = []
    for i in range(k):
        a = np.random.randint(low=min_size,high=max_size,size=(1))
        b = np.random.randint(low=0,high=max_size,size=(1))
        p = generate_random_prime(m)
        h_i = ((a*x + b) % p ) % m 
        h.append(h_i)
    
    encod = []
    encod_prime = []
    for i in range(k):
        encod_prime_i = (h[i] - 1) / (m-1)
        encod_prime.append(encod_prime_i)
        #print("h[i]:{} m:{} encod_prime_i:{}".format(h[i],m,encod_prime_i))
        encod_i = encod_prime_i * 2 - 1
        encod.append(encod_i)
    
    if encoding_type == 'gaussian':
        i = 0 
        while i < k:
            j = i + 1
            encod[i] = np.sqrt(-2*np.log(encod_prime[i])) * np.cos(2*np.pi*encod_prime[j])
            encod[j] = np.sqrt(-2*np.log(encod_prime[i])) * np.sin(2*np.pi*encod_prime[j])
            i = i + 2


    return np.array(encod)



if __name__ == "__main__":
    #python generate_unique_embeddings.py --emb_type uniform --emb_length 768 --num_unique 3000 
    parser = argparse.ArgumentParser()
    parser.add_argument("--emb_type", type=str, default='dhe',choices=['dhe','uniform','fourier'])
    parser.add_argument("--emb_length", type=int, default=None, required=True)
    parser.add_argument("--num_unique", type=int, default=None, required=True)
    parser.add_argument("--save_filename", type=str, default=None)

    args = parser.parse_args()
    
    save_filename = args.save_filename 
    if save_filename is None:
        save_filename = '{}_embeddings_length_{}.pt'.format(args.emb_type,args.emb_length)

    print("file is saved to ",save_filename)

    if args.emb_type == 'dhe':
        num = args.num_unique
        k , m = args.emb_length, 10e6
        dhe_embeddings = []
        for i in tqdm(range(num)):
            dhe_code = DHE(i,k,m, encoding_type='uniform').reshape(-1)
            dhe_embeddings.append(dhe_code)

        dhe_embeddings = torch.from_numpy(np.array(dhe_embeddings))
        torch.save(dhe_embeddings,f=save_filename)
    
    elif args.emb_type == 'uniform':
        uniform_embeddings = torch.rand(size=(args.num_unique,args.emb_length)) * 2 - 1
        print("max:{} min:{}".format(uniform_embeddings.max(),uniform_embeddings.min()))
        torch.save(uniform_embeddings,f=save_filename)
    elif args.emb_type == 'fourier':
        fourier_fn = FourierEmbedding(num_channels=args.emb_length)
        unique_num = torch.arange(args.num_unique)
        uniform_embeddings = fourier_fn(unique_num)
        torch.save(uniform_embeddings,f=save_filename)

    else:
        raise RuntimeError('embedding type error!')

    embeddings_load = torch.load(f=save_filename)
    print("embeddings shape:",embeddings_load.shape)
    
