import json
import logging
import math
import os
import pdb
import queue
from collections import Counter

import torch
from torch import nn
from torch.autograd import Variable
from torch.optim import lr_scheduler
from transformers import LogitsProcessor




class BasedOnProbabilityTransferLogits_Loacal_NQ_Reweight_Liner_Processor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False


        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        print(len(self.assist_model_score_queue_list))
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                # print(index, value.topk(10))
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")
                # print(value)
            except queue.Empty:
                print(f"aux model{index}【not received】!!!\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True
        if torch.argmax(scores).item() == self.forced_eos_token_id:
            main_model_only_flag = True
        # 如果aux model们传过来了logits

        if self.early_stop_string_list is not None:
            for early_stop_string in self.early_stop_string_list:
                early_stop_token = self.main_model_tokenizer(early_stop_string, return_tensors="pt",
                                                             add_special_tokens=False).input_ids.tolist()[0][1:]
                last_token_count = len(early_stop_token)

                last_token_ids = input_ids.tolist()[0][-last_token_count:]
                if last_token_ids == early_stop_token:
                    scores[:, self.forced_eos_token_id] = float('inf')
                    main_model_only_flag = True

        if not main_model_only_flag:

            main_model_generate_ids_logits = Variable(scores, requires_grad=True)
            # pdb.set_trace()
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                # pdb.set_trace()
                local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0]

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix).to(
                    self.device_compute)
            model_relative_representation_probs_list = [main_model_relative_representation_probs]
            with torch.no_grad():
                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                            dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            if len(model_relative_representation_probs_list) == 2:
                # model_weights = [0.5, 0.5]
                model_weights = [0.5217, 0.4783]
            elif len(model_relative_representation_probs_list) == 3:
                model_weights = [0.3543, 0.3249, 0.3208]
            elif len(model_relative_representation_probs_list) == 4:
                model_weights = [0.2818, 0.2584, 0.2552, 0.2046]
            elif len(model_relative_representation_probs_list) == 5:
                model_weights = [0.2320, 0.2127, 0.2100, 0.1770, 0.1684]
            elif len(model_relative_representation_probs_list) == 6:
                model_weights = [0.1968, 0.1804, 0.1781, 0.1501, 0.1428, 0.1428]
            elif len(model_relative_representation_probs_list) == 7:
                model_weights = [0.1755, 0.1609, 0.1589, 0.1338, 0.1273, 0.1273, 0.1163]
            else:
                model_weights = [1 / len(model_relative_representation_probs_list)] * len(
                    model_relative_representation_probs_list)
            print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_relative_representation_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs
            final_average_probs = average_probs

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone()
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = local_main_model_relative_representation_matrix.to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)
            # print("loss.item())")
            for i in range(1, self.learning_epochs_nums + 1):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, final_average_probs)
                # pdb.set_trace()
                # print(loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()

            torch.set_grad_enabled(False)

            next_tokens_id = torch.argmax(main_model_generate_ids_logits, dim=-1)
            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"ensemble_result top10:\n")
                process_file.write(str(torch.topk(main_model_generate_ids_logits, 10)) + "\n")
            self.ensemble_model_output_ids_queue.put(next_tokens_id)

            return main_model_generate_ids_logits.to(self.device).detach()

        else:
            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"no ensemble\n")
            next_tokens_id = torch.argmax(scores, dim=-1)
            self.ensemble_model_output_ids_queue.put(next_tokens_id)
            return scores


