import json
import logging
import math
import os
import pdb
import queue

import torch
from torch import nn
from torch.autograd import Variable
from torch.optim import lr_scheduler
from transformers import LogitsProcessor

from src.calculate.block_cosine_similarity import kl_divergence



class YiPPLBasedOnProbabilityTransferLogitsPIQAProcessor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def calculate_ppl(self, logits, labels):
        # Convert logits to probabilities using softmax

        # Calculate the negative log likelihood
        neg_log_likelihood = torch.nn.functional.cross_entropy(logits, labels, reduction='mean')

        # Calculate perplexity
        ppl = torch.exp(neg_log_likelihood)

        return ppl.item()

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        

        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")

            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True

        # ▁A ▁B
        A_index = 647
        B_index = 690
        token_ABCD_index_list = []
        token_ABCD_index_list.append(A_index)
        token_ABCD_index_list.append(B_index)

        ppl_ABCD_list = []
        optionA = torch.zeros_like(scores).to(self.device)
        optionA[:, A_index] = 1
        optionB = torch.zeros_like(scores).to(self.device)
        optionB[:, B_index] = 1

        if not main_model_only_flag:
            main_model_generate_ids_logits = Variable(scores, requires_grad=True)
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0]

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix).to(
                    self.device_compute)
            model_relative_representation_probs_list = [main_model_relative_representation_probs]
            with torch.no_grad():
                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                            dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            n = len(model_relative_representation_probs_list)
            model_weights = [1 / n for i in range(1, n + 1)]

            # print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_relative_representation_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"model_weights:{model_weights}\n")

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone()
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = local_main_model_relative_representation_matrix.to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)

            for i in range(1, self.learning_epochs_nums + 1):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, average_probs)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()
            torch.set_grad_enabled(False)
            optionA_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionA)
            optionB_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionB)

            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output
        else:
            optionA_ppl = self.calculate_ppl(scores, optionA)
            optionB_ppl = self.calculate_ppl(scores, optionB)
            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output
class LLaMAPPLBasedOnProbabilityTransferLogitsPIQAProcessor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def calculate_ppl(self, logits, labels):
        # Convert logits to probabilities using softmax

        # Calculate the negative log likelihood
        neg_log_likelihood = torch.nn.functional.cross_entropy(logits, labels, reduction='mean')

        # Calculate perplexity
        ppl = torch.exp(neg_log_likelihood)

        return ppl.item()

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        

        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")

            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True

        # ▁A ▁B
        A_index = 319
        B_index = 350
        token_ABCD_index_list = []
        token_ABCD_index_list.append(A_index)
        token_ABCD_index_list.append(B_index)

        ppl_ABCD_list = []
        optionA = torch.zeros_like(scores).to(self.device)
        optionA[:, A_index] = 1
        optionB = torch.zeros_like(scores).to(self.device)
        optionB[:, B_index] = 1

        if not main_model_only_flag:
            main_model_generate_ids_logits = Variable(scores, requires_grad=True)
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0]

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix).to(
                    self.device_compute)
            model_relative_representation_probs_list = [main_model_relative_representation_probs]
            with torch.no_grad():
                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                            dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            n = len(model_relative_representation_probs_list)
            model_weights = [1 / n for i in range(1, n + 1)]

            print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_relative_representation_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"model_weights:{model_weights}\n")

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone()
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = local_main_model_relative_representation_matrix.to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)

            for i in range(1, self.learning_epochs_nums + 1):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, average_probs)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()
            torch.set_grad_enabled(False)
            optionA_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionA)
            optionB_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionB)

            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output
        else:
            optionA_ppl = self.calculate_ppl(scores, optionA)
            optionB_ppl = self.calculate_ppl(scores, optionB)
            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output

