# the base class of moa, sc, which can be adapted for different datasets
# the async method is used as default
from abc import ABC, abstractmethod
import os
import copy
import numpy as np
from torch.cuda import temperature
np.random.seed(42)
import multiprocessing
from tqdm import tqdm
from transformers import AutoTokenizer
import asyncio
import jsonlines
import torch
import torch.nn.functional as F
import json
from global_utils.utils import generate_general_rm, async_generate_general, generate_general_em,generate_general_rm_mor
import datetime
from pytz import timezone
from moa_api import async_raw_moa_api
from kmeans_pytorch import kmeans
import random
from termcolor import colored
random.seed(42)
import time
from loguru import logger



QUESTION_BANK_PATH_MAP = {
}

def get_qb_path(qb_name):
    for q_path in QUESTION_BANK_PATH_MAP:
        if q_path in qb_name:
            return QUESTION_BANK_PATH_MAP[q_path]
    return None


class BaseTask(ABC):
    @abstractmethod
    def build_result_dir(self):
        pass


class MoaBase(BaseTask):
    def __init__(self, model, model_list_str, model_List_map, data_path, max_tokens, mode='raw_moa', use_sc=False,
                 use_rm=False, rm_model=None, N=1, max_process=8, sc_posi='agg', ref_sample='all',
                 ppl_coef=0.0, agg_max_tokens=32768, ref_token_cut=False, ref_clean_think=False, question_bank='5d_32k', dataset='', exp_suffix='',rm_qb_path='',rm_model_list=[],mor_batch=1,weight_method='softem',score_method='average',temperature=1.0,rtk='r',qb_lis=[]):
        self.model = model
        self.model_list_str = str(model_list_str)
        self.model_list = model_List_map[self.model_list_str]
        self.use_sc = use_sc
        self.use_rm = use_rm
        self.mode = mode
        self.rm_model = rm_model
        self.N = N
        self.rm_model_list=rm_model_list
        self.data_path = data_path
        self.rtk=rtk
        self.max_process = max_process
        self.max_tokens_list = max_tokens if isinstance(max_tokens, list) else [max_tokens]*(len(self.model_list)+1)
        self.max_tokens_list[0] = agg_max_tokens
        self.sc_posi = sc_posi
        self.ref_sample = ref_sample
        self.weight_method=weight_method
        self.score_method=score_method
        self.temperature=temperature
        self.qb_name = question_bank
        self.qb_path = get_qb_path(question_bank)
        self.rm_qb_path=rm_qb_path
        self.dataset = dataset
        self.exp_suffix = exp_suffix
        # The attributes which should be initial in the following method
        self.result_dir = None
        self.test_data = None
        self.val_data = None
        self.test_data_num = None
        self.cache_dict = None
        self.output_res_path = None
        self.output_summary_path = None
        self.done_list = []
        self.done_question_id = []
        self.done_sum_dict = None
        self.ppl_coef = ppl_coef
        self.ref_token_cut = ref_token_cut
        self.ref_clean_think = ref_clean_think
        self.models_profiles = None
        self.question_keywords = None
        self.mor_batch=mor_batch
        self.all_keywords_embedding = None
        self.data_level_profile = None
        if ("rag" in self.mode or "scale" in self.mode or 'router' in self.mode or 'analyze' in self.mode) and 'mor' not in self.mode:
            try:
                bank_keep_rate = float(self.qb_name.split('_')[-1])
            except:
                bank_keep_rate = 1.0
            with open(self.qb_path, 'r') as f:
                self.raw_question_bank = json.load(f)
                if bank_keep_rate != 1.0:
                    bank_keep_num = int(len(self.raw_question_bank) * bank_keep_rate)
                    self.raw_question_bank = {str(i): self.raw_question_bank[k] for i, k in enumerate(random.sample(self.raw_question_bank.keys(), bank_keep_num))}
                # build model pred dict and embedding bank
                model_list = list(self.raw_question_bank['0']['model_res'].keys())
                model_pred_dict = {m: [] for m in model_list}
                embedding_bank = []
                for i, q in self.raw_question_bank.items():
                    embedding_bank.append(q['embedding'])
                    for m in model_list:
                        model_pred_dict[m].append(q['model_res'][m]['is_correct'])
                # convert to tensor
                embedding_bank =  F.normalize(torch.tensor(embedding_bank), p=2, dim=1)
                for m in model_pred_dict:
                    model_pred_dict[m] = torch.tensor(model_pred_dict[m])
                self.question_bank = {
                    'embedding_bank': embedding_bank.cuda(),
                    'model_pred_dict': model_pred_dict
                }
        elif "mor" in self.mode:
            print(colored(self.rm_qb_path+'-----------------------', 'green'))
            print(colored(str(qb_lis)+'-----------------------', 'blue'))

            with open(self.rm_qb_path, 'r') as f:
                rm_js = json.load(f)
                rm_model_acc_dict = []
                embedding_bank = []
                for q in rm_js:
                    if -1 in list(q['reward_group_acc'].values()) or q['src'] not in qb_lis:
                        continue
                    embedding_bank.append(q['embedding'])
                    rm_model_acc_dict.append(sorted(q['reward_group_acc'].items(), key=lambda item: item[1], reverse=True))
                # convert to tensor
                embedding_bank =  F.normalize(torch.tensor(embedding_bank), p=2, dim=1)
                # for rm in rm_model_acc_dict:
                #     rm_model_acc_dict[rm] = torch.tensor(rm_model_acc_dict[rm])
                #     print(rm_model_acc_dict[rm].shape)
                print('embedding bank shape:',embedding_bank.shape)
                print('model acc list len',len(rm_model_acc_dict))
                self.question_bank = {
                    'embedding_bank': embedding_bank.cuda(),
                    'model_pred_dict': rm_model_acc_dict
                }
        else:
            self.raw_question_bank = None
            self.question_bank = None
        if 'analyze' in self.mode:
            with open("", 'r') as f:
                self.raw_question_bank = json.load(f)
        # load llama tokenizer
        try:
            self.simple_tokenizer = AutoTokenizer.from_pretrained('/fs-computility/mabasic/shared/models/Llama-3.3-70B-Instruct', use_fast=True, trust_remote_code=True)
        except:
            self.simple_tokenizer = None

    # for different datasets, it should be given
    def build_result_dir(self):
        max_token_ref = self.max_tokens_list[-1]
        if self.ref_token_cut:
            exp_root = 'main_exp_8k_cut'
        elif self.ref_clean_think:
            exp_root = 'main_exp_8k_clean_think'
        else:
            exp_root = 'main_exp_new_8k' if max_token_ref == 8192 else 'main_exp_new'
        mode = self.mode
        if self.use_sc:
            mode += f'_SC{self.N}'
            mode += f'_posi_{self.sc_posi}'
        if self.use_rm:
            mode += f'_Bo{self.N}_{self.rm_model}'
        if self.ref_sample != 'all':
            mode += f'_ref_sample_{self.ref_sample}'
        if self.ppl_coef != 0:
            mode += f'_ppl_{self.ppl_coef}'
        if self.qb_name == '5d_32k':
            qb_suffix = ''
        else:
            qb_suffix = '_'+self.qb_name
        if os.getcwd().split('/')[-1] == self.dataset:
            result_dir = os.path.join('result', exp_root,
                                      self.model + '_' + self.model_list_str + '_' + mode + qb_suffix)
        else:
            result_dir = os.path.join(self.dataset, 'result', exp_root, self.model + '_' + self.model_list_str + '_' + mode + qb_suffix)
        if self.exp_suffix != '':
            result_dir = result_dir + f'_{self.exp_suffix}'
        if 'gs' in self.mode:
            result_dir=result_dir+'_'+self.rm_model

        if not os.path.exists(result_dir):
            os.makedirs(result_dir, exist_ok=True)
        
        self.result_dir = result_dir
        print('--------res dir',result_dir)

    # should contain the test_data_num
    @abstractmethod
    def build_dataset(self):
        pass

    @abstractmethod
    def build_res_and_sum_file(self):
        pass

    @abstractmethod
    def build_cache(self):
        pass

    # which function can be modified
    @abstractmethod
    def build_messages(self, each):
        pass

    @abstractmethod
    def extract_answer(self, response):
        pass

    @abstractmethod
    def get_question_id(self, each):
        pass

    @abstractmethod
    def get_question(self, each):
        pass

    @abstractmethod
    async def async_generate_general_cache(self, question_id, model, messages, max_tokens, temperature, streaming):
        pass

    @abstractmethod
    def run(self):
        pass

    def record_sc_memory(self, sc_memory, responses):
        response_pred_list = [self.extract_answer(r) for r in responses]
        for ref_pred in response_pred_list:
            sc_memory[ref_pred] = 1 if ref_pred not in sc_memory.keys() else sc_memory[ref_pred] + 1

    def build_sub_ref(self, mode, references, N, rag_score_list=None):
        sample_n = int(mode.split('_')[-1])
        if 'random' in mode:
            sub_ref_index = np.array([np.random.choice(range(len(references)), sample_n, replace=False) for _ in range(N)])
        elif 'k-means' in mode:
            embedding_ref = generate_general_em('Linq-Embed-Mistral', sentences=references,
                                                tasks=['']*len(references), batch_size=2, max_length=8192)
            cluster_ids_x, cluster_centers = kmeans(
                X=torch.tensor(embedding_ref), num_clusters=sample_n, distance='cosine', device='cpu'
            )
            sub_ref_index = []
            for _ in range(N):
                sampled_indices = []
                for cluster_id in range(sample_n):
                    indices = torch.where(cluster_ids_x == cluster_id)[0]
                    if len(indices) > 0:
                        random_index = indices[torch.randint(0, len(indices), (1,))].item()
                        sampled_indices.append(random_index)
                sub_ref_index.append(sampled_indices)
            sub_ref_index = np.array(sub_ref_index)
        elif 'prior' in mode:
            rag_score_list = np.array(rag_score_list)
            rag_score_list_diff = rag_score_list - rag_score_list.min()
            rag_score_list_diff_norm = np.exp((rag_score_list_diff - rag_score_list_diff.mean()) / ((rag_score_list_diff ** 2).sum()/len(rag_score_list_diff)) ** 0.5)
            p = rag_score_list_diff_norm / rag_score_list_diff_norm.sum()
            sub_ref_index = np.array(
                [np.random.choice(range(len(references)), sample_n, replace=False, p=p) for _ in range(N)])
        result_ref = []
        for each_ref_index in sub_ref_index:
            result_ref.append(np.array(references)[each_ref_index].tolist())
        return result_ref

    def combine_sc_with_reward(self, sc_memory, reward_score, pred_list):
        reward_score = torch.tensor(reward_score)
        reward_score_norm = (reward_score - reward_score.min()) / (reward_score.max() - reward_score.min())
        sc_memory_all = sum([v for v in sc_memory.values()])
        return torch.tensor([sc_memory[pred] / sc_memory_all for pred in pred_list]) + reward_score_norm

    def ref_cut(self, references, max_tokens):
        new_references = []
        for r in references:
            r_tokens = self.simple_tokenizer.encode(r)
            if len(r_tokens) > max_tokens:
                new_references.append(self.simple_tokenizer.decode(r_tokens[-max_tokens:]))
            else:
                new_references.append(r)
        return new_references


    def clean_think(self, references, max_tokens):
        new_references = []
        for r in references:
            r_tokens = self.simple_tokenizer.encode(r)
            if len(r_tokens) > max_tokens:
                r = r.split('</think>')[-1]
                r = r.split('## Final Response')[-1]
                r = r.split('</thought>')[-1]
            new_references.append(r)
        return new_references


    def wrap_raw_moa_test(self, data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef,rm_model, N, sc_posi='agg', ref_sample='all',rm_model_list=[]):
        async def wrap_raw_moa_test_(data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef,rm_model, N, sc_posi, ref_sample,rm_model_list):
            return_dict = {}
            each = test_data[data_id]
            question_id = self.get_question_id(each)
            # the sc memory log each response similarity score and ppl score
            sc_memory = {}

            messages = self.build_messages(each)
            tasks = [self.async_generate_general_cache(question_id, m, messages, mt, 0.7, False) for m, mt in zip(model_list, max_tokens[1:])]
            references = await asyncio.gather(*tasks)
            if self.ref_token_cut:
                references = self.ref_cut(references, max_tokens[-1])
            if self.ref_clean_think:
                references = self.clean_think(references, max_tokens[-1])
            # if the  position of sc
            #TODO: correct ref sc
            if 'ref' in sc_posi and use_sc:
                sc_memory_ref = {m: {} for m in model_list}
                sc_ref_response = {m: [] for m in model_list}
                sc_ref_most_pred = {m: None for m in model_list}
                # build the tasks of ref model
                tasks_ref_sc = [async_generate_general(m, messages, mt, 0.7, False) for m, mt in zip(model_list, max_tokens[1:])] * (N-1)
                references_ref_sc = await asyncio.gather(*tasks_ref_sc)
                references_ref_sc += references
                for i, r in enumerate(references_ref_sc):
                    m_ref = model_list[i % len(model_list)]
                    pred_ref = self.extract_answer(r)
                    sc_ref_response[m_ref].append(r)
                    sc_memory_ref[m_ref][pred_ref] = 1 if pred_ref not in sc_memory_ref[m_ref].keys() else sc_memory_ref[m_ref][pred_ref] + 1

                for m in sc_memory_ref.keys():
                    max_cnt_sc = 0
                    most_pred = None
                    for pred_temp in sc_memory_ref[m].keys():
                        if sc_memory_ref[m][pred_temp] > max_cnt_sc:
                            max_cnt_sc = sc_memory_ref[m][pred_temp]
                            most_pred = pred_temp
                    sc_ref_most_pred[m] = most_pred

                collected_m = []
                references = []
                for i, r in enumerate(references_ref_sc):
                    m_ref = model_list[i % len(model_list)]
                    if m_ref in collected_m:
                        continue
                    pred_ref = self.extract_answer(r)
                    if pred_ref == sc_ref_most_pred[m_ref]:
                        collected_m.append(m_ref)
                        references.append(r)

            ref_dict = {m: r for r, m in zip(references, model_list)}
            return_dict['ref_dict'] = ref_dict
            agg_N = N if ('agg' in sc_posi and use_sc) or (not use_sc and use_rm) else 1
            if ref_sample == 'all' or agg_N == 1:
                agg_tasks = [async_raw_moa_api(model, messages, None, 0.7, max_tokens[0], 1, references, 1) for _ in range(agg_N)]
            else:
                sub_ref = self.build_sub_ref(mode=ref_sample, references=references, N=agg_N)
                agg_tasks = [async_raw_moa_api(model, messages, None, 0.7, max_tokens[0], 1, sub_ref_i, 1) for sub_ref_i in
                             sub_ref]
            raw_responses = await asyncio.gather(*agg_tasks)
            select_score = {i: {'ppl_score': 0.0, 'sc_score': 0.0, 'total_score': 0.0} for i in range(len(raw_responses))}
            if not isinstance(raw_responses[0], str):
                mean_cumulative_logprob = [r['cumulative_logprob'] for r in raw_responses]
                ppl = np.exp(mean_cumulative_logprob)
                responses = [r['response'] for r in raw_responses]
                ppl_score = 1 - ppl
            else:
                responses = raw_responses
                mean_cumulative_logprob = [None for r in raw_responses]
                ppl, ppl_score = None, None
            if agg_N == 1:
                response = responses[0]

            response_pred_list = [self.extract_answer(r) for r in responses]
            self.record_sc_memory(sc_memory, responses)
            for i, response_i in enumerate(responses):
                pred_i = response_pred_list[i]
                select_score[i]['ppl_score'] = ppl_score[i]
                select_score[i]['sc_score'] = sc_memory[pred_i] / agg_N
                select_score[i]['total_score'] = sc_memory[pred_i] / agg_N + ppl_coef * ppl_score[i]
            response = responses[sorted(select_score, key=lambda x: select_score[x]['total_score'], reverse=True)[0]]
            pred = self.extract_answer(response)
            return_dict['sc_memory'] = sc_memory
            return_dict['response'] = response
            return_dict['pred'] = pred
            return_dict['n_response'] = responses
            return_dict['select_score'] = select_score
            return return_dict

        return asyncio.run(
            wrap_raw_moa_test_(data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm,ppl_coef, rm_model, N, sc_posi, ref_sample,rm_model_list))

    def wrap_moa_greedy_search_mor(self, data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef,rm_model, N, sc_posi='agg', ref_sample='all',rm_model_list=[]):
        async def wrap_moa_greedy_search_mor_(data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef,rm_model, N, sc_posi, ref_sample,rm_model_list):
            all_time_st=time.time()
            try:
                rag_num = int(self.mode.split('_')[-1])
            except:
                rag_num = None
            weighted_score = self.mode.split('_')[-2] == 'weighted'
            k = 100
            process_id = str(os.getpid())
            
            return_dict = {}
            each = test_data[data_id]
            # print('data id',data_id)
            question_id = self.get_question_id(each)
            question = self.get_question(each)
            # print(question)
            rag_task = 'Given a question, find the one with the highest subject and semantic similarity from the question bank.'
            question_embedding = F.normalize(torch.tensor(generate_general_em('Linq-Embed-Mistral', [question], [rag_task], 8192, 1)).cuda())
            sc_memory = {}
            messages = self.build_messages(each)
            # find_k_near
            scores = (question_embedding @ self.question_bank['embedding_bank'].T) * 100
            
            scores_topk_value, scores_topk = scores[0].topk(len(scores[0]))
            threshold_bound = scores_topk_value[k] * 0.95
            threshold_bound=threshold_bound if threshold_bound>60 else 60
            top=scores_topk[0]
            scores_topk = scores_topk[scores_topk_value > threshold_bound].cpu()
            if len(scores_topk)==0:
                scores_topk=torch.tensor([top])
            
            # print([self.question_bank['model_pred_dict'][idc] for idc in scores_topk])

            top_acc_lis=self.question_bank['model_pred_dict'][scores_topk[0]]
            topa=top_acc_lis[0][1]
            top_acc=[]
            for md in top_acc_lis:
                if md[1]==topa:
                    top_acc.append(md)
            
            # print(top_acc)
            if len(top_acc)>1:
                # print(colored('going top acc >1---------','red'))
                acc_compare=[]
                temp_search=[dict(self.question_bank['model_pred_dict'][idx]) for idx in scores_topk.tolist()]
                # print(temp_search)
                simi=scores[0][scores_topk]
                assert len(temp_search)==len(simi), 'Search len mismatch with simi len, check...'
                for t in top_acc:
                    if len(t[0].split('_'))==1:
                        temp_acc=0
                        for weight,dic in zip(simi,temp_search):
                            temp_acc+=weight/100*dic[t[0]]
                        acc_compare.append(temp_acc/len(scores_topk))
                    else:
                        temp_lis=[]
                        for rmm in t[0].split('_')[:-1]:
                            temp_acc=0
                            for weight,dic in zip(simi,temp_search):
                                temp_acc+=weight/100*dic[rmm]
                            temp_lis.append(temp_acc/len(scores_topk))
                        # print(temp_lis)
                        acc_compare.append(max(temp_lis))
                # print(acc_compare)
                acc_compare=torch.tensor(acc_compare)
                va,tk=acc_compare.topk(len(acc_compare))
                co_acc=[x for x in acc_compare if x==acc_compare[tk[0]]]
                if len(co_acc)==1:
                    top_acc=[top_acc[tk[0]]]
                else:
                    for i in range(len(co_acc)):
                        if len(top_acc[tk[i]][0].split('_'))==1:
                            top_acc=[top_acc[tk[i]]]
                            break
                

            str_lis=top_acc[0][0].split('_')

            
            if len(str_lis)>1:
                # print('before',top_acc)
                # print('temp search',temp_search)
                temp_search=dict(self.question_bank['model_pred_dict'][scores_topk[0]])
                # print(temp_search)
                a=[temp_search[rml] for rml in str_lis[:-1]]

                top_acc.append(a)
                
                top_ac=max(a)
                for ac in a:
                    if abs(top_ac-ac)>0.15:
                        logger.info('top ac and ac surpass the thre...')
                        logger.info(f'Before: {top_acc}')
                        top_acc=[(str_lis[:-1][a.index(top_ac)],top_ac)]
                        logger.info(f'After: {top_acc}')
                        break
                
            

            
            residual = True
            bi_greedy = True
            use_sc = True
            drop_greedy = True

            return_dict = {}
            sc_memory = {}
            search_history = [str([i]) for i in range(len(model_list))]
            each = test_data[data_id]
            question_id = self.get_question_id(each)
            messages = self.build_messages(each)
            tasks = [self.async_generate_general_cache(question_id, m, messages, mt, 0.7, False) for m, mt in
                     zip(model_list, max_tokens[1:])]
            references = await asyncio.gather(*tasks)
            # print(references)
            ref_dict = {m: r for r, m in zip(references, model_list)}
            return_dict['ref_dict'] = ref_dict
            # log the self consistence
            ref_pred_list = [self.extract_answer(r) for r in references]
            for ref_pred in ref_pred_list:
                sc_memory[ref_pred] = 1 if ref_pred not in sc_memory.keys() else sc_memory[ref_pred] + 1
            question = messages[-1]["content"]
            # print('temperature: ',self.temperature)
            # compute the reward of each response
            # print('---------------------RUN MOR START 1---------------------------')
            # print('len(references)',len(references))
            logger.info(f'len(references): {str(len(references))}')
            # print(colored(f'top acc len {str(len(top_acc))}','red'))
            print(colored(f'process id {process_id}','red'),top_acc)
            start_time=time.time()
            reward_score = generate_general_rm_mor(
                model=top_acc,
                question=[question] * len(references),
                response=references,
                batch_size=self.mor_batch)
            end_time=time.time()
            print('Reward score',reward_score)
            
            # print(colored(f'---------------------DATA ID: {str(data_id)} MOR Thread {process_id} Final time cost 1: '+str(end_time-start_time)+' s---------------------------','green'))
            logger.info(colored(f'---------------------DATA ID: {str(data_id)} MOR Thread {process_id} time cost 1: '+str(end_time-start_time)+' s---------------------------','green'))
            reward_score = torch.tensor(reward_score)
            chosen_index = []
            responses = []
            win_responses = []
            # begin to search greedily
            # for residual, the reward of ref will be added to
            for j in range(2 ** len(reward_score)):
                if j == 0:
                    # the initial subset of greedy search
                    if drop_greedy:
                        chosen_index = list(range(len(model_list)))
                    else:
                        chosen_index = sorted(reward_score.topk(2)[1].tolist())
                    pre_response = await async_raw_moa_api(model, messages, reference_models=None, temperature=0.7,
                                                           max_tokens=max_tokens[0],
                                                           references=np.array(references)[chosen_index].tolist())
                    pre_pred = self.extract_answer(pre_response)
                    # print('---------------------RUN MOR START pre_iter_reward 2---------------------------')
                    start_time=time.time()
                    pre_iter_reward = generate_general_rm_mor(top_acc, [question], [pre_response], batch_size=1)[0]
                    end_time=time.time()
                    # print(colored(f'---------------------DATA ID: {str(data_id)} MOR Thread {process_id} time cost 2: '+str(end_time-start_time)+' s---------------------------','blue'))
                    logger.info(colored(f'---------------------DATA ID: {str(data_id)} MOR Thread {process_id} time cost 2: '+str(end_time-start_time)+' s---------------------------','blue'))

                    # print(f'---------------------RUN MOR END Final time cost 2: '+str(end_time-start_time)+' s---------------------------')
                    chosen_index_temp = chosen_index
                    search_history.append(str(chosen_index))
                    sc_memory[pre_pred] = 1 if pre_pred not in sc_memory.keys() else sc_memory[pre_pred] + 1
                    responses.append(pre_response)
                    win_responses.append(pre_response)
                else:
                    # prepare the tasks for multi-processing
                    temp_chosen_reference_list, temp_chosen_index_list = [], []
                    # for add the agent
                    for append_k in set(range(len(reward_score))) - set(chosen_index):
                        temp_chosen_index = copy.deepcopy(chosen_index)
                        temp_chosen_index.append(append_k)
                        temp_chosen_index = sorted(temp_chosen_index)
                        temp_chosen_reference = np.array(references)[temp_chosen_index].tolist()
                        # avoid the ring search
                        if str(temp_chosen_index) not in search_history:
                            temp_chosen_reference_list.append(temp_chosen_reference)
                            temp_chosen_index_list.append(temp_chosen_index)
                            search_history.append(str(temp_chosen_index))
                    # for reduce the agent
                    if bi_greedy and len(chosen_index) > 1:
                        for reduce_k in chosen_index:
                            temp_chosen_index = copy.deepcopy(chosen_index)
                            temp_chosen_index.remove(reduce_k)
                            temp_chosen_index = sorted(temp_chosen_index)
                            temp_chosen_reference = np.array(references)[temp_chosen_index].tolist()
                            # avoid the ring search
                            if str(temp_chosen_index) not in search_history:
                                temp_chosen_reference_list.append(temp_chosen_reference)
                                temp_chosen_index_list.append(temp_chosen_index)
                                search_history.append(str(temp_chosen_index))
                    if len(temp_chosen_index_list) == 0:
                        break
                    tasks = [async_raw_moa_api(model, messages, None, 0.7, max_tokens[0], 1, temp_chosen_reference)
                             for temp_chosen_reference in temp_chosen_reference_list]
                    moa_temp_chosen_reference_res = await asyncio.gather(*tasks)
                    # log the self consistence
                    moa_temp_pred_list = [self.extract_answer(r) for r in moa_temp_chosen_reference_res]
                    for pred_r in moa_temp_pred_list:
                        sc_memory[pred_r] = 1 if pred_r not in sc_memory.keys() else sc_memory[pred_r] + 1
                    # print('---------------------RUN MOR START moa_temp_chosen_reference_res_reward_score 3---------------------------')
                    start_time=time.time()
                    
                    moa_temp_chosen_reference_res_reward_score = generate_general_rm_mor(
                        model=top_acc,
                        question=[question] * len(moa_temp_chosen_reference_res),
                        response=moa_temp_chosen_reference_res,
                        batch_size=self.mor_batch)
                    end_time=time.time()
                    # print(colored(f'---------------------DATA ID: {str(data_id)} MOR Thread {process_id} Final time cost 3: '+str(end_time-start_time)+' s---------------------------','cyan'))
                    logger.info(colored(f'---------------------DATA ID: {str(data_id)} MOR Thread {process_id} time cost 3: '+str(end_time-start_time)+' s---------------------------','cyan'))
                    # print('---------------------RUN MOR END Final time cost 3: '+str(end_time-start_time)+' s---------------------------')
                    if not use_sc:
                        get_improvement = (
                                torch.tensor(moa_temp_chosen_reference_res_reward_score) > pre_iter_reward).any().item()
                        max_reward_index = torch.tensor(moa_temp_chosen_reference_res_reward_score).argmax().item()
                    else:
                        reward_score_cat = [pre_iter_reward] + moa_temp_chosen_reference_res_reward_score
                        pred_list_cat = [pre_pred] + moa_temp_pred_list
                        add_final_score = self.combine_sc_with_reward(sc_memory, reward_score_cat, pred_list_cat)
                        try:
                            get_improvement = add_final_score[1:].max() > add_final_score[0]
                        except:
                            print(1)
                        max_reward_index = add_final_score[1:].argmax().item()
                    if get_improvement:
                        pre_response = moa_temp_chosen_reference_res[max_reward_index]
                        pre_iter_reward = max(moa_temp_chosen_reference_res_reward_score)
                        pre_pred = moa_temp_pred_list[max_reward_index]
                        chosen_index_temp = temp_chosen_index_list[max_reward_index]
                    chosen_index = chosen_index_temp
                    responses.extend(moa_temp_chosen_reference_res)
                    if not get_improvement:
                        break
                    win_responses.append(pre_response)
            response = pre_response

            logger.info(colored(f'DATA ID: {str(data_id)} Search Step: {str(j)}','red'))
            if residual:
                if not use_sc:
                    max_reward_index_ref = reward_score.argmax().item()
                    max_reward_ref = reward_score.max().item()
                    if max_reward_ref > pre_iter_reward:
                        response = references[max_reward_index_ref]
                else:
                    pred_list_cat = [pre_pred] + ref_pred_list
                    reward_score_cat = [pre_iter_reward] + reward_score.tolist()
                    add_final_score = self.combine_sc_with_reward(sc_memory, reward_score_cat, pred_list_cat)
                    max_reward_index_ref = add_final_score[1:].argmax().item()
                    max_reward_ref = add_final_score[1:].max().item()
                    if max_reward_ref > add_final_score[0]:
                        response = references[max_reward_index_ref]
                win_responses.append(response)
            pred = self.extract_answer(response)
            return_dict['sc_memory'] = sc_memory
            return_dict['response'] = response
            return_dict['pred'] = pred
            return_dict['n_response'] = responses
            return_dict['win_response'] = win_responses
            all_time_end=time.time()
            # print(colored(f'DATA ID: {str(data_id)} Final {process_id} Thread time:'+str(all_time_end-all_time_st), 'yellow'))
            logger.info(colored(f'DATA ID: {str(data_id)} Final {process_id} Thread time:'+str(all_time_end-all_time_st), 'yellow'))

            return return_dict

        return asyncio.run(
            wrap_moa_greedy_search_mor_(data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm,ppl_coef, rm_model, N, sc_posi, ref_sample,rm_model_list))


    def wrap_moa_greedy_search(self, data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef, rm_model,N, sc_posi='agg', ref_sample='all',rm_model_list=[]):
        async def wrap_moa_greedy_search_(data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef,rm_model, N, sc_posi, ref_sample,rm_model_list):
            process_id = str(os.getpid())
            all_time_st=time.time()
            residual = True
            bi_greedy = True
            use_sc = True
            drop_greedy = True
            print('---------------------RUN GS STAR---------------------------')
            return_dict = {}
            sc_memory = {}
            search_history = [str([i]) for i in range(len(model_list))]
            each = test_data[data_id]
            question_id = self.get_question_id(each)
            messages = self.build_messages(each)
            tasks = [self.async_generate_general_cache(question_id, m, messages, mt, 0.7, False) for m, mt in
                     zip(model_list, max_tokens[1:])]
            references = await asyncio.gather(*tasks)
            ref_dict = {m: r for r, m in zip(references, model_list)}
            return_dict['ref_dict'] = ref_dict
            # log the self consistence
            ref_pred_list = [self.extract_answer(r) for r in references]
            for ref_pred in ref_pred_list:
                sc_memory[ref_pred] = 1 if ref_pred not in sc_memory.keys() else sc_memory[ref_pred] + 1
            question = messages[-1]["content"]
            # compute the reward of each response
            start_time=time.time()
            reward_score = generate_general_rm(
                model=rm_model,
                question=[question] * len(references),
                response=references,
                batch_size=self.mor_batch)
            end_time=time.time()
            print('Reward score',reward_score)
            print(colored(f'---------------------DATA ID: {str(data_id)} MOR Thread {process_id} Final time cost 1: '+str(end_time-start_time)+' s---------------------------','green'))
            reward_score = torch.tensor(reward_score)
            chosen_index = []
            responses = []
            win_responses = []
            # begin to search greedily
            # for residual, the reward of ref will be added to
            for j in range(2 ** len(reward_score)):
                if j == 0:
                    # the initial subset of greedy search
                    if drop_greedy:
                        chosen_index = list(range(len(model_list)))
                    else:
                        chosen_index = sorted(reward_score.topk(2)[1].tolist())
                    pre_response = await async_raw_moa_api(model, messages, reference_models=None, temperature=0.7,
                                                           max_tokens=max_tokens[0],
                                                           references=np.array(references)[chosen_index].tolist())
                    pre_pred = self.extract_answer(pre_response)
                    start_time=time.time()
                    pre_iter_reward = generate_general_rm(rm_model, [question], [pre_response], batch_size=1)[0]
                    end_time=time.time()
                    print(colored(f'---------------------DATA ID: {str(data_id)} MOR Thread {process_id} Final time cost 2: '+str(end_time-start_time)+' s---------------------------','blue'))
                    chosen_index_temp = chosen_index
                    search_history.append(str(chosen_index))
                    sc_memory[pre_pred] = 1 if pre_pred not in sc_memory.keys() else sc_memory[pre_pred] + 1
                    responses.append(pre_response)
                    win_responses.append(pre_response)
                else:
                    # prepare the tasks for multi-processing
                    temp_chosen_reference_list, temp_chosen_index_list = [], []
                    # for add the agent
                    for append_k in set(range(len(reward_score))) - set(chosen_index):
                        temp_chosen_index = copy.deepcopy(chosen_index)
                        temp_chosen_index.append(append_k)
                        temp_chosen_index = sorted(temp_chosen_index)
                        temp_chosen_reference = np.array(references)[temp_chosen_index].tolist()
                        # avoid the ring search
                        if str(temp_chosen_index) not in search_history:
                            temp_chosen_reference_list.append(temp_chosen_reference)
                            temp_chosen_index_list.append(temp_chosen_index)
                            search_history.append(str(temp_chosen_index))
                    # for reduce the agent
                    if bi_greedy and len(chosen_index) > 1:
                        for reduce_k in chosen_index:
                            temp_chosen_index = copy.deepcopy(chosen_index)
                            temp_chosen_index.remove(reduce_k)
                            temp_chosen_index = sorted(temp_chosen_index)
                            temp_chosen_reference = np.array(references)[temp_chosen_index].tolist()
                            # avoid the ring search
                            if str(temp_chosen_index) not in search_history:
                                temp_chosen_reference_list.append(temp_chosen_reference)
                                temp_chosen_index_list.append(temp_chosen_index)
                                search_history.append(str(temp_chosen_index))
                    if len(temp_chosen_index_list) == 0:
                        break
                    tasks = [async_raw_moa_api(model, messages, None, 0.7, max_tokens[0], 1, temp_chosen_reference)
                             for temp_chosen_reference in temp_chosen_reference_list]
                    moa_temp_chosen_reference_res = await asyncio.gather(*tasks)
                    # log the self consistence
                    moa_temp_pred_list = [self.extract_answer(r) for r in moa_temp_chosen_reference_res]
                    for pred_r in moa_temp_pred_list:
                        sc_memory[pred_r] = 1 if pred_r not in sc_memory.keys() else sc_memory[pred_r] + 1
                    start_time=time.time()
                    moa_temp_chosen_reference_res_reward_score = generate_general_rm(
                        model=rm_model,
                        question=[question] * len(moa_temp_chosen_reference_res),
                        response=moa_temp_chosen_reference_res,
                        batch_size=self.mor_batch)
                    end_time=time.time()
                    print(colored(f'---------------------DATA ID: {str(data_id)} MOR Thread {process_id} Final time cost 3: '+str(end_time-start_time)+' s---------------------------','cyan'))
                    if not use_sc:
                        get_improvement = (
                                torch.tensor(moa_temp_chosen_reference_res_reward_score) > pre_iter_reward).any().item()
                        max_reward_index = torch.tensor(moa_temp_chosen_reference_res_reward_score).argmax().item()
                    else:
                        reward_score_cat = [pre_iter_reward] + moa_temp_chosen_reference_res_reward_score
                        pred_list_cat = [pre_pred] + moa_temp_pred_list
                        add_final_score = self.combine_sc_with_reward(sc_memory, reward_score_cat, pred_list_cat)
                        try:
                            get_improvement = add_final_score[1:].max() > add_final_score[0]
                        except:
                            print(1)
                        max_reward_index = add_final_score[1:].argmax().item()
                    if get_improvement:
                        pre_response = moa_temp_chosen_reference_res[max_reward_index]
                        pre_iter_reward = max(moa_temp_chosen_reference_res_reward_score)
                        pre_pred = moa_temp_pred_list[max_reward_index]
                        chosen_index_temp = temp_chosen_index_list[max_reward_index]
                    chosen_index = chosen_index_temp
                    responses.extend(moa_temp_chosen_reference_res)
                    if not get_improvement:
                        break
                    win_responses.append(pre_response)
            response = pre_response
            if residual:
                if not use_sc:
                    max_reward_index_ref = reward_score.argmax().item()
                    max_reward_ref = reward_score.max().item()
                    if max_reward_ref > pre_iter_reward:
                        response = references[max_reward_index_ref]
                else:
                    pred_list_cat = [pre_pred] + ref_pred_list
                    reward_score_cat = [pre_iter_reward] + reward_score.tolist()
                    add_final_score = self.combine_sc_with_reward(sc_memory, reward_score_cat, pred_list_cat)
                    max_reward_index_ref = add_final_score[1:].argmax().item()
                    max_reward_ref = add_final_score[1:].max().item()
                    if max_reward_ref > add_final_score[0]:
                        response = references[max_reward_index_ref]
                win_responses.append(response)
            pred = self.extract_answer(response)
            return_dict['sc_memory'] = sc_memory
            return_dict['response'] = response
            return_dict['pred'] = pred
            return_dict['n_response'] = responses
            return_dict['win_response'] = win_responses
            all_time_end=time.time()
            print(colored(f'DATA ID: {str(data_id)} Final {process_id} Thread time:'+str(all_time_end-all_time_st), 'yellow'))
            return return_dict

        return asyncio.run(
            wrap_moa_greedy_search_(data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm,ppl_coef, rm_model, N, sc_posi, ref_sample,rm_model_list))

    def wrap_rag_moa_test(self, data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef, rm_model, N, sc_posi='agg', ref_sample='all',rm_model_list=[]):
        async def wrap_rag_moa_test_(data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef,
                                     rm_model, N, sc_posi, ref_sample,rm_model_list):
            try:
                rag_num = int(self.mode.split('_')[-1])
            except:
                rag_num = None
            weighted_score = self.mode.split('_')[-2] == 'weighted'
            k = 400
            return_dict = {}
            each = test_data[data_id]
            question_id = self.get_question_id(each)
            question = self.get_question(each)
            rag_task = 'Given a question, find the one with the highest semantic similarity and subject similarity from the question bank.'
            question_embedding = F.normalize(torch.tensor(generate_general_em('Linq-Embed-Mistral', [question], [rag_task], 8192, 1)).cuda())
            sc_memory = {}
            messages = self.build_messages(each)
            # find_k_near
            scores = (question_embedding @ self.question_bank['embedding_bank'].T) * 100
            scores_topk_value, scores_topk = scores[0].topk(len(scores[0]))
            threshold_bound = scores_topk_value[k] * 0.95
            scores_topk = scores_topk[scores_topk_value > threshold_bound].cpu()

            model_profile = {}
            for m in model_list:
                if weighted_score:
                    model_profile[m] = (scores_topk_value[:len(scores_topk)].cpu() / 100 * self.question_bank['model_pred_dict'][m][scores_topk]).sum().item()
                else:
                    model_profile[m] = (self.question_bank['model_pred_dict'][m][scores_topk]).sum().item()
            model_profile_score_np = np.array([model_profile[k] for k in model_profile])
            if rag_num is None:
                rag_mode = self.mode.split('_')[-1]
                if rag_mode == 'avg':
                    rag_num = (model_profile_score_np > model_profile_score_np.mean()).sum().item()
                elif rag_mode == 'norm':
                    # rag_num = ()
                    # model_profile_score_np
                    print(1)
            model_profile_sorted = sorted(list(zip(range(len(model_profile)), model_profile.items())), key=lambda x: x[1][1], reverse=True)[:rag_num]
            model_index_sorted, model_profile_sorted_zip = list(zip(*model_profile_sorted))
            rag_max_tokens = np.array(max_tokens[1:])[np.array(model_index_sorted)]
            rag_model_list, rag_score_list = list(zip(*model_profile_sorted_zip))
            tasks = [self.async_generate_general_cache(question_id, m, messages, mt, 0.7, False) for m, mt in
                     zip(rag_model_list, rag_max_tokens)]
            references = await asyncio.gather(*tasks)
            if self.ref_token_cut:
                references = self.ref_cut(references, max_tokens[-1])
            if self.ref_clean_think:
                references = self.clean_think(references, max_tokens[-1])
            ref_dict = {m: r for r, m in zip(references, rag_model_list)}
            return_dict['ref_dict'] = ref_dict
            agg_N = N if ('agg' in sc_posi and use_sc) or (not use_sc and use_rm) else 1
            if ref_sample == 'all' or agg_N == 1:
                agg_tasks = [async_raw_moa_api(model, messages, None, 0.7, max_tokens[0], 1, references, 1) for _ in range(agg_N)]
            else:
                sub_ref = self.build_sub_ref(mode=ref_sample, references=references, N=agg_N, rag_score_list=rag_score_list)
                agg_tasks = [async_raw_moa_api(model, messages, None, 0.7, max_tokens[0], 1, sub_ref_i, 1) for sub_ref_i in
                             sub_ref]
            raw_responses = await asyncio.gather(*agg_tasks)
            select_score = {i: {'ppl_score': 0.0, 'sc_score': 0.0, 'total_score': 0.0} for i in
                            range(len(raw_responses))}
            if not isinstance(raw_responses[0], str):
                mean_cumulative_logprob = [r['cumulative_logprob'] for r in raw_responses]
                ppl = np.exp(mean_cumulative_logprob)
                responses = [r['response'] for r in raw_responses]
                ppl_score = 1 - ppl
            else:
                responses = raw_responses
                mean_cumulative_logprob = [None for r in raw_responses]
                ppl, ppl_score = None, None
            if agg_N == 1:
                response = responses[0]
            response_pred_list = [self.extract_answer(r) for r in responses]
            self.record_sc_memory(sc_memory, responses)
            # for self consistency
            for i, response_i in enumerate(responses):
                pred_i = response_pred_list[i]
                select_score[i]['ppl_score'] = ppl_score[i]
                select_score[i]['sc_score'] = sc_memory[pred_i] / agg_N
                select_score[i]['total_score'] = sc_memory[pred_i] / agg_N + ppl_coef * ppl_score[i]
            response = responses[sorted(select_score, key=lambda x: select_score[x]['total_score'], reverse=True)[0]]
            pred = self.extract_answer(response)
            return_dict['sc_memory'] = sc_memory
            return_dict['response'] = response
            return_dict['pred'] = pred
            return_dict['n_response'] = responses
            return_dict['rag_model'] = rag_model_list
            return_dict['select_score'] = select_score
            return return_dict

        return asyncio.run(
            wrap_rag_moa_test_(data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef, rm_model, N, sc_posi, ref_sample,rm_model_list))

    def wrap_scale_exp_test(self, data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef, rm_model, N, sc_posi='agg', ref_sample='all'):
        async def wrap_scale_exp_test_(data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef,
                                     rm_model, N, sc_posi, ref_sample):
            explore_num = int(self.mode.split('_')[-1])
            weighted_score = True
            k = 400
            return_dict = {}
            each = test_data[data_id]
            question_id = self.get_question_id(each)
            question = self.get_question(each)
            rag_task = 'Given a question, find the one with the highest semantic similarity and subject similarity from the question bank.'
            question_embedding = F.normalize(
                torch.tensor(generate_general_em('Linq-Embed-Mistral', [question], [rag_task], 8192, 1)).cuda())
            sc_memory = {}
            messages = self.build_messages(each)
            # find_k_near
            scores = (question_embedding @ self.question_bank['embedding_bank'].T) * 100
            scores_topk_value, scores_topk = scores[0].topk(len(scores[0]))
            threshold_bound = scores_topk_value[k] * 0.95
            scores_topk = scores_topk[scores_topk_value > threshold_bound].cpu()
            relative_src = [self.raw_question_bank[f'{j.item()}']['src'] for j in scores_topk]
            relative_question = [self.raw_question_bank[f'{j.item()}']['question'] for j in scores_topk]
            # build the model profile
            model_profile = {}
            for m in model_list:
                if weighted_score:
                    model_profile[m] = (scores_topk_value[:len(scores_topk)].cpu() / 100 *
                                        self.question_bank['model_pred_dict'][m][scores_topk]).sum().item()
                else:
                    model_profile[m] = (self.question_bank['model_pred_dict'][m][scores_topk]).sum().item()
            model_profile_score_np = np.array([model_profile[k] for k in model_profile])
            rag_num = len(model_profile)
            model_profile_sorted = sorted(list(zip(range(len(model_profile)), model_profile.items())),
                                          key=lambda x: x[1][1], reverse=True)[:rag_num]
            model_index_sorted, model_profile_sorted_zip = list(zip(*model_profile_sorted))
            rag_max_tokens = np.array(max_tokens[1:])[np.array(model_index_sorted)]
            rag_model_list, rag_score_list = list(zip(*model_profile_sorted_zip))
            # build the explore process
            all_model_pool = self.model_list[:explore_num]
            now_model_pool = all_model_pool[:2]
            for i in range(2, len(all_model_pool)):
                now_avg = np.array([model_profile[m] for m in now_model_pool]).mean().item()
                if model_profile[all_model_pool[i]] >= now_avg:
                    now_model_pool.append(all_model_pool[i])
            rag_model_list = now_model_pool
            rag_score_list = [model_profile[m] for m in rag_model_list]

            tasks = [self.async_generate_general_cache(question_id, m, messages, mt, 0.7, False) for m, mt in
                     zip(rag_model_list, rag_max_tokens)]
            references = await asyncio.gather(*tasks)
            if self.ref_token_cut:
                references = self.ref_cut(references, max_tokens[-1])
            if self.ref_clean_think:
                references = self.clean_think(references, max_tokens[-1])
            ref_dict = {m: r for r, m in zip(references, rag_model_list)}
            return_dict['ref_dict'] = ref_dict
            agg_N = N if ('agg' in sc_posi and use_sc) or (not use_sc and use_rm) else 1
            if ref_sample == 'all' or agg_N == 1:
                agg_tasks = [async_raw_moa_api(model, messages, None, 0.7, max_tokens[0], 1, references, 1) for _ in
                             range(agg_N)]
            else:
                sub_ref = self.build_sub_ref(mode=ref_sample, references=references, N=agg_N,
                                             rag_score_list=rag_score_list)
                agg_tasks = [async_raw_moa_api(model, messages, None, 0.7, max_tokens[0], 1, sub_ref_i, 1) for sub_ref_i
                             in
                             sub_ref]
            raw_responses = await asyncio.gather(*agg_tasks)
            select_score = {i: {'ppl_score': 0.0, 'sc_score': 0.0, 'total_score': 0.0} for i in
                            range(len(raw_responses))}
            if not isinstance(raw_responses[0], str):
                mean_cumulative_logprob = [r['cumulative_logprob'] for r in raw_responses]
                ppl = np.exp(mean_cumulative_logprob)
                responses = [r['response'] for r in raw_responses]
                ppl_score = 1 - ppl
            else:
                responses = raw_responses
                mean_cumulative_logprob = [None for r in raw_responses]
                ppl, ppl_score = None, None
            if agg_N == 1:
                response = responses[0]
            response_pred_list = [self.extract_answer(r) for r in responses]
            self.record_sc_memory(sc_memory, responses)
            # for self consistency
            for i, response_i in enumerate(responses):
                pred_i = response_pred_list[i]
                select_score[i]['ppl_score'] = ppl_score[i]
                select_score[i]['sc_score'] = sc_memory[pred_i] / agg_N
                select_score[i]['total_score'] = sc_memory[pred_i] / agg_N + ppl_coef * ppl_score[i]
            response = responses[sorted(select_score, key=lambda x: select_score[x]['total_score'], reverse=True)[0]]
            pred = self.extract_answer(response)
            return_dict['sc_memory'] = sc_memory
            return_dict['response'] = response
            return_dict['pred'] = pred
            return_dict['n_response'] = responses
            return_dict['rag_model'] = rag_model_list
            return_dict['select_score'] = select_score
            return return_dict

        return asyncio.run(
            wrap_scale_exp_test_(data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef,
                               rm_model, N, sc_posi, ref_sample))


    def wrap_majority_voting(self, data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef, rm_model, N, sc_posi='agg', ref_sample='all'):
        async def majority_voting_(data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef,
                                     rm_model, N, sc_posi, ref_sample):
            return_dict = {}
            each = test_data[data_id]
            question_id = self.get_question_id(each)
            # the sc memory log each response similarity score and ppl score
            sc_memory = {}

            messages = self.build_messages(each)
            tasks = [self.async_generate_general_cache(question_id, m, messages, mt, 0.7, False) for m, mt in
                     zip(model_list, max_tokens[1:])]
            references = await asyncio.gather(*tasks)
            ref_dict = {m: r for r, m in zip(references, model_list)}
            return_dict['ref_dict'] = ref_dict
            raw_responses =  references
            select_score = {i: {'ppl_score': 0.0, 'sc_score': 0.0, 'total_score': 0.0} for i in
                            range(len(raw_responses))}
            responses = raw_responses
            mean_cumulative_logprob = [None for r in raw_responses]
            ppl, ppl_score = None, None
            response_pred_list = [self.extract_answer(r) for r in responses]
            self.record_sc_memory(sc_memory, responses)
            for i, response_i in enumerate(responses):
                pred_i = response_pred_list[i]
                select_score[i]['sc_score'] = sc_memory[pred_i] / len(responses)
                select_score[i]['total_score'] = sc_memory[pred_i] / len(responses)
            response = responses[sorted(select_score, key=lambda x: select_score[x]['total_score'], reverse=True)[0]]
            pred = self.extract_answer(response)
            return_dict['sc_memory'] = sc_memory
            return_dict['response'] = response
            return_dict['pred'] = pred
            return_dict['n_response'] = responses
            return_dict['select_score'] = select_score
            return return_dict

        return asyncio.run(
            majority_voting_(data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef,
                               rm_model, N, sc_posi, ref_sample))

    def wrap_simple_router(self, data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef, rm_model, N, sc_posi='agg', ref_sample='all'):
        async def simple_router_(data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef,
                                     rm_model, N, sc_posi, ref_sample):
            return_dict = {}
            each = test_data[data_id]
            question_id = self.get_question_id(each)
            # the sc memory log each response similarity score and ppl score
            sc_memory = {}

            messages = self.build_messages(each)
            tasks = [self.async_generate_general_cache(question_id, m, messages, mt, 0.7, False) for m, mt in
                     zip(model_list, max_tokens[1:])]
            references = await asyncio.gather(*tasks)
            ref_dict = {m: r for r, m in zip(references, model_list)}
            return_dict['ref_dict'] = ref_dict
            raw_responses =  references
            select_score = {i: {'ppl_score': 0.0, 'sc_score': 0.0, 'total_score': 0.0} for i in
                            range(len(raw_responses))}
            responses = raw_responses
            mean_cumulative_logprob = [None for r in raw_responses]
            ppl, ppl_score = None, None
            response_pred_list = [self.extract_answer(r) for r in responses]
            self.record_sc_memory(sc_memory, responses)
            for i, response_i in enumerate(responses):
                pred_i = response_pred_list[i]
                select_score[i]['sc_score'] = sc_memory[pred_i] / len(responses)
                select_score[i]['total_score'] = sc_memory[pred_i] / len(responses)
            response = responses[sorted(select_score, key=lambda x: select_score[x]['total_score'], reverse=True)[0]]
            pred = self.extract_answer(response)
            return_dict['sc_memory'] = sc_memory
            return_dict['response'] = response
            return_dict['pred'] = pred
            return_dict['n_response'] = responses
            return_dict['select_score'] = select_score
            return return_dict

        return asyncio.run(
            simple_router_(data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef,
                               rm_model, N, sc_posi, ref_sample))

    def wrap_analyze_prior(self, data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef, rm_model, N, sc_posi='agg', ref_sample='all'):
        async def analyze_prior_(data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef,
                                   rm_model, N, sc_posi, ref_sample):
            # analyze the rag
            rag_num = 15
            weighted_score = True
            k = 400
            return_dict = {}
            each = test_data[data_id]
            question_id = self.get_question_id(each)
            question = self.get_question(each)
            rag_task = 'Given a question, find the one with the highest semantic similarity and subject similarity from the question bank.'
            question_embedding = F.normalize(
                torch.tensor(generate_general_em('Linq-Embed-Mistral', [question], [rag_task], 8192, 1)).cuda())
            sc_memory = {}
            messages = self.build_messages(each)
            tasks = [self.async_generate_general_cache(question_id, m, messages, mt, 0.7, False) for m, mt in
                     zip(model_list, max_tokens[1:])]
            references = await asyncio.gather(*tasks)
            ref_dict = {m: r for r, m in zip(references, model_list)}
            return_dict['ref_dict'] = ref_dict
            # find_k_near
            scores = (question_embedding @ self.question_bank['embedding_bank'].T) * 100
            scores_topk_value, scores_topk = scores[0].topk(len(scores[0]))
            threshold_bound = scores_topk_value[k] * 0.95
            scores_topk = scores_topk[scores_topk_value > threshold_bound].cpu()
            relative_src = [self.raw_question_bank[f'{j.item()}']['src'] for j in scores_topk]
            relative_question = [self.raw_question_bank[f'{j.item()}']['question'] for j in scores_topk]
            # build the model profile
            model_profile = {}
            for m in model_list:
                if weighted_score:
                    model_profile[m] = (scores_topk_value[:len(scores_topk)].cpu() / 100 *
                                        self.question_bank['model_pred_dict'][m][scores_topk]).sum().item()
                else:
                    model_profile[m] = (self.question_bank['model_pred_dict'][m][scores_topk]).sum().item()
            model_profile_score_np = np.array([model_profile[k] for k in model_profile])
            model_profile_sorted = sorted(list(zip(range(len(model_profile)), model_profile.items())),
                                          key=lambda x: x[1][1], reverse=True)[:rag_num]
            model_index_sorted, model_profile_sorted_zip = list(zip(*model_profile_sorted))
            rag_model_list, rag_score_list = list(zip(*model_profile_sorted_zip))
            return_dict['relative_src'] = relative_src
            return_dict['model_profile'] = model_profile
            return return_dict


        return asyncio.run(
            analyze_prior_(data_id, test_data, dev_df, model_list, model, max_tokens, use_sc, use_rm, ppl_coef,
                             rm_model, N, sc_posi, ref_sample))

    def standard_moa(self):
        data_index = list(range(0, self.test_data_num, self.max_process)) + [self.test_data_num]
        test_data_id = self.get_test_data_id()
        has_test_id = self.get_has_test_id()
        for i in tqdm(range(len(data_index) - 1)):
            # build the task for data parallel
            data_id_range = range(data_index[i], data_index[i + 1])
            tasks = [[data_id, self.test_data, self.val_data, self.model_list, self.model, self.max_tokens_list,
                      self.use_sc, self.use_rm, self.rm_model, self.N] for data_id in data_id_range]
            with multiprocessing.Pool(processes=min(len(tasks), self.max_process)) as pool:
                response_pred_list = pool.starmap(self.wrap_raw_moa_test, tasks)
            print(1)


