import os, argparse

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from mersenne import mersenne_rng
from fastecdsa import curve, ecdsa, keys

import pyximport
import logging 




logger = logging.getLogger(__name__)
log_format = '[%(asctime)s] - %(message)s'
date_format='%Y/%m/%d %H:%M:%S'
formatter = logging.Formatter(log_format, date_format)
logger.setLevel(logging.INFO)

def generate(private_key, tokenizer, model,prompt,vocab_size,n,m,key,d):
    

    rng = mersenne_rng(key)
    xi = torch.tensor([rng.rand() for _ in range(n*vocab_size)]).view(n,vocab_size)
    shift = torch.randint(n, (1,))

    inputs = prompt.to(model.device)
    attn = torch.ones_like(inputs)
    past = None
    cnt = 0  
    for i in range(m+d):           
            
        with torch.no_grad():
            if past:
                output = model(inputs[:,-1:], past_key_values=past, attention_mask=attn)
            else:
                output = model(inputs)

        probs = torch.nn.functional.softmax(output.logits[:,-1, :vocab_size], dim=-1).cpu()
    
        if i < d:
            token = exp_sampling(probs,xi[(shift+i)%n,:]).to(model.device)            
            msg = token if i == 0 else torch.cat([msg, token], dim=-1)

        else:
            if cnt == 0:
                
                msg_str = tokenizer.decode(msg[0])
                r, s = ecdsa.sign(msg_str, private_key)
                signature = [int(char) for char in format(r, '0256b')]+[int(char) for char in format(s, '0256b')]
                # print('cnt',cnt)
                logger.info('msg_str: %s', msg_str)
                logger.info('r: %s', r)
                logger.info('s: %s', s)
                logger.info('signature: %s', signature)
               
                
           
            token = my_sampling(probs,xi[(shift+i)%n,:],signature[cnt]).to(model.device)
            cnt += 1
        
        

        inputs = torch.cat([inputs, token], dim=-1)
        # print('input: ',inputs.shape)
   
        
        
        past = output.past_key_values
        attn = torch.cat([attn, attn.new_ones((attn.shape[0], 1))], dim=-1)

    return inputs.detach().cpu(), signature

def exp_sampling(probs,u):
    return torch.argmax(u ** (1/probs),axis=1).unsqueeze(-1)

def my_sampling(probs,u,b):
    sorted_indices = torch.argsort(u ** (1 / probs), dim=1, descending=True)[0]
    # print(sorted_indices.shape)
    for t in sorted_indices:
        if t%2 == b: #replace hash with a simple mapping for simplicity
            
            return torch.tensor([[t]])

def greedy(probs):
    return torch.argmax(probs,axis=1).unsqueeze(-1)



def extract(watermarked_tokens, tokenizer, m, d, signature):
    # Extract the last m+d tokens from the watermarked text
    logger.info('Extraction...')
    
    # watermarked_tokens = tokenizer.encode(watermarked_text, return_tensors='pt',add_special_tokens=True)[0]
    logger.info('len_watermarked_tokens: %s', len(watermarked_tokens))
    # logger.info('watermarked_tokens_after: %s', watermarked_tokens)
    msg= tokenizer.decode(watermarked_tokens[-m-d:-m])
    logger.info('len_msg: %s', len(watermarked_tokens[-m-d:-m]))
    logger.info('msg: %s', msg)
    sig = watermarked_tokens[-m:]
    # print(len(sig))
    # Extract the signature bits
    signature_bits = []

    for i in range(m):
        res = sig[i] % 2
        signature_bits.append(res.item())
    logger.info('signature_bits: %s', signature_bits)
    print(signature==signature_bits)
    r = sum(bit << i for i, bit in enumerate(reversed(signature_bits[:256])))
    s = sum(bit << i for i, bit in enumerate(reversed(signature_bits[256:])))
    logger.info('r: %s', r)
    logger.info('s: %s', s)
    return msg, r, s



def main(args):
    if not os.path.exists(args.out_dir):
        os.mkdir(args.out_dir)
    logfile = os.path.join(args.out_dir, args.logfile)

    file_handler = logging.FileHandler(logfile)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.info(args)

    torch.manual_seed(args.seed)
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    tokenizer = AutoTokenizer.from_pretrained(args.model)
    model = AutoModelForCausalLM.from_pretrained(args.model).to(device)

    tokens = tokenizer.encode(args.prompt, return_tensors='pt',add_special_tokens=True)

    private_key = keys.gen_private_key(curve.P256)
    public_key = keys.get_public_key(private_key, curve.P256)

    watermarked_tokens,signature = generate(private_key, tokenizer, model,tokens,len(tokenizer),args.n,args.m,args.key, args.d)
    logger.info('shape:%s',watermarked_tokens.shape)
    # logger.info('watermarked_tokens: %s', watermarked_tokens[0])
    watermarked_text = tokenizer.decode(watermarked_tokens[0])
    logger.info('watermark_text: %s', watermarked_text)
    # print(tokenizer.decode(torch.tensor([39747, 2156, 479]), skip_special_tokens=True))
    print(watermarked_text)
    msg, r, s = extract(watermarked_tokens[0], tokenizer, args.m, args.d, signature)
    valid = ecdsa.verify((r, s), msg, public_key)
    if valid:
        print('The verification is successful')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='generate text watermarked with a key')
    parser.add_argument('--model',default='huggyllama/llama-7b',type=str,
            help='a HuggingFace model id of the model to generate from')
    parser.add_argument('--prompt',default='How to cook ',type=str,
            help='an optional prompt for generation')
    parser.add_argument('--m',default=512,type=int,
            help='the requested length of the generated text')
    parser.add_argument('--n',default=80,type=int,
            help='the length of the watermark sequence')
    parser.add_argument('--d',default=4,type=int,
            help='use first d tokens as message')           
    parser.add_argument('--key',default=42,type=int,
            help='a key for generating the random watermark sequence')
    parser.add_argument('--seed',default=0,type=int,
            help='a seed for reproducibile randomness')
    parser.add_argument('--out_dir', default='./log_test',
        help='directory for saving results')
    parser.add_argument('--logfile', default='opt_asym.log',
                    help='directory for saving results')

    main(parser.parse_args())