class YiPPLBasedOnProbabilityTransferLogitsProcessor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def calculate_ppl(self, logits, labels):
        # Convert logits to probabilities using softmax

        # Calculate the negative log likelihood
        neg_log_likelihood = torch.nn.functional.cross_entropy(logits, labels, reduction='mean')

        # Calculate perplexity
        ppl = torch.exp(neg_log_likelihood)

        return ppl.item()

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        

        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")

            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True

        # ▁A ▁B
        A_index = 647
        B_index = 690
        C_index = 650
        D_index = 723
        token_ABCD_index_list = []
        token_ABCD_index_list.append(A_index)
        token_ABCD_index_list.append(B_index)
        token_ABCD_index_list.append(C_index)
        token_ABCD_index_list.append(D_index)

        ppl_ABCD_list = []
        optionA = torch.zeros_like(scores).to(self.device)
        optionA[:, A_index] = 1
        optionB = torch.zeros_like(scores).to(self.device)
        optionB[:, B_index] = 1
        optionC = torch.zeros_like(scores).to(self.device)
        optionC[:, C_index] = 1
        optionD = torch.zeros_like(scores).to(self.device)
        optionD[:, D_index] = 1

        if not main_model_only_flag:
            main_model_generate_ids_logits = Variable(scores, requires_grad=True)
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0]

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix).to(
                    self.device_compute)
            model_relative_representation_probs_list = [main_model_relative_representation_probs]
            with torch.no_grad():
                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                            dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            n = len(model_relative_representation_probs_list)
            model_weights = [1 / n for i in range(1, n + 1)]

            # print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_relative_representation_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            cosine_similarity_list = []
            kl_div_list = []
            for model_relative_representation_probs in model_relative_representation_probs_list:
                cosine_similarity_list.append(
                    torch.cosine_similarity(model_relative_representation_probs, average_probs).item())
                kl_div_list.append(
                    torch.nn.functional.kl_div(model_relative_representation_probs.log(), average_probs).item())

            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"model_weights:{model_weights}\n")
                process_file.write(f"cosine_similarity_list:{cosine_similarity_list}\n")
                process_file.write(f"kl_div_list:{kl_div_list}\n")

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone()
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = local_main_model_relative_representation_matrix.to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)

            for i in range(1, self.learning_epochs_nums + 1):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, average_probs)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()
            torch.set_grad_enabled(False)
            optionA_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionA)
            optionB_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionB)
            optionC_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionC)
            optionD_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionD)

            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            ppl_ABCD_list.append(optionC_ppl)
            ppl_ABCD_list.append(optionD_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output
        else:
            optionA_ppl = self.calculate_ppl(scores, optionA)
            optionB_ppl = self.calculate_ppl(scores, optionB)
            optionC_ppl = self.calculate_ppl(scores, optionC)
            optionD_ppl = self.calculate_ppl(scores, optionD)
            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            ppl_ABCD_list.append(optionC_ppl)
            ppl_ABCD_list.append(optionD_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output


class LLaMAPPLBasedOnProbabilityTransferLogitsProcessor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def calculate_ppl(self, logits, labels):
        # Convert logits to probabilities using softmax

        # Calculate the negative log likelihood
        neg_log_likelihood = torch.nn.functional.cross_entropy(logits, labels, reduction='mean')

        # Calculate perplexity
        ppl = torch.exp(neg_log_likelihood)

        return ppl.item()

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        

        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")

            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True

        # ▁A ▁B
        A_index = 319
        B_index = 350
        C_index = 315
        D_index = 360
        token_ABCD_index_list = []
        token_ABCD_index_list.append(A_index)
        token_ABCD_index_list.append(B_index)
        token_ABCD_index_list.append(C_index)
        token_ABCD_index_list.append(D_index)

        ppl_ABCD_list = []
        optionA = torch.zeros_like(scores).to(self.device)
        optionA[:, A_index] = 1
        optionB = torch.zeros_like(scores).to(self.device)
        optionB[:, B_index] = 1
        optionC = torch.zeros_like(scores).to(self.device)
        optionC[:, C_index] = 1
        optionD = torch.zeros_like(scores).to(self.device)
        optionD[:, D_index] = 1

        if not main_model_only_flag:
            main_model_generate_ids_logits = Variable(scores, requires_grad=True)
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0]

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix).to(
                    self.device_compute)
            model_relative_representation_probs_list = [main_model_relative_representation_probs]
            with torch.no_grad():
                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                            dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            n = len(model_relative_representation_probs_list)
            model_weights = [1 / n for i in range(1, n + 1)]

            # print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_relative_representation_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            cosine_similarity_list = []
            kl_div_list = []
            for model_relative_representation_probs in model_relative_representation_probs_list:
                cosine_similarity_list.append(
                    torch.cosine_similarity(model_relative_representation_probs, average_probs).item())
                kl_div_list.append(
                    torch.nn.functional.kl_div(model_relative_representation_probs.log(), average_probs).item())

            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"model_weights:{model_weights}\n")
                process_file.write(f"cosine_similarity_list:{cosine_similarity_list}\n")
                process_file.write(f"kl_div_list:{kl_div_list}\n")

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone()
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = local_main_model_relative_representation_matrix.to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)

            for i in range(1, self.learning_epochs_nums + 1):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, average_probs)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()
            torch.set_grad_enabled(False)
            optionA_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionA)
            optionB_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionB)
            optionC_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionC)
            optionD_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionD)

            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            ppl_ABCD_list.append(optionC_ppl)
            ppl_ABCD_list.append(optionD_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output
        else:
            optionA_ppl = self.calculate_ppl(scores, optionA)
            optionB_ppl = self.calculate_ppl(scores, optionB)
            optionC_ppl = self.calculate_ppl(scores, optionC)
            optionD_ppl = self.calculate_ppl(scores, optionD)
            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            ppl_ABCD_list.append(optionC_ppl)
            ppl_ABCD_list.append(optionD_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output


class LLaMAPPLBasedOnProbabilityTransferLogitsPIQAReWeightProcessor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def calculate_ppl(self, logits, labels):
        # Convert logits to probabilities using softmax

        # Calculate the negative log likelihood
        neg_log_likelihood = torch.nn.functional.cross_entropy(logits, labels, reduction='mean')

        # Calculate perplexity
        ppl = torch.exp(neg_log_likelihood)

        return ppl.item()

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        

        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")

            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True

        # ▁A ▁B
        A_index = 319
        B_index = 350
        token_ABCD_index_list = []
        token_ABCD_index_list.append(A_index)
        token_ABCD_index_list.append(B_index)

        ppl_ABCD_list = []
        optionA = torch.zeros_like(scores).to(self.device)
        optionA[:, A_index] = 1
        optionB = torch.zeros_like(scores).to(self.device)
        optionB[:, B_index] = 1

        if not main_model_only_flag:
            main_model_generate_ids_logits = Variable(scores, requires_grad=True)
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0]

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix).to(
                    self.device_compute)
            model_relative_representation_probs_list = [main_model_relative_representation_probs]
            with torch.no_grad():
                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                            dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            if len(model_relative_representation_probs_list) == 2:
                model_weights = [0.5078,
                                 0.4922]
            elif len(model_relative_representation_probs_list) == 3:
                model_weights = [0.3496,
                                 0.3389,
                                 0.3115]
            elif len(model_relative_representation_probs_list) == 4:
                model_weights = [0.2721,
                                 0.2637,
                                 0.2424,
                                 0.2218]
            elif len(model_relative_representation_probs_list) == 5:
                model_weights = [0.2240,
                                 0.2171,
                                 0.1996,
                                 0.1826,
                                 0.1766]
            elif len(model_relative_representation_probs_list) == 6:
                model_weights = [0.1945,
                                 0.1885,
                                 0.1733,
                                 0.1585,
                                 0.1533,
                                 0.1318]
            elif len(model_relative_representation_probs_list) == 7:
                model_weights = [0.1750,
                                 0.1696,
                                 0.1559,
                                 0.1426,
                                 0.1379,
                                 0.1185,
                                 0.1005]
            elif len(model_relative_representation_probs_list) == 8:
                model_weights = [0.1597,
                                 0.1547,
                                 0.1423,
                                 0.1301,
                                 0.1259,
                                 0.1082,
                                 0.0917,
                                 0.0875]
            elif len(model_relative_representation_probs_list) == 9:
                model_weights = [0.1510,
                                 0.1464,
                                 0.1345,
                                 0.1231,
                                 0.1190,
                                 0.1023,
                                 0.0868,
                                 0.0827,
                                 0.0542]
            else:
                model_weights = [1 / len(model_relative_representation_probs_list)] * len(
                    model_relative_representation_probs_list)
            print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_relative_representation_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone()
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = local_main_model_relative_representation_matrix.to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)

            for i in range(1, self.learning_epochs_nums + 1):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, average_probs)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()
            torch.set_grad_enabled(False)
            optionA_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionA)
            optionB_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionB)

            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)

            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output
        else:
            optionA_ppl = self.calculate_ppl(scores, optionA)
            optionB_ppl = self.calculate_ppl(scores, optionB)
            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)

            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output


