import torch
from transformers import LogitsWarper
import torch.nn.functional as F
from transformers import AutoTokenizer
from transformers import (LogitsWarper)
import torch.nn.functional as F
import torch
import logging
import os

def set_logger(cfg, save_dir):
    save_dir = os.path.join(save_dir, "log")
    os.makedirs(save_dir, exist_ok=True)
    logging.basicConfig(filename=os.path.join(save_dir, f'I4_{cfg}.log'), level=logging.INFO, filemode='a+', format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    # logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    logger = logging.getLogger(__name__)
    logger.info(f"*****cfg: {cfg}*****")
    print(f"save log file to {os.path.join(save_dir, f'I4_{cfg}.log')}")
    # logger.addHandler(handler)
    return logger


class CFGLogits_v2(LogitsWarper):

    def __init__(self, cfg, control_input_ids, control_images, model, aligned_vision_model, tokenizer=None, verbose=True):
        self.cfg = cfg
        self.control_input_ids = control_input_ids # input_embeds from vanilla clip branch
        self.control_images = control_images
        self.model = model
        self.aligned_vision_model = aligned_vision_model
        self.out = None
        self.verbose = verbose
        self.tokenizer = tokenizer
        if self.verbose:
            self.logger = logging.getLogger(__name__)


    def __call__(self, input_ids, logits):
        if self.cfg == 0.0:
            return F.log_softmax(logits, dim=-1)

        ## vanilla branch
        logits = F.log_softmax(logits, dim=-1)

        ## cfg branch
        if self.out is None:
            input_ids, attention_mask, past_key_values, inputs_embeds, labels \
                = self.aligned_vision_model(input_ids=self.control_input_ids, images=self.control_images, use_cache=True)
            assert inputs_embeds is not None
            # others are None
            self.out = self.model.forward_manual_inputs_embeds(
                input_ids=input_ids,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=True
                )
        else:
            self.out = self.model(input_ids[:, -1:],
                                  use_cache=True,
                                  past_key_values=self.out.past_key_values)

        if len(self.out.logits) == 1:
            control_logits = F.log_softmax(self.out.logits[0][-1:], dim=-1)
        else:
            control_logits = F.log_softmax(self.out.logits[:,-1:], dim=-1).to(logits.device)
            control_logits = control_logits.squeeze(1)

        out = self.cfg * (control_logits - logits) + logits
        out = F.log_softmax(out, dim=-1)

        if self.verbose:
            name_ls = ["origin", "control", "output"]
            logits_ls = [logits, control_logits, out]
            self.log_topp(input_ids, logits_ls, name_ls)

        return out

    def log_topp(self, input_ids, logits_ls, name_ls):
        scores_ls, kept_num_ls, detokenize_ls = [], [], []
        for name, lgt in zip(name_ls, logits_ls):
            scores, kept_num, detokenize = self.caculate_topp_kept(lgt, top_p=0.8)
            scores_ls.append(scores)
            kept_num_ls.append(kept_num)
            detokenize_ls.append(detokenize)

        if input_ids is None:
            return
        for i in range(len(input_ids)):
        # for i in range(1):
            self.logger.info("******input******")
            # outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
            filtered_input_ids = input_ids[i][input_ids[i] > 0]
            # print(self.tokenizer.decode(filtered_input_ids, skip_special_tokens=True))
            self.logger.info(self.tokenizer.decode(filtered_input_ids, skip_special_tokens=True))
            self.logger.info("******output******")
            for j, (name, scores, kept_num, detokenize) in enumerate(zip(name_ls, scores_ls, kept_num_ls, detokenize_ls)):
                self.logger.info(f"{i},{j},{name}") 
                                # kept_num: {kept_num[i]}, scores: {scores[i].topk(kept_num[i])}")
                self.logger.info(f"detokenize: {detokenize[i]}")
            self.logger.info("")


    def caculate_topp_kept(self, scores, top_p=0.5, min_tokens_to_keep=1, filter_value=-float("Inf")):
        sorted_logits, sorted_indices = torch.sort(scores, descending=False)
        cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
        # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
        # Keep at least min_tokens_to_keep
        sorted_indices_to_remove[..., -min_tokens_to_keep :] = 0

        # scatter sorted tensors to original indexing 
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        scores = scores.masked_fill(indices_to_remove, filter_value)
        
        # enable batch size > 1
        # scores is [batch_size, vocab_size]
        # kept_num is [batch_size]

        kept_num = (sorted_indices_to_remove == 0).sum(dim=-1)

        kept_idx = []
        detokenize = []
        for i in range(len(scores)):
            kept_idx.append(scores[i].topk(kept_num[i]).indices)
            detokenize.append(self.tokenizer.batch_decode(kept_idx[i], skip_special_tokens=True))
            
        return scores, kept_num, detokenize