class BasedOnProbabilityTransferLogits_Loacal_TriviaQA_Reweight_Liner_Processor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        

        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")
                # print(value)
            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True
        if torch.argmax(scores).item() == self.forced_eos_token_id:
            main_model_only_flag = True
        # 如果aux model们传过来了logits

        if self.early_stop_string_list is not None:
            for early_stop_string in self.early_stop_string_list:
                early_stop_token = self.main_model_tokenizer(early_stop_string, return_tensors="pt",
                                                             add_special_tokens=False).input_ids.tolist()[0][1:]
                last_token_count = len(early_stop_token)

                last_token_ids = input_ids.tolist()[0][-last_token_count:]
                if last_token_ids == early_stop_token:
                    scores[:, self.forced_eos_token_id] = float('inf')
                    main_model_only_flag = True

        if not main_model_only_flag:

            main_model_generate_ids_logits = Variable(scores, requires_grad=True)
            # pdb.set_trace()
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                # pdb.set_trace()
                local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0]

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix).to(
                    self.device_compute)
            assist_model_relative_representation_probs_list = []
            with torch.no_grad():
                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                            dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    assist_model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            assist_model_weights = [0.26, 0.2384, 0.2332]
            main_model_weight = 1 - sum(assist_model_weights)
            print(main_model_weight, assist_model_weights)
            weighted_assist_probs_sum = torch.zeros_like(
                main_model_relative_representation_probs)
            for weight, probs in zip(assist_model_weights, assist_model_relative_representation_probs_list):
                weighted_assist_probs_sum += weight * probs

            average_probs = main_model_weight * main_model_relative_representation_probs + weighted_assist_probs_sum

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone()
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = local_main_model_relative_representation_matrix.to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)
            # print("loss.item())")
            for i in range(1, self.learning_epochs_nums + 1):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, average_probs)
                # pdb.set_trace()
                # print(loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()

            torch.set_grad_enabled(False)

            next_tokens_id = torch.argmax(main_model_generate_ids_logits, dim=-1)
            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"ensemble_result top10:\n")
                process_file.write(str(torch.topk(main_model_generate_ids_logits, 10)) + "\n")
            self.ensemble_model_output_ids_queue.put(next_tokens_id)

            return main_model_generate_ids_logits.to(self.device).detach()

        else:
            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"no ensemble\n")
            next_tokens_id = torch.argmax(scores, dim=-1)
            self.ensemble_model_output_ids_queue.put(next_tokens_id)
            return scores


class BasedOnProbabilityTransferLogits_Loacal_Reweight_Processor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        

        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")
                # print(value)
            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True
        if torch.argmax(scores).item() == self.forced_eos_token_id:
            main_model_only_flag = True
        # 如果aux model们传过来了logits

        if self.early_stop_string_list is not None:
            for early_stop_string in self.early_stop_string_list:
                early_stop_token = self.main_model_tokenizer(early_stop_string, return_tensors="pt",
                                                             add_special_tokens=False).input_ids.tolist()[0][1:]
                last_token_count = len(early_stop_token)

                last_token_ids = input_ids.tolist()[0][-last_token_count:]
                if last_token_ids == early_stop_token:
                    scores[:, self.forced_eos_token_id] = float('inf')
                    main_model_only_flag = True

        if not main_model_only_flag:

            main_model_generate_ids_logits = Variable(scores, requires_grad=True)
            # pdb.set_trace()
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                # pdb.set_trace()
                local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0]

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix).to(
                    self.device_compute)
            model_relative_representation_probs_list = [main_model_relative_representation_probs]
            with torch.no_grad():
                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(assist_model_generate_ids_logits,
                                                                            dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            model_weights = [0.25, 0.25, 0.25, 0.25]

            print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_relative_representation_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone()
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = local_main_model_relative_representation_matrix.to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)
            # print("loss.item())")
            for i in range(1, self.learning_epochs_nums + 1):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, average_probs)
                # pdb.set_trace()
                # print(loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()

            torch.set_grad_enabled(False)

            next_tokens_id = torch.argmax(main_model_generate_ids_logits, dim=-1)
            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"ensemble_result top10:\n")
                process_file.write(str(torch.topk(main_model_generate_ids_logits, 10)) + "\n")
            self.ensemble_model_output_ids_queue.put(next_tokens_id)

            return main_model_generate_ids_logits.to(self.device).detach()

        else:
            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"no ensemble\n")
            next_tokens_id = torch.argmax(scores, dim=-1)
            self.ensemble_model_output_ids_queue.put(next_tokens_id)
            return scores