class MistralPPLBasedOnProbabilityTransferLogitsProcessor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def calculate_ppl(self, logits, labels):
        # Convert logits to probabilities using softmax

        # Calculate the negative log likelihood
        neg_log_likelihood = torch.nn.functional.cross_entropy(logits, labels, reduction='mean')

        # Calculate perplexity
        ppl = torch.exp(neg_log_likelihood)

        return ppl.item()

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        

        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")

            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True

        # ▁A ▁B
        A_index = 330
        B_index = 365
        C_index = 334
        D_index = 384
        token_ABCD_index_list = []
        token_ABCD_index_list.append(A_index)
        token_ABCD_index_list.append(B_index)
        token_ABCD_index_list.append(C_index)
        token_ABCD_index_list.append(D_index)

        ppl_ABCD_list = []
        optionA = torch.zeros_like(scores).to(self.device)
        optionA[:, A_index] = 1
        optionB = torch.zeros_like(scores).to(self.device)
        optionB[:, B_index] = 1
        optionC = torch.zeros_like(scores).to(self.device)
        optionC[:, C_index] = 1
        optionD = torch.zeros_like(scores).to(self.device)
        optionD[:, D_index] = 1

        if not main_model_only_flag:
            main_model_generate_ids_logits = Variable(scores, requires_grad=True)
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0]

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix).to(
                    self.device_compute)
            model_relative_representation_probs_list = [main_model_relative_representation_probs]
            with torch.no_grad():
                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                            dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            n = len(model_relative_representation_probs_list)
            model_weights = [1 / n for i in range(1, n + 1)]

            # print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_relative_representation_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            cosine_similarity_list = []
            kl_div_list = []
            for model_relative_representation_probs in model_relative_representation_probs_list:
                cosine_similarity_list.append(
                    torch.cosine_similarity(model_relative_representation_probs, average_probs).item())
                kl_div_list.append(
                    torch.nn.functional.kl_div(model_relative_representation_probs.log(), average_probs).item())

            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"model_weights:{model_weights}\n")
                process_file.write(f"cosine_similarity_list:{cosine_similarity_list}\n")
                process_file.write(f"kl_div_list:{kl_div_list}\n")

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone()
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = local_main_model_relative_representation_matrix.to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)

            for i in range(1, self.learning_epochs_nums + 1):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, average_probs)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()
            torch.set_grad_enabled(False)
            optionA_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionA)
            optionB_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionB)
            optionC_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionC)
            optionD_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionD)

            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            ppl_ABCD_list.append(optionC_ppl)
            ppl_ABCD_list.append(optionD_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output
        else:
            optionA_ppl = self.calculate_ppl(scores, optionA)
            optionB_ppl = self.calculate_ppl(scores, optionB)
            optionC_ppl = self.calculate_ppl(scores, optionC)
            optionD_ppl = self.calculate_ppl(scores, optionD)
            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            ppl_ABCD_list.append(optionC_ppl)
            ppl_ABCD_list.append(optionD_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output


class YiPPLBasedOnProbabilityTransferLogitsReWeightProcessor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def calculate_ppl(self, logits, labels):
        # Convert logits to probabilities using softmax

        # Calculate the negative log likelihood
        neg_log_likelihood = torch.nn.functional.cross_entropy(logits, labels, reduction='mean')

        # Calculate perplexity
        ppl = torch.exp(neg_log_likelihood)

        return ppl.item()

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        

        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")

            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True

        # ▁A ▁B
        A_index = 647
        B_index = 690
        C_index = 650
        D_index = 723
        token_ABCD_index_list = []
        token_ABCD_index_list.append(A_index)
        token_ABCD_index_list.append(B_index)
        token_ABCD_index_list.append(C_index)
        token_ABCD_index_list.append(D_index)

        ppl_ABCD_list = []
        optionA = torch.zeros_like(scores).to(self.device)
        optionA[:, A_index] = 1
        optionB = torch.zeros_like(scores).to(self.device)
        optionB[:, B_index] = 1
        optionC = torch.zeros_like(scores).to(self.device)
        optionC[:, C_index] = 1
        optionD = torch.zeros_like(scores).to(self.device)
        optionD[:, D_index] = 1

        if not main_model_only_flag:
            main_model_generate_ids_logits = Variable(scores, requires_grad=True)
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0]

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix).to(
                    self.device_compute)
            model_relative_representation_probs_list = [main_model_relative_representation_probs]
            with torch.no_grad():
                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                            dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            if len(model_relative_representation_probs_list) == 2:
                model_weights = [0.5014, 0.4633]
            elif len(model_relative_representation_probs_list) == 3:
                model_weights = [0.3412, 0.3196, 0.3152]
            elif len(model_relative_representation_probs_list) == 4:
                model_weights = [0.2585, 0.2422, 0.2389, 0.1912]
            elif len(model_relative_representation_probs_list) == 5:
                model_weights = [0.2087, 0.2075, 0.1955, 0.1928, 0.1544]
            elif len(model_relative_representation_probs_list) == 6:
                model_weights = [0.1787, 0.1777, 0.1674, 0.1674, 0.1652, 0.1322]
            elif len(model_relative_representation_probs_list) == 7:
                model_weights = [0.1579, 0.1570, 0.1479, 0.1479, 0.1459, 0.1268, 0.1168]
            else:
                model_weights = [1 / len(model_relative_representation_probs_list)] * len(
                    model_relative_representation_probs_list)

            # print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_relative_representation_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            # cosine_similarity_list = [0.2593, 0.2080, 0.2430, 0.2397]
            # kl_div_list = []
            # for model_relative_representation_probs in model_relative_representation_probs_list:
            #     cosine_similarity_list.append(
            #         torch.cosine_similarity(model_relative_representation_probs, average_probs).item())
            #     kl_div_list.append(
            #         torch.nn.functional.kl_div(model_relative_representation_probs.log(), average_probs).item())

            # with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            #     process_file.write(f"model_weights:{model_weights}\n")
            #     process_file.write(f"cosine_similarity_list:{cosine_similarity_list}\n")
            #     process_file.write(f"kl_div_list:{kl_div_list}\n")

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone()
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = local_main_model_relative_representation_matrix.to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)

            for i in range(1, self.learning_epochs_nums + 1):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, average_probs)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()
            torch.set_grad_enabled(False)
            optionA_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionA)
            optionB_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionB)
            optionC_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionC)
            optionD_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionD)

            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            ppl_ABCD_list.append(optionC_ppl)
            ppl_ABCD_list.append(optionD_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output
        else:
            optionA_ppl = self.calculate_ppl(scores, optionA)
            optionB_ppl = self.calculate_ppl(scores, optionB)
            optionC_ppl = self.calculate_ppl(scores, optionC)
            optionD_ppl = self.calculate_ppl(scores, optionD)
            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            ppl_ABCD_list.append(optionC_ppl)
            ppl_ABCD_list.append(optionD_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output


class YiPPLBasedOnProbabilityTransferLogitsPIQAReWrightByDevProcessor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def calculate_ppl(self, logits, labels):
        neg_log_likelihood = torch.nn.functional.cross_entropy(logits, labels, reduction='mean')
        ppl = torch.exp(neg_log_likelihood)
        return ppl.item()

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        

        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")

            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True

        # ▁A ▁B
        A_index = 647
        B_index = 690
        token_ABCD_index_list = []
        token_ABCD_index_list.append(A_index)
        token_ABCD_index_list.append(B_index)

        ppl_ABCD_list = []
        optionA = torch.zeros_like(scores).to(self.device)
        optionA[:, A_index] = 1
        optionB = torch.zeros_like(scores).to(self.device)
        optionB[:, B_index] = 1

        if not main_model_only_flag:
            main_model_generate_ids_logits = Variable(scores, requires_grad=True)
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).to(
                    torch.float32)
                local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0]

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix).to(
                    self.device_compute)
            model_relative_representation_probs_list = [main_model_relative_representation_probs]
            with torch.no_grad():
                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(
                        assist_model_generate_ids_logits.to(torch.float32),
                        dim=-1).float()

                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          assist_model_probability_transfer_matrix).to(
                        self.device_compute)
                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            if len(model_relative_representation_probs_list) == 2:
                model_weights = [0.5327, 0.4673]
            elif len(model_relative_representation_probs_list) == 3:
                model_weights = [0.3633, 0.3187, 0.3180]
            elif len(model_relative_representation_probs_list) == 4:
                # model_weights = [0.2868, 0.2516, 0.2511, 0.2106]
                model_weights = [0.3486, 0.3051, 0.1862, 0.1601]
            elif len(model_relative_representation_probs_list) == 5:
                model_weights = [0.2365, 0.2075, 0.2070, 0.1754, 0.1737]
            elif len(model_relative_representation_probs_list) == 6:
                model_weights = [0.2100, 0.1842, 0.1838, 0.1557, 0.1542, 0.1122]
            elif len(model_relative_representation_probs_list) == 7:
                model_weights = [0.1900, 0.1667, 0.1664, 0.1409, 0.1395, 0.1015, 0.0949]
            elif len(model_relative_representation_probs_list) == 8:
                model_weights = [0.1748, 0.1533, 0.1530, 0.1296, 0.1283, 0.0934, 0.0873, 0.0803]
            elif len(model_relative_representation_probs_list) == 9:
                model_weights = [0.1657, 0.1453, 0.1450, 0.1228, 0.1217, 0.0885, 0.0828, 0.0761, 0.0521]
            else:
                model_weights = [1 / len(model_relative_representation_probs_list)] * len(
                    model_relative_representation_probs_list)

            # print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_relative_representation_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"model_weights:{model_weights}\n")

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone().to(
                torch.float32)
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = local_main_model_relative_representation_matrix.to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)

            for i in range(1, self.learning_epochs_nums + 1):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, average_probs)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()
            torch.set_grad_enabled(False)
            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"ensemble_result:\n{main_model_generate_ids_logits.topk(10)}\n")
            optionA_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionA)
            optionB_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionB)

            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')

            return output

        else:
            optionA_ppl = self.calculate_ppl(scores, optionA)
            optionB_ppl = self.calculate_ppl(scores, optionB)
            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output


