from tqdm import tqdm
import os
import torch.nn.functional as F
from watermark.old_watermark import BlacklistLogitsProcessor
from watermark.our_watermark import OurBlacklistLogitsProcessor
from watermark.gptwm import GPTWatermarkLogitsWarper
from watermark.watermark_v2 import WatermarkLogitsProcessor
from transformers import  LogitsProcessorList
import logging
from datetime import datetime

def setup_logger(log_file_path=None, console_output=False):
    """设置logger配置
    
    Args:
        log_file_path: 日志文件路径，如果为None则自动生成带时间戳的文件名
        console_output: 是否同时在控制台输出，默认为False
    """
    if log_file_path is None:
        # 创建带时间戳的日志文件名
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_file_path = f"./log/generate_{timestamp}.log"
    
    # 创建logger
    logger = logging.getLogger('generate')
    logger.setLevel(logging.INFO)
    
    # 如果logger已经有handler，就不再添加
    if not logger.handlers:
        # 创建文件handler
        file_handler = logging.FileHandler(log_file_path, encoding='utf-8')
        file_handler.setLevel(logging.INFO)
        
        # 创建formatter
        formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S'
        )
        
        # 设置formatter
        file_handler.setFormatter(formatter)
        
        # 添加文件handler到logger
        logger.addHandler(file_handler)
        
        # 如果需要控制台输出，添加控制台handler
        if console_output:
            console_handler = logging.StreamHandler()
            console_handler.setLevel(logging.INFO)
            console_handler.setFormatter(formatter)
            logger.addHandler(console_handler)
    
    return logger

# 全局logger实例 - 默认只输出到文件
logger = setup_logger()

