import json
import math
import os
import queue
from collections import Counter

import torch
from torch import nn
from torch.autograd import Variable
from torch.optim import lr_scheduler
from transformers import LogitsProcessor

import numpy as np

def softmax(logits):
    logits = np.array(logits)
    exp_logits = np.exp(logits - np.max(logits)) 
    return exp_logits / np.sum(exp_logits)

   
            
class BasedOnProbabilityTransferLogits_Main_FP32_Processor(nn.Module):
    def __init__(self, learning_rate, learning_epochs_nums, ensemble_weight,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list, assist_model_probability_reverse_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, early_stop_string_list=None, topk=-1, unc="None", k=10, l=1, select=0, top_p=0.9, early_exit=False):
        super().__init__()

        self.learning_rate = learning_rate
        self.ensemble_weight = ensemble_weight
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        # self.assist_model_probability_transfer_matrix_list = [matrix.to_dense() for matrix in self.assist_model_probability_transfer_matrix_list]
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer_list = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.early_stop_string_list = early_stop_string_list
        self.topk = topk  
        self.unc = unc    
        self.k = k        
        self.select = select
        self.top_p = top_p
        self.early_exit = early_exit

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        with torch.no_grad():
            ensemble_process_file_path = os.path.join(self.result_save_dir, f"ensemble_lr{self.learning_rate}_log.json")
            json_object = {}

            assist_model_generate_ids_logits_list = []
            main_model_only_flag = False

            for index, queue_instance in enumerate(self.assist_model_score_queue_list):
                try:
                    value = queue_instance.get(block=True, timeout=5)
                    # value = queue_instance.get(block=False)
                    assist_model_generate_ids_logits_list.append(value)
                except queue.Empty:
                    print(f"Aux model {index} not received")
                    assist_model_generate_ids_logits_list.append(None)
                    main_model_only_flag = True

            if len(assist_model_generate_ids_logits_list) == 0:
                main_model_only_flag = True

            if torch.argmax(scores).item() == self.main_model_tokenizer.eos_token_id:
                main_model_only_flag = True
            cur_text = self.main_model_tokenizer.decode(input_ids.tolist()[0])
            if check_early_stop(cur_text, self.early_stop_string_list):
                scores[:, self.main_model_tokenizer.eos_token_id] = float('inf')
                main_model_only_flag = True
            # if self.early_stop_string_list is not None:
            #     for early_stop_string in self.early_stop_string_list:
                    # early_stop_token = self.main_model_tokenizer(early_stop_string, return_tensors="pt",
                    #                                             add_special_tokens=False).input_ids.tolist()[0][1:]
                    # last_token_count = len(early_stop_token)

                    # last_token_ids = input_ids.tolist()[0][-last_token_count:]

                    # if last_token_ids == early_stop_token:
                    #     scores[:, self.main_model_tokenizer.eos_token_id] = float('inf')
                        # main_model_only_flag = True
            if not main_model_only_flag:

                main_model_generate_ids_logits = scores.to(self.device_compute).to(torch.float32)
                if self.topk == -1 or self.select==2:
                    main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1)
                elif self.topk==-2:
                    main_model_generate_ids_logits = main_model_generate_ids_logits

                    buffer_k = self.k
                    topk_values, topk_indices = torch.topk(main_model_generate_ids_logits, k=buffer_k, dim=-1)

                    topk_probs = nn.functional.softmax(topk_values, dim=-1)
                    cumulative_probs = torch.cumsum(topk_probs, dim=-1) 

                    keep_mask = cumulative_probs < self.top_p 
                    keep_mask[:, 1:] = keep_mask[:, :-1].clone()  
                    keep_mask[:, 0] = True 
                    selected_probs = topk_probs * keep_mask  

                    selected_probs = selected_probs / selected_probs.sum(dim=-1, keepdim=True)

                    masked_logits = torch.full_like(main_model_generate_ids_logits, float('-inf'))
                    masked_logits.scatter_(dim=-1, index=topk_indices, src=topk_values * keep_mask) 

                    main_model_generate_ids_probs = nn.functional.softmax(masked_logits, dim=-1)

                elif self.topk==0:
                    main_model_generate_ids_logits = main_model_generate_ids_logits
                    logits_std = torch.std(main_model_generate_ids_logits, dim=-1, keepdim=True)

                    logits_max = torch.max(main_model_generate_ids_logits, dim=-1, keepdim=True).values
                    threshold = logits_max - logits_std

                    mask = main_model_generate_ids_logits >= threshold
                    selected_count = torch.sum(mask, dim=-1, keepdim=True)

                    if selected_count.min().item() < self.k:
                        topk_values, _ = torch.topk(main_model_generate_ids_logits, k=self.k, dim=-1)
                        new_threshold = topk_values[:, -1].unsqueeze(-1)  
                        threshold = new_threshold  

                    masked_logits = torch.where(
                        main_model_generate_ids_logits >= threshold,  
                        main_model_generate_ids_logits,  
                        torch.full_like(main_model_generate_ids_logits, float('-inf'))
                    )

                    main_model_generate_ids_probs = nn.functional.softmax(masked_logits, dim=-1)
                elif self.topk>0:
                    main_model_generate_ids_logits = main_model_generate_ids_logits
                    topk_values, topk_indices = torch.topk(main_model_generate_ids_logits, k=self.topk, dim=-1)

                    masked_logits = torch.full_like(main_model_generate_ids_logits, float('-inf')) 
                    masked_logits.scatter_(dim=-1, index=topk_indices, src=topk_values)  

                    main_model_generate_ids_probs = nn.functional.softmax(masked_logits, dim=-1)

                model_relative_representation_probs_list = [main_model_generate_ids_probs]
                model_weights = []
                if self.unc == "entropy":
                    if self.k>0:
                        topk_values, topk_indices = torch.topk(main_model_generate_ids_logits, k=self.k, dim=-1)
                        unc_probs = nn.functional.softmax(topk_values, dim=-1)
                        margin = topk_values[..., 0] - topk_values[..., 1]
                    else:
                        unc_probs = main_model_generate_ids_probs
                    uncertainty = -torch.sum(unc_probs * torch.log(unc_probs + 1e-9))
                    # uncertainty = uncertainty * torch.exp(-0.5 * margin)

                    model_weight = 1.0 / (uncertainty + 1e-9)
                    # model_weight = 1.0 / torch.nn.functional.softplus(uncertainty)
                    model_weights.append(model_weight.item())
                elif self.unc == "logit":
                    if self.k>0:
                        values, topk_indices = torch.topk(main_model_generate_ids_logits, k=self.k, dim=-1)
                    else:
                        values = main_model_generate_ids_logits
                    # print("main:", values)
                    certainty = torch.norm(values, p=2).item()
                    # certainty = values[..., 0] - values[..., -1]
                    model_weight = certainty**self.l # / torch.sum(main_model_generate_ids_logits).item()
                    model_weights.append(model_weight)

                elif self.unc == "energy":
                    if self.k>0:
                        values, topk_indices = torch.topk(main_model_generate_ids_logits, k=self.k, dim=-1)
                    else:
                        values = main_model_generate_ids_logits
                    model_weight = torch.logsumexp(values, dim=-1).item()
                    model_weights.append(model_weight**self.l)

                elif self.unc == "inv_energy":
                    if self.k>0:
                        values, topk_indices = torch.topk(main_model_generate_ids_logits, k=self.k, dim=-1)
                    else:
                        values = main_model_generate_ids_logits        
                    energy = -self.l*torch.logsumexp(values/self.l, dim=-1)
                    model_weight = 1 / (1 + torch.exp(energy))
                    model_weights.append(model_weight.item())

                elif self.unc == "margin":
                    if self.k>0:
                        values, topk_indices = torch.topk(main_model_generate_ids_logits, k=self.k, dim=-1)
                    else:
                        values = main_model_generate_ids_logits
                    # print("main:", values)
                    # certainty = torch.norm(values, p=2).item()
                    certainty = values[..., 0] - values[..., -1]
                    model_weight = certainty**self.l # / torch.sum(main_model_generate_ids_logits).item()
                    model_weights.append(model_weight)

                elif self.unc == "au":
                    if self.k>0:
                        values, topk_indices = torch.topk(main_model_generate_ids_logits, k=self.k, dim=-1)
                    else:
                        values = main_model_generate_ids_logits
                    alpha_k = values 
                    alpha_0 = torch.sum(alpha_k, dim=-1, keepdim=True)  
                    
                    psi_alpha_k = torch.digamma(alpha_k + 1)
                    psi_alpha_0 = torch.digamma(alpha_0 + 1)
                    AU = -torch.sum((alpha_k / alpha_0) * (psi_alpha_k - psi_alpha_0), dim=-1)
                    model_weight = 1.0 / (AU + 1e-9)
                    model_weights.append(model_weight)

                elif self.unc=="renyi":
                    if self.k>0:
                        topk_values, topk_indices = torch.topk(main_model_generate_ids_logits, k=self.k, dim=-1)
                        unc_probs = nn.functional.softmax(topk_values, dim=-1)
                        margin = topk_values[..., 0] - topk_values[..., 1]
                    else:
                        unc_probs = main_model_generate_ids_logits

                    q = self.l
                    uncertainty = renyi_entropy(unc_probs, q)
                    # uncertainty = uncertainty * torch.exp(-0.5 * margin)

                    model_weight = 1.0 / (uncertainty + 1e-9)
                    # model_weight = 1.0 / torch.nn.functional.softplus(uncertainty)
                    model_weights.append(model_weight.item())

                if self.unc=="entropy" and self.early_exit and uncertainty<0.2*torch.log(torch.tensor(self.k)):
                    next_tokens_id = torch.argmax(scores, dim=-1)
                    self.ensemble_model_output_ids_queue.put(next_tokens_id)
                    return scores
                    
                for index, (assist_model_generate_ids_logits, assist_model_probability_transfer_matrix) in enumerate(
                        zip(assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list)):
                    
                    if assist_model_generate_ids_logits is not None:
                        assist_model_generate_ids_logits = assist_model_generate_ids_logits.to(assist_model_probability_transfer_matrix.device)
                        assist_model_generate_ids_probs = nn.functional.softmax(
                            assist_model_generate_ids_logits.to(torch.float32), dim=-1
                        ).to(assist_model_probability_transfer_matrix.device)

                        if self.topk == -1 or self.select==1:
                            # assist_model_relative_representation_probs = torch.mm(
                            #     assist_model_generate_ids_probs, assist_model_probability_transfer_matrix.T
                            # )
                            assist_model_relative_representation_probs = torch.sparse.mm(
                                assist_model_probability_transfer_matrix.T, assist_model_generate_ids_probs.T).T
                            
                        elif self.topk==-2:
                            assist_model_generate_ids_logits = assist_model_generate_ids_logits

                            buffer_k = self.k  # self.dynamic_topk 例如 200
                            values, indices = torch.topk(
                                assist_model_generate_ids_logits.to(assist_model_probability_transfer_matrix.device), k=buffer_k, dim=-1
                            )

                            prob_values = nn.functional.softmax(values, dim=-1)
                            cumulative_probs = torch.cumsum(prob_values, dim=-1)

                            keep_mask = cumulative_probs < self.top_p  
                            keep_mask[:, 1:] = keep_mask[:, :-1].clone() 
                            keep_mask[:, 0] = True  
                            selected_probs = prob_values * keep_mask 

                            selected_probs = selected_probs / selected_probs.sum(dim=-1, keepdim=True)

                            batch_size, vocab_size = assist_model_generate_ids_logits.shape
                            one_hot_matrix = torch.zeros(batch_size, vocab_size, device=indices.device)
                            one_hot_matrix.scatter_(1, indices, selected_probs)  

                            assist_model_relative_representation_probs = torch.sparse.mm(
                                assist_model_probability_transfer_matrix.T, one_hot_matrix.T  #
                            ).T

                        elif self.topk==0:
                            assist_model_generate_ids_logits = assist_model_generate_ids_logits.to(assist_model_probability_transfer_matrix.device)
                            assist_model_generate_ids_logits = assist_model_generate_ids_logits
                            logits_std = torch.std(assist_model_generate_ids_logits, dim=-1, keepdim=True)

                            logits_max = torch.max(assist_model_generate_ids_logits, dim=-1, keepdim=True).values
                            threshold = logits_max - logits_std

                            # masked_logits = torch.where(
                            #     assist_model_generate_ids_logits,  
                            # )
                            mask = main_model_generate_ids_logits >= threshold
                            selected_count = torch.sum(mask, dim=-1, keepdim=True)

                            if selected_count.min().item() < self.k:
                                topk_values, _ = torch.topk(assist_model_generate_ids_logits, k=self.k, dim=-1)
                                new_threshold = topk_values[:, -1].unsqueeze(-1)  
                                threshold = new_threshold  

                            masked_logits = torch.where(
                                assist_model_generate_ids_logits >= threshold,  
                                assist_model_generate_ids_logits,  
                                torch.full_like(assist_model_generate_ids_logits, float('-inf'))
                            )

                            masked_probs = nn.functional.softmax(masked_logits, dim=-1)
                            assist_model_relative_representation_probs = torch.sparse.mm(
                                assist_model_probability_transfer_matrix.T, masked_probs.T  
                            ).T

                        else:
                            assist_model_generate_ids_logits = assist_model_generate_ids_logits
                            values, indices = torch.topk(assist_model_generate_ids_logits.to(assist_model_probability_transfer_matrix.device), k=self.topk, dim=-1)
                            prob_values = nn.functional.softmax(values, dim=-1)
                            batch_size, vocab_size = assist_model_generate_ids_probs.shape  # [batch_size, vocab_size]
                            one_hot_matrix = torch.zeros(batch_size, vocab_size, device=indices.device)
                            one_hot_matrix.scatter_(1, indices, prob_values)  

                            assist_model_relative_representation_probs = torch.sparse.mm(
                                assist_model_probability_transfer_matrix.T, one_hot_matrix.T  
                            ).T  

                        if self.unc == "entropy":
                            if self.k>0:
                                assist_model_generate_ids_logits = assist_model_generate_ids_logits.to(assist_model_probability_transfer_matrix.device)
                                topk_values, topk_indices = torch.topk(assist_model_generate_ids_logits, k=self.k, dim=-1)
                                unc_probs = nn.functional.softmax(topk_values, dim=-1)
                                margin = topk_values[..., 0] - topk_values[..., 1]
                            else:
                                unc_probs = assist_model_relative_representation_probs
                            uncertainty = -torch.sum(unc_probs * torch.log(unc_probs + 1e-9))

                            # uncertainty = uncertainty * torch.exp(-0.5 * margin)

                            model_weight = 1.0 / (uncertainty + 1e-9)
                            # model_weight = 1.0 / torch.nn.functional.softplus(uncertainty)
                            model_weights.append(model_weight.item())
                        elif self.unc == "logit":
                            if self.k>0:
                                values, topk_indices = torch.topk(assist_model_generate_ids_logits, k=self.k, dim=-1)
                            else:
                                values = assist_model_generate_ids_logits
                            # print(f"assist {index}:", values)
                            certainty = torch.norm(values, p=2).item()
                            # certainty = values[..., 0] - values[..., -1]
                            model_weight = certainty**self.l # / torch.sum(main_model_generate_ids_logits).item()
                            model_weights.append(model_weight)
                        elif self.unc == "energy":
                            if self.k>0:
                                values, topk_indices = torch.topk(assist_model_generate_ids_logits, k=self.k, dim=-1)
                            else:
                                values = assist_model_generate_ids_logits
                            # certainty = torch.logsumexp(values, dim=-1).item()
                            model_weight = torch.logsumexp(values, dim=-1).item()
                            model_weights.append(model_weight**self.l)

                        elif self.unc == "inv_energy":
                            if self.k>0:
                                values, topk_indices = torch.topk(assist_model_generate_ids_logits, k=self.k, dim=-1)
                            else:
                                values = assist_model_generate_ids_logits
                            energy = -self.l*torch.logsumexp(values/self.l, dim=-1)
                            model_weight = 1 / (1 + torch.exp(energy))
                            model_weights.append(model_weight.item())

                        elif self.unc == "margin":
                            if self.k>0:
                                values, topk_indices = torch.topk(assist_model_generate_ids_logits, k=self.k, dim=-1)
                            else:
                                values = assist_model_generate_ids_logits
                            certainty = values[..., 0] - values[..., -1]
                            model_weight = certainty**self.l 
                            model_weights.append(model_weight)
                        
                        elif self.unc == "au":
                            if self.k>0:
                                values, topk_indices = torch.topk(assist_model_generate_ids_logits, k=self.k, dim=-1)
                            else:
                                values = main_model_generate_ids_logits
                            alpha_k = values  
                            alpha_0 = torch.sum(alpha_k, dim=-1, keepdim=True)  # alpha_0 是 top-K logits 的和
                            
                            psi_alpha_k = torch.digamma(alpha_k)
                            psi_alpha_0 = torch.digamma(alpha_0)
                            AU = -torch.sum((alpha_k / alpha_0) * (psi_alpha_k - psi_alpha_0), dim=-1)
                            model_weight = 1.0 / (AU + 1e-9)
                            model_weights.append(model_weight)
                        elif self.unc=="renyi":
                            if self.k>0:
                                topk_values, topk_indices = torch.topk(assist_model_generate_ids_logits, k=self.k, dim=-1)
                                unc_probs = nn.functional.softmax(topk_values, dim=-1)
                                margin = topk_values[..., 0] - topk_values[..., 1]
                            else:
                                unc_probs = assist_model_relative_representation_probs

                            q = self.l
                            uncertainty = renyi_entropy(unc_probs, q)
                            # uncertainty = uncertainty * torch.exp(-0.5 * margin)

                            model_weight = 1.0 / (uncertainty + 1e-9)
                            # model_weight = 1.0 / torch.nn.functional.softplus(uncertainty)
                            model_weights.append(model_weight.item())
                        
                        
                        model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

                min_last_dim = min(tensor.shape[-1] for tensor in model_relative_representation_probs_list)
                # devices_s =[tensor.device for tensor in model_relative_representation_probs_list]
                average_probs = torch.zeros_like(main_model_generate_ids_probs[:, :min_last_dim])
                if self.unc == "logit" or self.unc == "entropy" or self.unc=="energy" or self.unc=="inv_energy" or self.unc=="margin" or self.unc=="renyi":
                    if self.top_p>1:
                        total = sum(model_weights)
                        model_weights = [w / total for w in model_weights]
                        cap = 0.5
                        if model_weights[0]<cap:
                            total = sum(model_weights[1:])
                            model_weights = [(w / total)*(1-cap) for w in model_weights]
                            model_weights[0] = cap
                        else:
                            model_weights[0] = 1.0
                            model_weights[1:] = torch.zeros_like(torch.tensor(model_weights[1:]))
                    elif self.top_p==1.:
                        total = sum(model_weights)
                        model_weights = [w / total for w in model_weights]
                        cap = 0.5
                        if model_weights[0]<cap:
                            total = sum(model_weights[1:])
                            model_weights = [(w / total)*(1-cap) for w in model_weights]
                            model_weights[0] = cap
                    elif self.top_p==0.:
                        total = sum(model_weights)
                        model_weights = [w / total for w in model_weights]
                    else:
                        total = sum(model_weights[1:])
                        model_weights = [(w / total)*0.5 for w in model_weights]
                        model_weights[0] = 0.5

                    device = model_relative_representation_probs_list[0].device
                    model_relative_representation_probs_list = [
                        sth.to(device) for sth in model_relative_representation_probs_list
                    ]
                    print(model_weights)
                    for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                        average_probs += weight * probs
                else:
                    model_relative_representation_probs_list = [tensor[:, :min_last_dim] for tensor in model_relative_representation_probs_list]
                    for weight, probs in zip(self.ensemble_weight, model_relative_representation_probs_list):
                        average_probs += weight * probs.to(self.device_compute)

                next_tokens_id = torch.argmax(average_probs, dim=-1)
                self.ensemble_model_output_ids_queue.put(next_tokens_id)

                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    # json_object["ensemble_output_tokens"] = self.main_model_tokenizer.convert_ids_to_tokens(next_tokens_id.tolist())
                    process_file.write(json.dumps(json_object, ensure_ascii=False) + "\n")
                eps = torch.finfo(average_probs.dtype).eps  
                average_logits = torch.log(average_probs + eps)
                return average_logits.to(self.device).detach()
            else:
                next_tokens_id = torch.argmax(scores, dim=-1)
                self.ensemble_model_output_ids_queue.put(next_tokens_id)
                return scores

