# from transformers import (GPT2Tokenizer, AutoModelForCausalLM,
#                           GPTNeoXForCausalLM, AutoTokenizer)
# import numpy as np
# import torch
# from transformers import (LogitsProcessor, LogitsProcessorList,
#                           MinLengthLogitsProcessor, TemperatureLogitsWarper,
#                           LogitsWarper)
# from transformers.generation import LogitNormalization
# import torch.nn.functional as F

# class CFGLogits(LogitsWarper):

#     def __init__(self, cfg, control, model, cfg_mode="text", verbose=True):
#         self.cfg = cfg
#         self.control = control.cuda()
#         self.model = model
#         self.out = None
#         self.verbose = verbose
#         self.cfg_mode = cfg_mode

#     def __call__(self, input_ids, logits):
#         # if self.cfg == 1:
#         #     return F.log_softmax(logits, dim=-1)
#         logits = F.log_softmax(logits, dim=-1)
#         if self.out is None:
#             # if self.cfg_mode == "text":
#             self.out = self.model(self.control, use_cache=True)
#             # elif self.cfg_mode == "image":
#             #     # import pdb
#             #     self.out = self.model.forward_manual_inputs_embeds(inputs_embeds=self.control, use_cache=True)
#         else:
#             self.out = self.model(input_ids[:, -1:],
#                                   use_cache=True,
#                                   past_key_values=self.out.past_key_values)
#         # if self.cfg_mode == "text":
#         control_logits = F.log_softmax(self.out.logits[0][-1:], dim=-1).to(logits.device)
#         # elif self.cfg_mode == "image":
#         #     if "logits" not in self.out:
#         #         print(self.out)
#         #         raise ValueError("self.out does not have logits!")
#         #     control_logits = F.log_softmax(self.out.logits[0][-1:], dim=-1)

#         out = self.cfg * (control_logits - logits) + logits
#         out = F.log_softmax(out, dim=-1)

#         # # usually the kept_num is below 200
#         # scores, kept_num, kept_idx = caculate_topp_kept(logits)
#         # control_scores, control_kept_num, control_kept_idx = caculate_topp_kept(control_logits)
#         # out_scores, out_kept_num, out_kept_idx = caculate_topp_kept(out)
#         # print(f"kept num for logits and control_logits: {kept_num} {control_kept_num} {out_kept_num}")
#         return out #0.7 * out + 0.3 * scores


from transformers import (LogitsWarper)
import torch.nn.functional as F
import torch
import logging
import os

def set_logger(cfg, save_dir):
    save_dir = os.path.join(save_dir, "log")
    os.makedirs(save_dir, exist_ok=True)
    logging.basicConfig(filename=os.path.join(save_dir, f'I4_{cfg}.log'), level=logging.INFO, filemode='a+', format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    # logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    logger = logging.getLogger(__name__)
    logger.info(f"*****cfg: {cfg}*****")
    print(f"save log file to {os.path.join(save_dir, f'I4_{cfg}.log')}")
    # logger.addHandler(handler)
    return logger

class CFGLogits(LogitsWarper):

    def __init__(self, cfg, control, images, model, tokenizer=None, verbose=True):
        self.cfg = cfg
        self.control = control.cuda()
        self.images = images
        self.model = model
        self.out = None
        self.verbose = verbose
        self.tokenizer = tokenizer
        if self.verbose:
            self.logger = logging.getLogger(__name__)

    def __call__(self, input_ids, logits):
        logits = F.log_softmax(logits, dim=-1)
        if self.out is None:
            # add IMAGE_TOKEN_INDEX == -200 in input_ids
            self.out = self.model(input_ids=self.control, images=self.images, use_cache=True)
        else:
            self.out = self.model(input_ids[:, -1:],
                                  use_cache=True,
                                  past_key_values=self.out.past_key_values)

        if len(self.out.logits) == 1:
            control_logits = F.log_softmax(self.out.logits[0][-1:], dim=-1)
        else:
            # input_ids is [batch_size, seq_len]
            # self.out.logits is [batch_size, seq_len, vocab_size]
            # control_logits is [batch_size, vocab_size]
            control_logits = F.log_softmax(self.out.logits[:,-1:], dim=-1).to(logits.device)
            control_logits = control_logits.squeeze(1)

        out = self.cfg * (control_logits - logits) + logits
        out = F.log_softmax(out, dim=-1)

        if self.verbose:
            name_ls = ["origin", "control", "output"]
            logits_ls = [logits, control_logits, out]
            self.log_topp(input_ids, logits_ls, name_ls)

        return out

    def log_topp(self, input_ids, logits_ls, name_ls):
        scores_ls, kept_num_ls, detokenize_ls = [], [], []
        for name, lgt in zip(name_ls, logits_ls):
            scores, kept_num, detokenize = self.caculate_topp_kept(lgt, top_p=0.8)
            scores_ls.append(scores)
            kept_num_ls.append(kept_num)
            detokenize_ls.append(detokenize)

        for i in range(len(input_ids)):
        # for i in range(1):
            self.logger.info("******input******")
            # outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
            filtered_input_ids = input_ids[i][input_ids[i] > 0]
            # print(self.tokenizer.decode(filtered_input_ids, skip_special_tokens=True))
            self.logger.info(self.tokenizer.decode(filtered_input_ids, skip_special_tokens=True))
            self.logger.info("******output******")
            for j, (name, scores, kept_num, detokenize) in enumerate(zip(name_ls, scores_ls, kept_num_ls, detokenize_ls)):
                self.logger.info(f"{i},{j},{name}, scores: {scores[i].topk(kept_num[i])}")
                                # kept_num: {kept_num[i]}, scores: {scores[i].topk(kept_num[i])}")
                self.logger.info(f"detokenize: {detokenize[i]}")
                try:
                    logits = scores
                    probs = torch.exp(logits)
                    self.logger.info(f"{probs[i].topk(kept_num[i])}")
                except:
                    pass
            self.logger.info("")


    def caculate_topp_kept(self, scores, top_p=0.5, min_tokens_to_keep=1, filter_value=-float("Inf")):
        sorted_logits, sorted_indices = torch.sort(scores, descending=False)
        cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
        # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
        # Keep at least min_tokens_to_keep
        sorted_indices_to_remove[..., -min_tokens_to_keep :] = 0

        # scatter sorted tensors to original indexing 
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        scores = scores.masked_fill(indices_to_remove, filter_value)
        
        # enable batch size > 1
        # scores is [batch_size, vocab_size]
        # kept_num is [batch_size]

        kept_num = (sorted_indices_to_remove == 0).sum(dim=-1)

        kept_idx = []
        detokenize = []
        for i in range(len(scores)):
            kept_idx.append(scores[i].topk(kept_num[i]).indices)
            detokenize.append(self.tokenizer.batch_decode(kept_idx[i], skip_special_tokens=True))
            
        return scores, kept_num, detokenize