class BasedOnProbabilityTransferLogits_Loacal_FP32_Processor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer_list = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        
        json_object = {}

        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)

            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True
        if torch.argmax(scores).item() == self.forced_eos_token_id:
            main_model_only_flag = True
        # 如果aux model们传过来了logits

        if self.early_stop_string_list is not None:
            for early_stop_string in self.early_stop_string_list:
                early_stop_token = self.main_model_tokenizer(early_stop_string, return_tensors="pt",
                                                             add_special_tokens=False).input_ids.tolist()[0][1:]
                last_token_count = len(early_stop_token)

                last_token_ids = input_ids.tolist()[0][-last_token_count:]
                if last_token_ids == early_stop_token:
                    scores[:, self.forced_eos_token_id] = float('inf')
                    main_model_only_flag = True

        if not main_model_only_flag:

            main_model_generate_ids_logits = Variable(scores, requires_grad=True).to(torch.float32)
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()

                main_model_generate_ids_probs_values, main_model_generate_ids_probs_indices = torch.topk(
                    main_model_generate_ids_probs, k=10)
                json_object[f'origin_main_top_tokens'] = self.main_model_tokenizer.convert_ids_to_tokens(
                    main_model_generate_ids_probs_indices.tolist()[0])

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    self.main_model_probability_transfer_matrix_list[
                                                                        0]).to(self.device_compute)

                main_model_relative_values, main_model_relative_indices = torch.topk(
                    main_model_relative_representation_probs, k=10)
                json_object[f'main_rel_values'] = main_model_relative_values.tolist()[0]
                json_object[f'main_rel_indices'] = main_model_relative_indices.tolist()[0]

                model_relative_representation_probs_list = [main_model_relative_representation_probs]

                for index, (assist_model_generate_ids_logits, assist_model_probability_transfer_matrix) in enumerate(
                        zip(assist_model_generate_ids_logits_list,
                            self.assist_model_probability_transfer_matrix_list)):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(
                        assist_model_generate_ids_logits.to(torch.float32),
                        dim=-1).float()

                    values, indices = torch.topk(assist_model_generate_ids_probs, k=10)
                    json_object[f'origin_aux_{index}_top_tokens'] = self.assist_model_tokenizer_list[
                        index].convert_ids_to_tokens(indices.tolist()[0])

                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          assist_model_probability_transfer_matrix).to(
                        self.device_compute)

                    assist_model_relative_values, assist_model_relative_indices = torch.topk(
                        assist_model_relative_representation_probs, k=10)
                    json_object[f'aux_rel_values_{index}'] = assist_model_relative_values.tolist()[0]
                    json_object[f'aux_rel_indices_{index}'] = assist_model_relative_indices.tolist()[0]

                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            model_weights = [1 / len(model_relative_representation_probs_list)] * len(
                model_relative_representation_probs_list)
            json_object[f'model_weights'] = model_weights

            # print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_relative_representation_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            average_relative_probs_values, average_relative_probs_indices = torch.topk(
                average_probs, k=10)

            json_object[f'average_rel_probs_values'] = average_relative_probs_values.tolist()[0]
            json_object[f'average_rel_probs_indices'] = average_relative_probs_indices.tolist()[0]

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone().to(
                torch.float32)
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0].to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)

            for i in range(0, self.learning_epochs_nums):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, average_probs)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()

                main_model_generate_ids_logits_probs_values, main_model_generate_ids_logits_indices = torch.topk(
                    torch.nn.functional.softmax(main_model_generate_ids_logits, dim=-1), k=10)
                json_object[f'main_model_generate_ids_logits_probs_values_{i}'] = \
                    main_model_generate_ids_logits_probs_values.tolist()[0]
                json_object[f'main_model_generate_ids_logits_indices_{i}'] = \
                    self.main_model_tokenizer.convert_ids_to_tokens(
                        main_model_generate_ids_logits_indices.tolist()[0])

            torch.set_grad_enabled(False)

            next_tokens_id = torch.argmax(main_model_generate_ids_logits, dim=-1)

            main_model_generate_ids_logits_probs_values, main_model_generate_ids_logits_indices = torch.topk(
                torch.nn.functional.softmax(main_model_generate_ids_logits, dim=-1), k=10)
            json_object[f'main_model_generate_ids_logits_probs_values_final'] = \
                main_model_generate_ids_logits_probs_values.tolist()[0]
            json_object[f'main_model_generate_ids_logits_indices_final'] = \
                self.main_model_tokenizer.convert_ids_to_tokens(
                    main_model_generate_ids_logits_indices.tolist()[0])

            self.ensemble_model_output_ids_queue.put(next_tokens_id)
            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(json.dumps(json_object, ensure_ascii=False) + '\n')

            return main_model_generate_ids_logits.to(self.device).detach()

        else:
            next_tokens_id = torch.argmax(scores, dim=-1)
            self.ensemble_model_output_ids_queue.put(next_tokens_id)
            return scores