class YiPPLBasedOnProbabilityTransferLogitsPIQAMinCEProcessor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def calculate_ppl(self, logits, labels):
        neg_log_likelihood = torch.nn.functional.cross_entropy(logits, labels, reduction='mean')
        ppl = torch.exp(neg_log_likelihood)
        return ppl.item()

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        

        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")

            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True

        # ▁A ▁B
        A_index = 647
        B_index = 690
        token_ABCD_index_list = []
        token_ABCD_index_list.append(A_index)
        token_ABCD_index_list.append(B_index)

        ppl_ABCD_list = []
        optionA = torch.zeros_like(scores).to(self.device)
        optionA[:, A_index] = 1
        optionB = torch.zeros_like(scores).to(self.device)
        optionB[:, B_index] = 1

        if not main_model_only_flag:
            main_model_generate_ids_logits = Variable(scores, requires_grad=True)
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits,
                                                                      dim=-1).float().to(self.device_compute)
                model_relative_representation_probs_list = [main_model_generate_ids_probs]

                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                            dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            model_weights = [1 / len(model_relative_representation_probs_list)] * len(
                model_relative_representation_probs_list)

            # print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_generate_ids_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"model_weights:{model_weights}\n")

            main_model_generate_ids_logits = average_probs * 15
            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"ensemble_result:\n{main_model_generate_ids_logits.topk(10)}\n")
            optionA_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionA)
            optionB_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionB)

            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')

            return output

        else:
            optionA_ppl = self.calculate_ppl(scores, optionA)
            optionB_ppl = self.calculate_ppl(scores, optionB)
            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output