class Generator():
    def __init__(self, args, tokenizer, model) -> None:
        self.mode = args.mode # watermark mode
        self.init_seed, self.dyna_seed, self.gamma, \
        self.delta, self.bl_type, self.num_beams, self.sampling_temp = args.initial_seed, args.dynamic_seed, args.gamma, args.delta, args.bl_type, args.num_beams, args.sampling_temp
        self.tokenizer = tokenizer
        self.model = model # language model
        self.model_name = args.model
        
        
        self.all_token_ids = list(tokenizer.get_vocab().values())
        self.vocab_size = len(self.all_token_ids)
        if "qwen" in self.model_name:
            self.vocab_size = 152064
        
        # if self.vocab_size != model.config.padded_vocab_size:
        #     self.vocab_size = model.config.padded_vocab_size
        
        self.bl_processor = BlacklistLogitsProcessor(
                                            bad_words_ids=None, 
                                            eos_token_id=tokenizer.eos_token_id, 
                                            vocab=self.all_token_ids, 
                                            vocab_size=self.vocab_size, 
                                            bl_proportion=1-self.gamma,
                                            bl_logit_bias=self.delta,
                                            bl_type=self.bl_type, 
                                            initial_seed=self.init_seed, 
                                            dynamic_seed=self.dyna_seed)
        self.logit_processor_lst = LogitsProcessorList([self.bl_processor])
        if args.mode == 'new': 
            self.bl_processor = OurBlacklistLogitsProcessor(tokenizer=tokenizer,
                                        bad_words_ids=None, 
                                        eos_token_id=tokenizer.eos_token_id, 
                                        vocab=self.all_token_ids, 
                                        all_vocab_size=self.vocab_size, 
                                        bl_proportion=1-self.gamma,
                                        bl_logit_bias=self.delta,
                                        bl_type=self.bl_type, 
                                        initial_seed=self.init_seed, 
                                        dynamic_seed=self.dyna_seed)
            self.logit_processor_lst = LogitsProcessorList([self.bl_processor])
            
        if args.mode == 'gpt':
            
            watermark_processor = GPTWatermarkLogitsWarper(vocab_size=self.vocab_size,
                                                        fraction=args.gamma,
                                                        strength=args.delta,
                                                        watermark_key=args.wm_key)
            
            self.logit_processor_lst = LogitsProcessorList([watermark_processor])   
            
        if args.mode == 'v2':
            watermark_processor = WatermarkLogitsProcessor(vocab=list(tokenizer.get_vocab().values()),
                                                        gamma=args.gamma,
                                                        delta=args.delta,
                                                        seeding_scheme=args.seeding_scheme,
                                                        select_green_tokens=args.select_green_tokens)
            self.logit_processor_lst = LogitsProcessorList([watermark_processor]) 
            
    def generate(self, input_ids, max_new_tokens):
        if self.mode == 'new':
            example = {}
            
            outputs = self.model.generate(
                input_ids, max_new_tokens=max_new_tokens,
                logits_processor = self.logit_processor_lst,
                do_sample=True,
                top_k=0,
                temperature=self.sampling_temp,
                repetition_penalty=1.0
            )

            example.update({"bl_vocabularys":self.logit_processor_lst[0].get_and_clear_vocabularys()})
            # remove the attached input from output for some model
            scores = outputs.scores
            output_ids = outputs.sequences[0, -len(scores):]


            # compute logprob for each token
            completions_tokens = []
            completions_logprob = 0

            for score, token, vocabulary in zip(scores, output_ids, example["bl_vocabularys"], strict=True):
                logprobs = F.log_softmax(score[0], dim=-1)
                logprob = logprobs[token].item()
                completions_tokens.append({
                    'text': self.tokenizer.decode(token),
                    'logprob': logprob,
                    'vocabulary': vocabulary,
                })
                completions_logprob += logprob
            
            completions_text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
            return completions_text, completions_tokens
        else:    
        
            if self.mode == 'no':
                
                outputs, _ = self.model.generate(
                    input_ids, max_new_tokens=max_new_tokens,
                    return_dict_in_generate=True,
                    output_scores=True,
                )

            elif self.mode == 'old':
                
                outputs, _ = self.model.generate(
                    input_ids, max_new_tokens=max_new_tokens,
                    logits_processor = self.logit_processor_lst,
                    do_sample=True,
                    top_k=0,
                    temperature=self.sampling_temp,
                    return_dict_in_generate=True,
                    output_scores=True,
                    repetition_penalty=1.0
                )


            elif self.mode == 'gpt':
                
                outputs, _ = self.model.generate(
                    input_ids, max_new_tokens=max_new_tokens,
                    logits_processor = self.logit_processor_lst,
                    do_sample=True,
                    top_k=0,
                    top_p=0.9,
                    return_dict_in_generate=True,
                    output_scores=True,
                    repetition_penalty=1.0
                )

            elif self.mode == 'v2':
                outputs, _ = self.model.generate(
                    input_ids, max_new_tokens=max_new_tokens,
                    logits_processor = self.logit_processor_lst,
                    do_sample=True,
                    top_k=0,
                    temperature=self.sampling_temp,
                    return_dict_in_generate=True,
                    output_scores=True,
                    repetition_penalty=1.0
                )
            # remove the attached input from output for some model
            scores = outputs.scores
            output_ids = outputs.sequences[0, -len(scores):]
            # list_data = output_ids.cpu().tolist()
            # compute logprob for each token
            completions_tokens = []
            completions_logprob = 0

            for score, token in zip(scores, output_ids):
                logprobs = F.log_softmax(score[0], dim=-1)
                logprob = logprobs[token].item()
                completions_tokens.append({
                    'text': self.tokenizer.decode(token),
                    'logprob': logprob,
                })
                completions_logprob += logprob
            
            completions_text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
            re_output_ids = self.tokenizer.encode(completions_text, return_tensors="pt", truncation=True, add_special_tokens=False)
            logger.info(f"长度差距:{abs(len(output_ids) - len(re_output_ids[0]))}")
            
            
            return completions_text, completions_tokens