class BasedOnProbabilityTransferLogits_MinED_Processor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer_list = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        
        json_object = {}

        values, indices = torch.topk(scores, k=10)

        json_object[f'origin_main_top_tokens'] = self.main_model_tokenizer.convert_ids_to_tokens(indices.tolist()[0])

        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True
        if torch.argmax(scores).item() == self.forced_eos_token_id:
            main_model_only_flag = True

        if self.early_stop_string_list is not None:
            for early_stop_string in self.early_stop_string_list:
                early_stop_token = self.main_model_tokenizer(early_stop_string, return_tensors="pt",
                                                             add_special_tokens=False).input_ids.tolist()[0][1:]
                last_token_count = len(early_stop_token)

                last_token_ids = input_ids.tolist()[0][-last_token_count:]
                if last_token_ids == early_stop_token:
                    scores[:, self.forced_eos_token_id] = float('inf')
                    main_model_only_flag = True

        if not main_model_only_flag:

            main_model_generate_ids_logits = Variable(scores, requires_grad=True).to(torch.float32)
            # pdb.set_trace()
            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits,
                                                                      dim=-1).float().to(
                    self.device_compute)
                model_relative_representation_probs_list = [main_model_generate_ids_probs]

                for index, (assist_model_generate_ids_logits, assist_model_probability_transfer_matrix) in enumerate(
                        zip(assist_model_generate_ids_logits_list,
                            self.assist_model_probability_transfer_matrix_list)):
                    assist_model_generate_ids_probs = nn.functional.softmax(
                        assist_model_generate_ids_logits.to(torch.float32), dim=-1)

                    values, indices = torch.topk(assist_model_generate_ids_probs, k=10)
                    json_object[f'origin_aux_{index}_top_tokens'] = self.assist_model_tokenizer_list[
                        index].convert_ids_to_tokens(
                        indices.tolist()[0])

                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          assist_model_probability_transfer_matrix).to(
                        self.device_compute)

                    values, indices = torch.topk(assist_model_relative_representation_probs, k=10)
                    json_object[f'mapping_aux_{index}_top_tokens_probs'] = values.tolist()[0]
                    json_object[f'mapping_aux_{index}_top_tokens'] = self.main_model_tokenizer.convert_ids_to_tokens(
                        indices.tolist()[0])

                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            model_weights = [1 / len(model_relative_representation_probs_list)] * len(
                model_relative_representation_probs_list)

            # print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_generate_ids_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            main_model_generate_ids_logits = average_probs * 15

            values, indices = torch.topk(average_probs, k=10)
            json_object[f'final_average_top_tokens_probs'] = values.tolist()[0]
            json_object[f'final_average_top_tokens'] = self.main_model_tokenizer.convert_ids_to_tokens(
                indices.tolist()[0])

            next_tokens_id = torch.argmax(main_model_generate_ids_logits, dim=-1)
            self.ensemble_model_output_ids_queue.put(next_tokens_id)
            # print(json_object)
            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(json.dumps(json_object, ensure_ascii=False) + '\n')

            return main_model_generate_ids_logits.to(self.device).detach()

        else:

            next_tokens_id = torch.argmax(scores, dim=-1)
            self.ensemble_model_output_ids_queue.put(next_tokens_id)
            return scores