class InternLMPPLBasedOnProbabilityTransferLogitsPIQAReWrightByDevProcessor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def calculate_ppl(self, logits, labels):
        # Convert logits to probabilities using softmax

        # Calculate the negative log likelihood
        neg_log_likelihood = torch.nn.functional.cross_entropy(logits, labels, reduction='mean')

        # Calculate perplexity
        ppl = torch.exp(neg_log_likelihood)

        return ppl.item()

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        

        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")

            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True

        # ▁A ▁B
        A_index = 493
        B_index = 556
        C_index = 487
        D_index = 553
        token_ABCD_index_list = []
        token_ABCD_index_list.append(A_index)
        token_ABCD_index_list.append(B_index)
        token_ABCD_index_list.append(C_index)
        token_ABCD_index_list.append(D_index)

        ppl_ABCD_list = []
        optionA = torch.zeros_like(scores).to(self.device)
        optionA[:, A_index] = 1
        optionB = torch.zeros_like(scores).to(self.device)
        optionB[:, B_index] = 1
        optionC = torch.zeros_like(scores).to(self.device)
        optionC[:, C_index] = 1
        optionD = torch.zeros_like(scores).to(self.device)
        optionD[:, D_index] = 1

        if not main_model_only_flag:
            main_model_generate_ids_logits = Variable(scores, requires_grad=True)
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0]

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix).to(
                    self.device_compute)
            model_relative_representation_probs_list = [main_model_relative_representation_probs]
            with torch.no_grad():
                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                            dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            model_weights = [0.5856, 0.3159, 0.0583, 0.0402]

            # print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_relative_representation_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            cosine_similarity_list = []
            kl_div_list = []
            for model_relative_representation_probs in model_relative_representation_probs_list:
                cosine_similarity_list.append(
                    torch.cosine_similarity(model_relative_representation_probs, average_probs).item())
                kl_div_list.append(
                    torch.nn.functional.kl_div(model_relative_representation_probs.log(), average_probs).item())

            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"model_weights:{model_weights}\n")
                process_file.write(f"cosine_similarity_list:{cosine_similarity_list}\n")
                process_file.write(f"kl_div_list:{kl_div_list}\n")

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone()
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = local_main_model_relative_representation_matrix.to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)

            for i in range(1, self.learning_epochs_nums + 1):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, average_probs)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()
            torch.set_grad_enabled(False)
            optionA_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionA)
            optionB_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionB)

            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output
        else:
            optionA_ppl = self.calculate_ppl(scores, optionA)
            optionB_ppl = self.calculate_ppl(scores, optionB)
            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output


