from tqdm import tqdm
import os
import torch.nn.functional as F
from watermark.sparseonebit_normalhash_watermark import SparseOneBitNormalHash,SparseOneBitNormalHashDetector
from transformers import  LogitsProcessorList
    
    
class Generator():
    def __init__(self, args, tokenizer, model) -> None:
        self.mode = args.mode # watermark mode
        self.init_seed, self.dyna_seed, self.gamma, \
        self.delta, self.bl_type, self.num_beams, self.sampling_temp = args.initial_seed, args.dynamic_seed, args.gamma, args.delta, args.bl_type, args.num_beams, args.sampling_temp
        self.tokenizer = tokenizer
        self.model = model # language model
        
        self.all_token_ids = list(tokenizer.get_vocab().values())
        self.vocab_size = len(self.all_token_ids)
        # if self.vocab_size != model.config.padded_vocab_size:
        #     self.vocab_size = model.config.padded_vocab_size
        if args.mode == 'onebitsparsenormalhash':
            watermark_processor = SparseOneBitNormalHash(tokenizer=tokenizer,
                                               gamma=args.gamma,
                                                delta=args.delta,
                                                prompt_slice=None,
                                                hard_encode=True if self.bl_type=="hard" else False,
                                                allowed_pos_tag=args.pos_tag
                                                )
            
            print(f"[INFO]:{watermark_processor.hard_encode}")
            self.detector = SparseOneBitNormalHashDetector(
                tokenizer = tokenizer,
                gamma=args.gamma,
                delta=args.delta,
                prompt_slice=None,
                hard_encode=True if self.bl_type=="hard" else False,
                allowed_pos_tag=args.pos_tag
                )
            watermark_processor.init_table()
            self.logit_processor_lst = LogitsProcessorList([watermark_processor]) 
            
    def generate(self, input_ids, max_new_tokens):
        if self.mode == 'new':
            example = {}
            
            outputs = self.model.generate(
                input_ids, max_new_tokens=max_new_tokens,
                logits_processor = self.logit_processor_lst,
                do_sample=True,
                top_k=0,
                temperature=self.sampling_temp
            )

            example.update({"bl_vocabularys":self.logit_processor_lst[0].get_and_clear_vocabularys()})
            # remove the attached input from output for some model
            scores = outputs.scores
            output_ids = outputs.sequences[0, -len(scores):]


            # compute logprob for each token
            completions_tokens = []
            completions_logprob = 0

            for score, token, vocabulary in zip(scores, output_ids, example["bl_vocabularys"], strict=True):
                logprobs = F.log_softmax(score[0], dim=-1)
                logprob = logprobs[token].item()
                completions_tokens.append({
                    'text': self.tokenizer.decode(token),
                    'logprob': logprob,
                    'vocabulary': vocabulary,
                })
                completions_logprob += logprob
            
            completions_text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
            return completions_text, completions_tokens
        else:    
        
            if self.mode == 'onebitsparsenormalhash':
                
                self.logit_processor_lst[0].prompt_slice = len(input_ids[0])
                outputs = self.model.generate(
                    input_ids, max_new_tokens=max_new_tokens,
                    logits_processor = self.logit_processor_lst,
                )
            # remove the attached input from output for some model
            scores = outputs.scores
            output_ids = outputs.sequences[0, -len(scores):]

            # compute logprob for each token
            completions_tokens = []
            completions_logprob = 0

            for score, token in zip(scores, output_ids, strict=True):
                logprobs = F.log_softmax(score[0], dim=-1)
                logprob = logprobs[token].item()
                completions_tokens.append({
                    'text': self.tokenizer.decode(token),
                    'logprob': logprob,
                })
                completions_logprob += logprob
            
            completions_text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
            if 'sparse' in self.mode:
                print(completions_text)
                print(self.detector.detect(completions_text))
                print(self.logit_processor_lst[0].last_input)
            return completions_text, completions_tokens