class BasedOnProbabilityTransferLogits_Loacal_FP32_digit_vote_Processor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer_list = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        

        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")
            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True
        if torch.argmax(scores).item() == self.forced_eos_token_id:
            main_model_only_flag = True
        # 如果aux model们传过来了logits

        if self.early_stop_string_list is not None:
            for early_stop_string in self.early_stop_string_list:
                early_stop_token = self.main_model_tokenizer(early_stop_string, return_tensors="pt",
                                                             add_special_tokens=False).input_ids.tolist()[0][1:]
                last_token_count = len(early_stop_token)

                last_token_ids = input_ids.tolist()[0][-last_token_count:]
                if last_token_ids == early_stop_token:
                    scores[:, self.forced_eos_token_id] = float('inf')
                    main_model_only_flag = True

        if self.main_model_tokenizer.decode(torch.argmax(scores).item()).isdigit():

            candidates_list = [self.main_model_tokenizer.decode(torch.argmax(scores).item())]
            for index, logits in enumerate(assist_model_generate_ids_logits_list):
                output = self.assist_model_tokenizer_list[index].decode(torch.argmax(logits).item())
                # if output.isdigit():
                candidates_list.append(output)
            counter = Counter(candidates_list)

            print(f"本轮数字投票集成:{candidates_list}")
            most_common_number = counter.most_common(1)[0][0]
            if most_common_number.isdigit():
                vote_result_token_id = self.main_model_tokenizer.convert_tokens_to_ids(f"{most_common_number}")
                self.ensemble_model_output_ids_queue.put(torch.tensor([vote_result_token_id]).to(self.device))
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"本轮数字投票集成:")
                    process_file.write(str(candidates_list))
                    process_file.write(f"投票结果:{vote_result_token_id}\n")
                output = torch.zeros_like(scores).to(self.device)
                output[:, vote_result_token_id] = float('inf')
                return output

        if not main_model_only_flag:

            main_model_generate_ids_logits = Variable(scores, requires_grad=True).to(torch.float32)

            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()

                local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0]

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix).to(
                    self.device_compute)
            model_relative_representation_probs_list = [main_model_relative_representation_probs]
            with torch.no_grad():
                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(
                        assist_model_generate_ids_logits.to(torch.float32),
                        dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            model_weights = [1 / len(model_relative_representation_probs_list)] * len(
                model_relative_representation_probs_list)

            # print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_relative_representation_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone().to(
                torch.float32)
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = local_main_model_relative_representation_matrix.to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)
            # print("loss.item())")
            for i in range(1, self.learning_epochs_nums + 1):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, average_probs)
                # pdb.set_trace()
                # print(loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()

            torch.set_grad_enabled(False)

            next_tokens_id = torch.argmax(main_model_generate_ids_logits, dim=-1)
            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"ensemble_result top10:\n")
                process_file.write(str(torch.topk(main_model_generate_ids_logits, 10)) + "\n")
            self.ensemble_model_output_ids_queue.put(next_tokens_id)
            print(next_tokens_id)
            return main_model_generate_ids_logits.to(self.device).detach()

        else:
            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"no ensemble\n")
            next_tokens_id = torch.argmax(scores, dim=-1)
            self.ensemble_model_output_ids_queue.put(next_tokens_id)
            return scores