class InternLMPPLBasedOnProbabilityTransferLogitsProcessor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def calculate_ppl(self, logits, labels):
        # Convert logits to probabilities using softmax

        # Calculate the negative log likelihood
        neg_log_likelihood = torch.nn.functional.cross_entropy(logits, labels, reduction='mean')

        # Calculate perplexity
        ppl = torch.exp(neg_log_likelihood)

        return ppl.item()

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        

        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")

            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True

        # ▁A ▁B
        A_index = 493
        B_index = 556
        C_index = 487
        D_index = 553
        token_ABCD_index_list = []
        token_ABCD_index_list.append(A_index)
        token_ABCD_index_list.append(B_index)
        token_ABCD_index_list.append(C_index)
        token_ABCD_index_list.append(D_index)

        ppl_ABCD_list = []
        optionA = torch.zeros_like(scores).to(self.device)
        optionA[:, A_index] = 1
        optionB = torch.zeros_like(scores).to(self.device)
        optionB[:, B_index] = 1
        optionC = torch.zeros_like(scores).to(self.device)
        optionC[:, C_index] = 1
        optionD = torch.zeros_like(scores).to(self.device)
        optionD[:, D_index] = 1

        if not main_model_only_flag:
            main_model_generate_ids_logits = Variable(scores, requires_grad=True).to(torch.float32)
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0]

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix).to(
                    self.device_compute)
            model_relative_representation_probs_list = [main_model_relative_representation_probs]
            with torch.no_grad():
                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                            dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            n = len(model_relative_representation_probs_list)
            model_weights = [1 / n for i in range(1, n + 1)]
            # model_weights = [0.2759, 0.2558, 0.2488, 0.2184]

            # print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_relative_representation_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            cosine_similarity_list = []
            kl_div_list = []
            for model_relative_representation_probs in model_relative_representation_probs_list:
                cosine_similarity_list.append(
                    torch.cosine_similarity(model_relative_representation_probs, average_probs).item())
                kl_div_list.append(
                    torch.nn.functional.kl_div(model_relative_representation_probs.log(), average_probs).item())

            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"model_weights:{model_weights}\n")
                process_file.write(f"cosine_similarity_list:{cosine_similarity_list}\n")
                process_file.write(f"kl_div_list:{kl_div_list}\n")

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone().to(
                torch.float32)
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = local_main_model_relative_representation_matrix.to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)

            for i in range(1, self.learning_epochs_nums + 1):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, average_probs)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()
            torch.set_grad_enabled(False)
            optionA_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionA)
            optionB_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionB)
            optionC_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionC)
            optionD_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionD)

            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            ppl_ABCD_list.append(optionC_ppl)
            ppl_ABCD_list.append(optionD_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output
        else:
            optionA_ppl = self.calculate_ppl(scores, optionA)
            optionB_ppl = self.calculate_ppl(scores, optionB)
            optionC_ppl = self.calculate_ppl(scores, optionC)
            optionD_ppl = self.calculate_ppl(scores, optionD)
            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            ppl_ABCD_list.append(optionC_ppl)
            ppl_ABCD_list.append(optionD_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output


class InternLMPPLBasedOnProbabilityTransferLogitsMinCEProcessor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def calculate_ppl(self, logits, labels):
        # Convert logits to probabilities using softmax

        # Calculate the negative log likelihood
        neg_log_likelihood = torch.nn.functional.cross_entropy(logits, labels, reduction='mean')

        # Calculate perplexity
        ppl = torch.exp(neg_log_likelihood)

        return ppl.item()

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        

        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")

            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True

        # ▁A ▁B
        A_index = 493
        B_index = 556
        C_index = 487
        D_index = 553
        token_ABCD_index_list = []
        token_ABCD_index_list.append(A_index)
        token_ABCD_index_list.append(B_index)
        token_ABCD_index_list.append(C_index)
        token_ABCD_index_list.append(D_index)

        ppl_ABCD_list = []
        optionA = torch.zeros_like(scores).to(self.device)
        optionA[:, A_index] = 1
        optionB = torch.zeros_like(scores).to(self.device)
        optionB[:, B_index] = 1
        optionC = torch.zeros_like(scores).to(self.device)
        optionC[:, C_index] = 1
        optionD = torch.zeros_like(scores).to(self.device)
        optionD[:, D_index] = 1

        if not main_model_only_flag:
            main_model_generate_ids_logits = Variable(scores, requires_grad=True).to(torch.float32)
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits,
                                                                      dim=-1).float().to(
                    self.device_compute)

                model_relative_representation_probs_list = [main_model_generate_ids_probs]

                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                            dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            n = len(model_relative_representation_probs_list)
            model_weights = [1 / n for i in range(1, n + 1)]
            # model_weights = [0.2759, 0.2558, 0.2488, 0.2184]

            # print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_generate_ids_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            main_model_generate_ids_logits = average_probs * 15
            optionA_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionA)
            optionB_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionB)
            optionC_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionC)
            optionD_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionD)

            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            ppl_ABCD_list.append(optionC_ppl)
            ppl_ABCD_list.append(optionD_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output
        else:
            optionA_ppl = self.calculate_ppl(scores, optionA)
            optionB_ppl = self.calculate_ppl(scores, optionB)
            optionC_ppl = self.calculate_ppl(scores, optionC)
            optionD_ppl = self.calculate_ppl(scores, optionD)
            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            ppl_ABCD_list.append(optionC_ppl)
            ppl_ABCD_list.append(optionD_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output


class YiPPLBasedOnProbabilityTransferLogitsMinCEProcessor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def calculate_ppl(self, logits, labels):
        # Convert logits to probabilities using softmax

        # Calculate the negative log likelihood
        neg_log_likelihood = torch.nn.functional.cross_entropy(logits, labels, reduction='mean')

        # Calculate perplexity
        ppl = torch.exp(neg_log_likelihood)

        return ppl.item()

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        

        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")

            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True

        # ▁A ▁B
        A_index = 647
        B_index = 690
        C_index = 650
        D_index = 723

        token_ABCD_index_list = []
        token_ABCD_index_list.append(A_index)
        token_ABCD_index_list.append(B_index)
        token_ABCD_index_list.append(C_index)
        token_ABCD_index_list.append(D_index)

        ppl_ABCD_list = []
        optionA = torch.zeros_like(scores).to(self.device)
        optionA[:, A_index] = 1
        optionB = torch.zeros_like(scores).to(self.device)
        optionB[:, B_index] = 1
        optionC = torch.zeros_like(scores).to(self.device)
        optionC[:, C_index] = 1
        optionD = torch.zeros_like(scores).to(self.device)
        optionD[:, D_index] = 1

        if not main_model_only_flag:
            main_model_generate_ids_logits = Variable(scores, requires_grad=True).to(torch.float32)
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits,
                                                                      dim=-1).float().to(self.device_compute)

                model_relative_representation_probs_list = [main_model_generate_ids_probs]

                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                            dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            n = len(model_relative_representation_probs_list)
            model_weights = [1 / n for i in range(1, n + 1)]
            # model_weights = [0.2759, 0.2558, 0.2488, 0.2184]

            # print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_generate_ids_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            main_model_generate_ids_logits = average_probs * 15
            optionA_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionA)
            optionB_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionB)
            optionC_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionC)
            optionD_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionD)

            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            ppl_ABCD_list.append(optionC_ppl)
            ppl_ABCD_list.append(optionD_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output
        else:
            optionA_ppl = self.calculate_ppl(scores, optionA)
            optionB_ppl = self.calculate_ppl(scores, optionB)
            optionC_ppl = self.calculate_ppl(scores, optionC)
            optionD_ppl = self.calculate_ppl(scores, optionD)
            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            ppl_ABCD_list.append(optionC_ppl)
            ppl_ABCD_list.append(optionD_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output


class InternLMPPLBasedOnProbabilityTransferLogitsPIQAProcessor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def calculate_ppl(self, logits, labels):
        # Convert logits to probabilities using softmax

        # Calculate the negative log likelihood
        neg_log_likelihood = torch.nn.functional.cross_entropy(logits, labels, reduction='mean')

        # Calculate perplexity
        ppl = torch.exp(neg_log_likelihood)

        return ppl.item()

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        

        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")

            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True

        # ▁A ▁B
        A_index = 493
        B_index = 556
        # C_index = 487
        # D_index = 553
        token_ABCD_index_list = []
        token_ABCD_index_list.append(A_index)
        token_ABCD_index_list.append(B_index)
        # token_ABCD_index_list.append(C_index)
        # token_ABCD_index_list.append(D_index)

        ppl_ABCD_list = []
        optionA = torch.zeros_like(scores).to(self.device)
        optionA[:, A_index] = 1
        optionB = torch.zeros_like(scores).to(self.device)
        optionB[:, B_index] = 1
        # optionC = torch.zeros_like(scores).to(self.device)
        # optionC[:, C_index] = 1
        # optionD = torch.zeros_like(scores).to(self.device)
        # optionD[:, D_index] = 1

        if not main_model_only_flag:
            main_model_generate_ids_logits = Variable(scores, requires_grad=True)
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0]

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix).to(
                    self.device_compute)
            model_relative_representation_probs_list = [main_model_relative_representation_probs]
            with torch.no_grad():
                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                            dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            n = len(model_relative_representation_probs_list)
            model_weights = [1 / n for i in range(1, n + 1)]

            # print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_relative_representation_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            cosine_similarity_list = []
            kl_div_list = []
            for model_relative_representation_probs in model_relative_representation_probs_list:
                cosine_similarity_list.append(
                    torch.cosine_similarity(model_relative_representation_probs, average_probs).item())
                kl_div_list.append(
                    torch.nn.functional.kl_div(model_relative_representation_probs.log(), average_probs).item())

            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"model_weights:{model_weights}\n")
                process_file.write(f"cosine_similarity_list:{cosine_similarity_list}\n")
                process_file.write(f"kl_div_list:{kl_div_list}\n")

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone()
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = local_main_model_relative_representation_matrix.to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)

            for i in range(1, self.learning_epochs_nums + 1):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, average_probs)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()
            torch.set_grad_enabled(False)
            optionA_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionA)
            optionB_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionB)
            # optionC_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionC)
            # optionD_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionD)

            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            # ppl_ABCD_list.append(optionC_ppl)
            # ppl_ABCD_list.append(optionD_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output
        else:
            optionA_ppl = self.calculate_ppl(scores, optionA)
            optionB_ppl = self.calculate_ppl(scores, optionB)
            # optionC_ppl = self.calculate_ppl(scores, optionC)
            # optionD_ppl = self.calculate_ppl(scores, optionD)
            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            # ppl_ABCD_list.append(optionC_ppl)
            # ppl_ABCD_list.append(optionD_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')
            return output


class Baichuan2PPLBasedOnProbabilityTransferLogitsProcessor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def calculate_ppl(self, logits, labels):
        # Convert logits to probabilities using softmax

        # Calculate the negative log likelihood
        neg_log_likelihood = torch.nn.functional.cross_entropy(logits, labels, reduction='mean')

        # Calculate perplexity
        ppl = torch.exp(neg_log_likelihood)

        return ppl.item()

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count{self.anchor_point_count}_learning_epochs_nums{self.learning_epochs_nums}.log')
        main_model_only_flag = False
        

        assist_model_generate_ids_logits_list = []
        for queue_instance in self.assist_model_score_queue_list:
            try:
                value = queue_instance.get(block=True, timeout=1.5)
                assist_model_generate_ids_logits_list.append(value)
            except queue.Empty:
                assist_model_generate_ids_logits_list.append(None)
                # print("assist_model_generate_ids_logits_list.append(None)")
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True
        if torch.argmax(scores).item() == self.forced_eos_token_id:
            main_model_only_flag = True
        # 如果aux model们传过来了logits
        A_index = 1401
        B_index = 1432
        C_index = 1399
        D_index = 1443

        token_ABCD_index_list = []
        token_ABCD_index_list.append(A_index)
        token_ABCD_index_list.append(B_index)
        token_ABCD_index_list.append(C_index)
        token_ABCD_index_list.append(D_index)

        ppl_ABCD_list = []
        optionA = torch.zeros_like(scores).to(self.device)
        optionA[:, A_index] = 1
        optionB = torch.zeros_like(scores).to(self.device)
        optionB[:, B_index] = 1
        optionC = torch.zeros_like(scores).to(self.device)
        optionC[:, C_index] = 1
        optionD = torch.zeros_like(scores).to(self.device)
        optionD[:, D_index] = 1

        # pdb.set_trace()

        # print("Warning?")
        if not main_model_only_flag:
            # print("Warning!")
            main_model_generate_ids_logits = Variable(scores, requires_grad=True)
            # pdb.set_trace()
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                # pdb.set_trace()
                local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0]

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix).to(
                    self.device_compute)
            assist_model_relative_representation_probs_list = []
            # print(len(assist_model_generate_ids_logits_list))
            with torch.no_grad():
                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                            dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    assist_model_relative_representation_probs_list.append(assist_model_relative_representation_probs)
                    # print(self.assist_model_tokenizer.decode(torch.argmax(assist_model_generate_ids_logits, dim=-1).tolist()))

            average_assist_model_relative_representation_probs = torch.mean(
                torch.stack(assist_model_relative_representation_probs_list), dim=0)

            average_probs = (main_model_relative_representation_probs +
                             average_assist_model_relative_representation_probs) / 2

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone()
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = local_main_model_relative_representation_matrix.to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)
            # print("loss.item())")
            for i in range(1, self.learning_epochs_nums + 1):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, average_probs)
                # pdb.set_trace()
                # print(loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()
                # print(loss)
                # print(main_model_generate_ids_logits.topk(5))
            torch.set_grad_enabled(False)

            next_tokens_id = torch.argmax(main_model_generate_ids_logits, dim=-1)

            optionA_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionA)
            optionB_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionB)
            optionC_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionC)
            optionD_ppl = self.calculate_ppl(main_model_generate_ids_logits.to(self.device), optionD)

            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            ppl_ABCD_list.append(optionC_ppl)
            ppl_ABCD_list.append(optionD_ppl)

            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')

            # print(self.main_model_tokenizer.decode(torch.argmax(scores, dim=-1).tolist()))
            # print(self.main_model_tokenizer.decode(torch.argmax(output, dim=-1).tolist()))
            # print(optionA_ppl)
            # print(optionB_ppl)
            # print(optionC_ppl)
            # print(optionD_ppl)

            return output

        else:
            next_tokens_id = torch.argmax(scores, dim=-1)
            optionA_ppl = self.calculate_ppl(scores, optionA)
            optionB_ppl = self.calculate_ppl(scores, optionB)
            optionC_ppl = self.calculate_ppl(scores, optionC)
            optionD_ppl = self.calculate_ppl(scores, optionD)
            ppl_ABCD_list.append(optionA_ppl)
            ppl_ABCD_list.append(optionB_ppl)
            ppl_ABCD_list.append(optionC_ppl)
            ppl_ABCD_list.append(optionD_ppl)
            next_tokens_id = token_ABCD_index_list[ppl_ABCD_list.index(min(ppl_ABCD_list))]
            output = torch.zeros_like(scores).to(self.device)
            output[:, next_tokens_id] = float('inf')

            # print(self.main_model_tokenizer.decode(torch.argmax(scores, dim=-1).tolist()))
            # print(self.main_model_tokenizer.decode(torch.argmax(output, dim=-1).tolist()))
            # print(optionA_ppl)
            # print(optionB_ppl)
            # print(optionC_ppl)
            # print(optionD_ppl)

            return output