class BasedOnProbabilityTransferLogits_Loacal_FP32_Reweight_by_dev_digit_vote_Processor(LogitsProcessor):
    def __init__(self, learning_rate, anchor_point_count, learning_epochs_nums,
                 ensemble_model_output_ids_queue, assist_model_score_queue_list,
                 main_model_probability_transfer_matrix_list,
                 assist_model_probability_transfer_matrix_list, result_save_dir, main_model_tokenizer,
                 assist_model_tokenizer, device, device_compute, forced_eos_token_id, early_stop_string_list=None):
        self.learning_rate = learning_rate
        self.anchor_point_count = anchor_point_count
        self.assist_model_score_queue_list = assist_model_score_queue_list
        self.learning_epochs_nums = learning_epochs_nums
        self.ensemble_model_output_ids_queue = ensemble_model_output_ids_queue
        self.main_model_probability_transfer_matrix_list = main_model_probability_transfer_matrix_list
        self.assist_model_probability_transfer_matrix_list = assist_model_probability_transfer_matrix_list
        self.result_save_dir = result_save_dir
        self.main_model_tokenizer = main_model_tokenizer
        self.assist_model_tokenizer_list = assist_model_tokenizer
        self.device = device
        self.device_compute = device_compute
        self.forced_eos_token_id = forced_eos_token_id
        self.early_stop_string_list = early_stop_string_list

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        ensemble_process_file_path = os.path.join(self.result_save_dir,
                                                  f'ensemble_lr{self.learning_rate}_anchor_point_count_all_learning_epochs_nums_5.log')
        main_model_only_flag = False
        

        with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
            process_file.write(f"main model top10:\n")
            process_file.write(str(torch.topk(scores, 10)) + "\n")
        assist_model_generate_ids_logits_list = []
        for index, queue_instance in enumerate(self.assist_model_score_queue_list):
            try:
                value = queue_instance.get(block=True, timeout=5)
                assist_model_generate_ids_logits_list.append(value)
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index} top10:\n")
                    process_file.write(str(torch.topk(value, 10)) + "\n")
            except queue.Empty:
                print(f"aux model{index}【not received】\n")
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"aux model{index}【not received】\n")

                assist_model_generate_ids_logits_list.append(None)
                main_model_only_flag = True

        if math.fabs(self.learning_rate) <= 1e-6:
            main_model_only_flag = True
        if torch.argmax(scores).item() == self.forced_eos_token_id:
            main_model_only_flag = True
        # 如果aux model们传过来了logits

        if self.early_stop_string_list is not None:
            for early_stop_string in self.early_stop_string_list:
                early_stop_token = self.main_model_tokenizer(early_stop_string, return_tensors="pt",
                                                             add_special_tokens=False).input_ids.tolist()[0][1:]
                last_token_count = len(early_stop_token)

                last_token_ids = input_ids.tolist()[0][-last_token_count:]
                if last_token_ids == early_stop_token:
                    scores[:, self.forced_eos_token_id] = float('inf')
                    main_model_only_flag = True

        if self.main_model_tokenizer.decode(torch.argmax(scores).item()).isdigit():

            candidates_list = [self.main_model_tokenizer.decode(torch.argmax(scores).item())]
            for index, logits in enumerate(assist_model_generate_ids_logits_list):
                output = self.assist_model_tokenizer_list[index].decode(torch.argmax(logits).item())
                # if output.isdigit():
                candidates_list.append(output)
            counter = Counter(candidates_list)

            print(f"本轮数字投票集成:{candidates_list}")
            most_common_number = counter.most_common(1)[0][0]
            if most_common_number.isdigit():
                vote_result_token_id = self.main_model_tokenizer.convert_tokens_to_ids(f"{most_common_number}")
                self.ensemble_model_output_ids_queue.put(torch.tensor([vote_result_token_id]).to(self.device))
                with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                    process_file.write(f"本轮数字投票集成:")
                    process_file.write(str(candidates_list))
                    process_file.write(f"投票结果:{vote_result_token_id}\n")
                output = torch.zeros_like(scores).to(self.device)
                output[:, vote_result_token_id] = float('inf')
                return output

        if not main_model_only_flag:

            main_model_generate_ids_logits = Variable(scores, requires_grad=True).to(torch.float32)

            with torch.no_grad():
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()

                local_main_model_relative_representation_matrix = self.main_model_probability_transfer_matrix_list[0]

                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix).to(
                    self.device_compute)
            model_relative_representation_probs_list = [main_model_relative_representation_probs]
            with torch.no_grad():
                for assist_model_generate_ids_logits, assist_model_probability_transfer_matrix in zip(
                        assist_model_generate_ids_logits_list, self.assist_model_probability_transfer_matrix_list):
                    
                    assist_model_generate_ids_probs = nn.functional.softmax(
                        assist_model_generate_ids_logits.to(torch.float32),
                        dim=-1).float()
                    # 
                    local_assist_model_relative_representation_matrix = assist_model_probability_transfer_matrix
                    assist_model_relative_representation_probs = torch.mm(assist_model_generate_ids_probs,
                                                                          local_assist_model_relative_representation_matrix).to(
                        self.device_compute)
                    model_relative_representation_probs_list.append(assist_model_relative_representation_probs)

            if len(model_relative_representation_probs_list) == 2:
                model_weights = [0.5174, 0.4826]
            elif len(model_relative_representation_probs_list) == 4:
                model_weights = [0.2821, 0.2631, 0.2298, 0.2250]
            else:
                model_weights = [1 / len(model_relative_representation_probs_list)] * len(
                    model_relative_representation_probs_list)

            # print(model_weights)
            assert len(model_weights) == len(model_relative_representation_probs_list), "权重数和logits必须相同"
            average_probs = torch.zeros_like(main_model_relative_representation_probs)
            for weight, probs in zip(model_weights, model_relative_representation_probs_list):
                average_probs += weight * probs

            torch.set_grad_enabled(True)
            main_model_generate_ids_logits = main_model_generate_ids_logits.to(self.device_compute).detach().clone().to(
                torch.float32)
            main_model_generate_ids_logits.requires_grad_(True)
            local_main_model_relative_representation_matrix = local_main_model_relative_representation_matrix.to(
                self.device_compute)
            local_learning_rate = self.learning_rate
            criterion = nn.KLDivLoss()

            optimizer = torch.optim.AdamW(params=[main_model_generate_ids_logits],
                                          lr=local_learning_rate,
                                          betas=(0.9, 0.999))

            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=local_learning_rate / 4)
            # print("loss.item())")
            for i in range(1, self.learning_epochs_nums + 1):
                main_model_generate_ids_probs = nn.functional.softmax(main_model_generate_ids_logits, dim=-1).float()
                main_model_relative_representation_probs = torch.mm(main_model_generate_ids_probs,
                                                                    local_main_model_relative_representation_matrix)

                log_main_probs = torch.log(main_model_relative_representation_probs)
                loss = criterion(log_main_probs, average_probs)
                # pdb.set_trace()
                # print(loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()

            torch.set_grad_enabled(False)

            next_tokens_id = torch.argmax(main_model_generate_ids_logits, dim=-1)
            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"ensemble_result top10:\n")
                process_file.write(str(torch.topk(main_model_generate_ids_logits, 10)) + "\n")
            self.ensemble_model_output_ids_queue.put(next_tokens_id)
            print(next_tokens_id)
            return main_model_generate_ids_logits.to(self.device).detach()

        else:
            with open(ensemble_process_file_path, "a+", encoding="utf-8") as process_file:
                process_file.write(f"no ensemble\n")
            next_tokens_id = torch.argmax(scores, dim=-1)
            self.ensemble_model_output_ids_queue.put(next_tokens_id)
            return scores

