"""PPL FederatedInferencer"""

from typing import List, Union, Optional, Dict, Any
import tqdm
from tqdm import trange
from collections import Counter
import json
import copy
import faiss

import torch
from torch.utils.data import DataLoader

from datasets import concatenate_datasets, load_dataset, DatasetDict, Dataset
from transformers import PretrainedConfig
from transformers import AutoTokenizer

from sentence_transformers import SentenceTransformer
from accelerate import Accelerator

from openicl.icl_dataset_reader import DatasetEncoder
from openicl.icl_retriever import BaseRetriever, TopkRetriever
from openicl.utils.collators import DataCollatorWithPaddingAndCuda
from openicl import PromptTemplate, PPLInferencer, DatasetReader
from openicl.icl_retriever import *
from openicl.icl_evaluator import *
from openicl.icl_inferencer.icl_base_inferencer import PPLInferencerOutputHandler

from openicl.utils.logging import get_logger
from openicl.utils.api_service import *

from util.concatenate import (
    collect_samples,
    simple_concatenate,
    merge_concatenate,
    random_concatenate,
)
from util.icl_ppl_fed_inferencer import PPLFedInferencer
from util.retriever import get_retriever, get_server_retriever
from util.misc import save_json, AverageMeter, setup_seed
from util.dataset import QueryBudgetDataset
from models import get_model


logger = get_logger(__name__)


class PPLFedNeighborQueryInferencer(PPLFedInferencer):
    def __init__(
        self,
        model_name: Optional[str] = "gpt2-xl",
        tokenizer_name: Optional[str] = None,
        max_model_token_num: Optional[int] = None,
        model_config: Optional[PretrainedConfig] = None,
        batch_size: Optional[int] = 1,
        accelerator: Optional[Accelerator] = None,
        output_json_filepath: Optional[str] = "./server_icl_inference_output",
        output_json_filename: Optional[str] = "predictions",
        api_name: Optional[str] = None,
        labels: Optional[List] = None,
        model_parallel: Optional[bool] = False,
        args=None,
        sentence_transformers_model_name: Optional[str] = "all-mpnet-base-v2",
        device: Optional[str] = None,
        **kwargs,
    ) -> None:
        super().__init__(
            model_name,
            tokenizer_name,
            max_model_token_num,
            model_config,
            batch_size,
            accelerator,
            output_json_filepath,
            output_json_filename,
            api_name,
            labels,
            model_parallel,
            args,
            sentence_transformers_model_name,
            device,
            **kwargs,
        )
        self.budget_model = None

    def inference(
        self,
        retrievers: List[BaseRetriever],
        query_dataset: Union[Dataset, DatasetDict],
        query_split: Optional[str] = None,
        ice_template: Optional[PromptTemplate] = None,
        prompt_template: Optional[PromptTemplate] = None,
        output_json_filepath: Optional[str] = None,
        output_json_filename: Optional[str] = None,
        concat: Optional[str] = "simple",
        normalizing_str: Optional[str] = None,
        args=None,
    ):
        # assert self.budget_model is not None
        num_clients = len(retrievers)
        if isinstance(query_dataset, DatasetDict):
            query_dataset = query_dataset[query_split]

        query_num = len(query_dataset)

        # 1. Preparation for output logs
        output_handler = PPLInferencerOutputHandler(self.accelerator)

        sub_predictions = []
        ppl = []
        ice = []  # {cid: [] for cid in range(num_clients)}

        if output_json_filepath is None:
            output_json_filepath = self.output_json_filepath
        if output_json_filename is None:
            output_json_filename = self.output_json_filename

        # get embedding for queries
        query_orig_idxs = query_dataset["idx"]
        query_datalist = retrievers[0].dataset_reader.generate_input_field_corpus(
            query_dataset
        )  # TODO: ugly code because of `dataset_reader.generate_input_field_corpus`
        query_dataloader = retrievers[0].create_dataloader(query_datalist)
        query_res_list = retrievers[0].forward(
            self.retriever_model,
            query_dataloader,
            orig_idxs=query_orig_idxs,
            process_bar=True,
            information="Embedding test data queries...",
        )  # each element is {'embed': np.array, 'metadata': {'id': int, 'len': int, 'text': sample_text}, 'idx': int}

        # try to construct server ice datasets by constructing nearest test query's ICE set
        # 2. Get results of retrieval process
        server_ice_datasets = self._construct_server_ice_dataset(
            retrievers=retrievers,
            query_res_list=query_res_list,
            concat=concat,
            output_json_filepath=output_json_filepath,
        )  # [q1_ice_Dataset, q2_ice_Dataset, ...]

        # 3. Get labels of all the classes (need to check all client's local dataset since NonIID partition)
        if self.labels is None:
            labels = []
            for cid in range(num_clients):
                labels.extend(
                    retrievers[cid].get_labels(
                        ice_template=ice_template, prompt_template=prompt_template
                    )
                )
            labels = list(set(labels))
        else:
            labels = self.labels

        # 4. Generate in-context examples indices for testing inputs
        ice_dataset_retrievers = []
        server_sample_orig_idx = {}  # [[] for _ in range(query_num)]
        for qid in trange(
            query_num,
            disable=not self.is_main_process,
            desc="Server side generating ICE: ",
        ):
            data_dict = DatasetDict(
                {
                    "train": server_ice_datasets[qid],
                    "test": Dataset.from_list([query_dataset[qid]]),
                }
            )
            data = DatasetReader(
                data_dict,
                input_columns=retrievers[0].dataset_reader.input_columns,
                output_column=retrievers[0].dataset_reader.output_column,
            )
            per_query_retriever = get_server_retriever(args)(
                data, ice_num=args.server_ice_num
            )

            # check whether need to do reorder on server side
            ice_num = len(server_ice_datasets[qid])
            ice_idxs = list(range(ice_num))

            q_orig_idx = per_query_retriever.test_ds[0]["idx"]
            server_sample_orig_idx[q_orig_idx] = per_query_retriever.index_ds.select(
                ice_idxs
            )["idx"]
            # use selected samples in each query's ICE dataset to generate ICE
            ice.append(
                per_query_retriever.generate_ice(ice_idxs, ice_template=ice_template)
            )
            ice_dataset_retrievers.append(per_query_retriever)

        output_handler.save_ice(ice)
        server_ice_index_file = f"{output_json_filepath}/server_ice_indices.json"
        save_json(server_sample_orig_idx, server_ice_index_file)

        # 5. Calculating PPL for prompts in each label's class
        for label in labels:
            index = 0
            prompt_list = []
            sub_ppl_list = []
            normalizing_prompt_list = []
            context_length_list = []

            # 5.1 Generate prompts of current label and truncate
            for qid in range(query_num):
                prompt = ice_dataset_retrievers[qid].generate_label_prompt(
                    0,
                    ice[qid],
                    label,
                    ice_template=ice_template,
                    prompt_template=prompt_template,
                    remain_sep=normalizing_str is not None,
                )

                # TODO: max_model_token_num is for generation task, temporory removed
                # if self.max_model_token_num is not None and self.api_name != "gpt3":
                #     prompt_token_num = self.get_input_token_num(prompt)
                #     while (
                #         len(ice_idx_list[idx]) > 0
                #         and prompt_token_num > self.max_model_token_num
                #     ):
                #         ice_idx_list[idx] = ice_idx_list[idx][:-1]
                #         ice[idx] = retriever.generate_ice(
                #             ice_idx_list[idx], ice_template=ice_template
                #         )
                #         prompt = retriever.generate_label_prompt(
                #             idx,
                #             ice[idx],
                #             label,
                #             ice_template=ice_template,
                #             prompt_template=prompt_template,
                #         )
                #         prompt_token_num = self.get_input_token_num(prompt)

                if normalizing_str is not None:
                    prompt_sep = prompt
                    if prompt_template is not None:
                        sep_token = prompt_template.sep_token
                    else:
                        sep_token = ice_template.sep_token
                    sep_pos = prompt_sep.find(sep_token)

                    context = prompt_sep[0:sep_pos]
                    answer = prompt_sep[sep_pos:].replace(sep_token, "")
                    prompt = context + answer
                    normalizing_prompt = normalizing_str + answer

                    context_length_list.append(self.get_input_token_num(context))
                    normalizing_prompt_list.append(normalizing_prompt)
                prompt_list.append(prompt)

            if normalizing_str is not None:
                normalizing_str_len = self.get_input_token_num(normalizing_str)

            # 5.2 Get PPL: loop through all queries with current give label
            logger.info(f"Calculating PPL for prompts labeled '{label}'")
            for idx in trange(
                0, len(prompt_list), self.batch_size, disable=not self.is_main_process
            ):
                sub_prompt_list = prompt_list[idx : idx + self.batch_size]
                if normalizing_str is not None:
                    sub_context_length_list = context_length_list[
                        idx : idx + self.batch_size
                    ]
                    sub_normalizing_prompt_list = normalizing_prompt_list[
                        idx : idx + self.batch_size
                    ]

                with torch.no_grad():
                    if normalizing_str is not None:
                        res1 = self.__get_ppl(
                            input_texts=sub_prompt_list,
                            mask_length=sub_context_length_list,
                        )
                        res2 = self.__get_ppl(
                            input_texts=sub_normalizing_prompt_list,
                            mask_length=[
                                normalizing_str_len for i in range(len(sub_prompt_list))
                            ],
                        )
                        sub_res = res1 - res2
                    else:
                        sub_res = self.__get_ppl(
                            sub_prompt_list
                        ).tolist()  # list of scores, length is batch size

                for batch_offset, (res, prompt) in enumerate(
                    zip(sub_res, sub_prompt_list)
                ):
                    sub_ppl_list.append(res)
                    output_handler.save_prompt_and_ppl(
                        label,
                        prompt[len(ice[idx + batch_offset]) :],
                        prompt,
                        res,
                        index,
                    )
                    index = index + 1
            ppl.append(sub_ppl_list)

        # 6. Get lowest PPL class as predictions
        ppl = list(zip(*ppl))
        for single_ppl in ppl:
            sub_predictions.append(labels[single_ppl.index(min(single_ppl))])
        output_handler.save_predictions(sub_predictions)

        # 7. Output
        output_handler.subprocess_write_to_json(
            output_json_filepath, output_json_filename
        )
        if self.accelerator is not None:
            self.accelerator.wait_for_everyone()
        output_handler.merge_to_main_process(output_json_filepath, output_json_filename)
        output_handler.write_to_json(output_json_filepath, output_json_filename)

        return [sample["prediction"] for sample in output_handler.results_dict.values()]

    def proxy_data_opt_budget(
        self,
        retrievers: List[BaseRetriever],
        query_dataset: Union[Dataset, DatasetDict],
        query_split: Optional[str] = None,
        ice_template: Optional[PromptTemplate] = None,
        prompt_template: Optional[PromptTemplate] = None,
        output_json_filepath: Optional[str] = None,
        output_json_filename: Optional[str] = "proxy_predictions",
        concat: Optional[str] = "reorder",
        normalizing_str: Optional[str] = None,
        inference: Optional[bool] = False,
        args=None,
        prefix="proxy",
        local_ice_num=None,
    ):
        num_clients = len(retrievers)
        if isinstance(query_dataset, DatasetDict):
            query_dataset = query_dataset[query_split]

        query_num = len(query_dataset)

        # 1. Preparation for output logs
        output_handler = PPLInferencerOutputHandler(self.accelerator)

        sub_predictions = []
        ppl = []
        ice = []  # {cid: [] for cid in range(num_clients)}

        if output_json_filepath is None:
            output_json_filepath = self.output_json_filepath
        if output_json_filename is None:
            output_json_filename = "proxy_predictions"

        # 2. Get results of retrieval process for the whole proxy dataset
        if local_ice_num is not None:
            local_ice_num = args.server_ice_num
        server_ice_datasets = self._construct_server_ice_dataset(
            retrievers=retrievers,
            query_dataset=query_dataset,
            ice_num=local_ice_num,
            concat="simple",
            output_json_filepath=output_json_filepath,
            prefix=prefix,
        )  # [q1_ice_Dataset, q2_ice_Dataset, ...]

        # 3. Get labels of all the classes (need to check all client's local dataset since NonIID partition)
        # do this only if `inference` is `True`
        if inference:
            if self.labels is None:
                labels = []
                for cid in range(num_clients):
                    labels.extend(
                        retrievers[cid].get_labels(
                            ice_template=ice_template, prompt_template=prompt_template
                        )
                    )
                labels = list(set(labels))
            else:
                labels = self.labels

        # 4. Generate in-context examples indices for testing inputs
        ice_dataset_retrievers = []
        server_sample_orig_idx = {}  # [[] for _ in range(query_num)]
        ice_client_source = {}  # list of source client ID for each proxy query
        opt_per_client_budget = {}
        for qid in trange(
            query_num,
            disable=not self.is_main_process,
            desc=f"Server calculates each client's Top-{args.server_ice_num} ICE contribution",
        ):
            data_dict = DatasetDict(
                {
                    "train": server_ice_datasets[qid],
                    "test": Dataset.from_list([query_dataset[qid]]),
                }
            )
            data = DatasetReader(
                data_dict, input_columns=["sentence"], output_column="label"
            )
            per_query_retriever = get_server_retriever(args)(
                data, ice_num=args.server_ice_num
            )

            # perform reorder on server side
            ice_num = len(server_ice_datasets[qid])
            ice_idxs = per_query_retriever.retrieve(
                model=self.retriever_model, use_trange=False
            )[0]

            q_orig_idx = per_query_retriever.test_ds[0]["idx"]
            server_sample_orig_idx[q_orig_idx] = per_query_retriever.index_ds.select(
                ice_idxs
            )["idx"]
            per_query_ice_cids = per_query_retriever.index_ds.select(ice_idxs)[
                "cid"
            ]  # get client ID of each selected ICE
            ice_client_source[q_orig_idx] = per_query_ice_cids

            # count each client's contribution in the server ICE set
            counter = Counter(per_query_ice_cids)
            opt_per_client_budget[q_orig_idx] = [
                counter.get(k, 0) for k in range(num_clients)
            ]

            # use selected samples in each query's ICE dataset to generate ICE
            ice.append(
                per_query_retriever.generate_ice(ice_idxs, ice_template=ice_template)
            )
            ice_dataset_retrievers.append(per_query_retriever)

        # save proxy dataset server side ICE
        output_handler.save_ice(ice)
        server_ice_index_file = (
            f"{output_json_filepath}/{prefix}_server_ice_indices.json"
        )
        save_json(server_sample_orig_idx, server_ice_index_file)

        # 5. Calculating PPL for prompts in each label's class
        if inference:
            for label in labels:
                index = 0
                prompt_list = []
                sub_ppl_list = []
                normalizing_prompt_list = []
                context_length_list = []

                # 5.1 Generate prompts of current label and truncate
                for qid in range(query_num):
                    prompt = ice_dataset_retrievers[qid].generate_label_prompt(
                        0,
                        ice[qid],
                        label,
                        ice_template=ice_template,
                        prompt_template=prompt_template,
                        remain_sep=normalizing_str is not None,
                    )

                    # TODO: max_model_token_num is for generation task, temporory removed
                    # if self.max_model_token_num is not None and self.api_name != "gpt3":
                    #     prompt_token_num = self.get_input_token_num(prompt)
                    #     while (
                    #         len(ice_idx_list[idx]) > 0
                    #         and prompt_token_num > self.max_model_token_num
                    #     ):
                    #         ice_idx_list[idx] = ice_idx_list[idx][:-1]
                    #         ice[idx] = retriever.generate_ice(
                    #             ice_idx_list[idx], ice_template=ice_template
                    #         )
                    #         prompt = retriever.generate_label_prompt(
                    #             idx,
                    #             ice[idx],
                    #             label,
                    #             ice_template=ice_template,
                    #             prompt_template=prompt_template,
                    #         )
                    #         prompt_token_num = self.get_input_token_num(prompt)

                    if normalizing_str is not None:
                        prompt_sep = prompt
                        if prompt_template is not None:
                            sep_token = prompt_template.sep_token
                        else:
                            sep_token = ice_template.sep_token
                        sep_pos = prompt_sep.find(sep_token)

                        context = prompt_sep[0:sep_pos]
                        answer = prompt_sep[sep_pos:].replace(sep_token, "")
                        prompt = context + answer
                        normalizing_prompt = normalizing_str + answer

                        context_length_list.append(self.get_input_token_num(context))
                        normalizing_prompt_list.append(normalizing_prompt)
                    prompt_list.append(prompt)

                if normalizing_str is not None:
                    normalizing_str_len = self.get_input_token_num(normalizing_str)

                # 5.2 Get PPL: loop through all queries with current give label
                logger.info(f"Calculating PPL for prompts labeled '{label}'")
                for idx in trange(
                    0,
                    len(prompt_list),
                    self.batch_size,
                    disable=not self.is_main_process,
                ):
                    sub_prompt_list = prompt_list[idx : idx + self.batch_size]
                    if normalizing_str is not None:
                        sub_context_length_list = context_length_list[
                            idx : idx + self.batch_size
                        ]
                        sub_normalizing_prompt_list = normalizing_prompt_list[
                            idx : idx + self.batch_size
                        ]

                    with torch.no_grad():
                        if normalizing_str is not None:
                            res1 = self.__get_ppl(
                                input_texts=sub_prompt_list,
                                mask_length=sub_context_length_list,
                            )
                            res2 = self.__get_ppl(
                                input_texts=sub_normalizing_prompt_list,
                                mask_length=[
                                    normalizing_str_len
                                    for i in range(len(sub_prompt_list))
                                ],
                            )
                            sub_res = res1 - res2
                        else:
                            sub_res = self.__get_ppl(
                                sub_prompt_list
                            ).tolist()  # list of scores, length is batch size

                    for batch_offset, (res, prompt) in enumerate(
                        zip(sub_res, sub_prompt_list)
                    ):
                        sub_ppl_list.append(res)
                        output_handler.save_prompt_and_ppl(
                            label,
                            prompt[len(ice[idx + batch_offset]) :],
                            prompt,
                            res,
                            index,
                        )
                        index = index + 1
                ppl.append(sub_ppl_list)

            # 6. Get lowest PPL class as predictions
            ppl = list(zip(*ppl))
            for single_ppl in ppl:
                sub_predictions.append(labels[single_ppl.index(min(single_ppl))])
            output_handler.save_predictions(sub_predictions)

            # 7. Output
            output_handler.subprocess_write_to_json(
                output_json_filepath, output_json_filename
            )
            if self.accelerator is not None:
                self.accelerator.wait_for_everyone()
            output_handler.merge_to_main_process(
                output_json_filepath, output_json_filename
            )
            output_handler.write_to_json(output_json_filepath, output_json_filename)

            proxy_predictions = [
                sample["prediction"] for sample in output_handler.results_dict.values()
            ]
        else:
            proxy_predictions = None

        save_json(
            ice_client_source,
            os.path.join(output_json_filepath, f"{prefix}_ice_client_source.json"),
        )
        save_json(
            opt_per_client_budget,
            os.path.join(output_json_filepath, f"{prefix}_opt_client_budget.json"),
        )

        return ice_client_source, opt_per_client_budget, proxy_predictions

    def _construct_server_ice_dataset(
        self,
        retrievers: List[BaseRetriever],
        query_res_list: List[Dict[Any, Any]],
        ice_num: Optional[int] = None,
        concat: Optional[str] = "simple",
        output_json_filepath: Optional[str] = None,
        prefix: Optional[str] = None,
    ):
        """construct server side ICE dataset for a query, using the ICE of the most similar test query

        Args:
            retrievers (List[BaseRetriever]): _description_
            query_dataset (Dataset): _description_
            ice_num (Optional[int], optional): _description_. Defaults to None.
            concat (Optional[str], optional): _description_. Defaults to "simple".
            output_json_filepath (Optional[str], optional): _description_. Defaults to None.
            prefix (Optional[str], optional): _description_. Defaults to None.
        """
        num_clients = len(retrievers)
        query_num = len(query_res_list)

        ice_idx_dict = {}

        ice_budgets = [
            [self.args.local_ice_num for _ in range(query_num)]
            for _ in range(num_clients)
        ]  # just set fixed ICE budget

        # ----- try to construct neighbor query res list based on similarity between query embeddings
        # build searching index for query set
        query_index = faiss.IndexIDMap(
            faiss.IndexFlatIP(self.retriever_model.get_sentence_embedding_dimension())
        )
        tmp_id_list = np.array([res["metadata"]["id"] for res in query_res_list])
        tmp_embed_list = np.stack([res["embed"] for res in query_res_list])
        query_index.add_with_ids(tmp_embed_list, tmp_id_list)

        neighbor_query_res_list = []
        neighbor_query_orig_idxs = {}
        neighbor_num = 16
        for entry in tqdm.tqdm(
            query_res_list,
            desc=f"Retrieve top-{neighbor_num} similar test query for each test query",
        ):
            idx = entry["metadata"]["id"]
            q_orig_idx = entry["idx"]
            embed = np.expand_dims(entry["embed"], axis=0)
            near_ids = query_index.search(embed, neighbor_num + 1)[1][0].tolist()[
                1:
            ]  # the top 1 should be test sample itself, so only keep the id after the first one
            near_query_res = copy.deepcopy(query_res_list[near_ids[0]])
            near_query_res["metadata"]["id"] = idx
            neighbor_query_res_list.append(near_query_res)

            neighbor_query_orig_idxs[q_orig_idx] = []
            for near_id in near_ids:
                near_orig_idx = query_res_list[near_id]["idx"]
                neighbor_query_orig_idxs[q_orig_idx].append(near_orig_idx)

        neighbor_query_file_name = (
            f"{prefix+'_' if prefix is not None else ''}neighbor_query_orig_idxs.json"
        )
        save_json(
            neighbor_query_orig_idxs,
            os.path.join(self.args.log_dir, neighbor_query_file_name),
        )

        for cid in range(num_clients):
            print(f"Client {cid} retrieves local ICE samples...")
            ice_idx_dict[cid] = retrievers[cid].retrieve_with_budget(
                neighbor_query_res_list,
                ice_budgets=ice_budgets[cid],
                model=self.retriever_model,
            )

        samples_for_all_query = collect_samples(ice_idx_dict, retrievers)

        # ice original index in original training set
        local_sample_orig_idx = {}  # [[] for _ in range(query_num)]
        for qid in range(query_num):
            q_orig_idx = query_res_list[qid]["idx"]
            local_sample_orig_idx[q_orig_idx] = {}
            for cid in range(num_clients):
                orig_idx = samples_for_all_query[qid][cid]["idx"]
                local_sample_orig_idx[q_orig_idx][cid] = orig_idx

        local_ice_index_file = f"{output_json_filepath}/local_ice_indices.json"
        save_json(local_sample_orig_idx, local_ice_index_file)

        if concat in ["simple", "reorder"]:
            concat_func = simple_concatenate
        elif concat == "random":
            concat_func = random_concatenate
        elif concat == "merge":
            concat_func = merge_concatenate
        else:
            raise ValueError(
                f"concate should be either 'simple', 'merge', or 'reorder', rather than '{concat}'."
            )

        server_ice = []
        # for each test query, perform concatenate for server side ICE dataset based on all clients local ICE dataset
        for qid in trange(
            query_num,
            disable=False,
            desc="Construct server side ICE dataset for each query: ",
        ):
            ice_data = concat_func(
                samples_for_all_query[qid],
                chunk_sample_num=self.args.server_ice_num,
                seed=self.args.seed,
            )
            server_ice.append(ice_data)

        return server_ice

    # def train_local_budget_model(
    #     self,
    #     retrievers: List[BaseRetriever],
    #     query_dataset: Union[Dataset, DatasetDict],
    #     query_split: Optional[str] = None,
    #     opt_client_budget: Dict[int, List[int]] = None,
    #     model_name: Optional[str] = "SMLP",
    #     model_width: Optional[int] = 380,
    #     epochs: Optional[int] = 50,
    #     lr: Optional[float] = 0.02,
    #     batch_size: Optional[int] = 8,
    #     train_ratio: Optional[float] = 1,
    #     seed: Optional[int] = 0,
    # ):
    #     if isinstance(query_dataset, DatasetDict):
    #         query_dataset = query_dataset[query_split]

    #     assert train_ratio > 0
    #     assert train_ratio <= 1

    #     query_num = len(query_dataset)

    #     if self.args.retriever != "topk":
    #         raise ValueError("Currently only support 'topk' retriever.")

    #     self.retriever_model.eval()

    #     # transform proxy data query to embedding vectors
    #     query_orig_idxs = query_dataset["idx"]
    #     query_datalist = retrievers[0].dataset_reader.generate_input_field_corpus(
    #         query_dataset
    #     )  # TODO: ugly code because of `dataset_reader.generate_input_field_corpus`
    #     query_dataloader = retrievers[0].create_dataloader(query_datalist)

    #     res_list = retrievers[0].forward(
    #         self.retriever_model,
    #         query_dataloader,
    #         orig_idxs=query_orig_idxs,
    #         process_bar=True,
    #         information="Embedding proxy data queries...",
    #     )  # each element is {'embed': np.array, 'metadata': {'id': int, 'len': int, 'text': sample_text}, 'idx': int}

    #     # construct dataset for budget model training
    #     # split dataset into train-eval
    #     train_num = int(query_num * train_ratio)
    #     eval_num = query_num - train_num
    #     rng = np.random.default_rng(seed=seed)
    #     train_set_idxs = sorted(rng.choice(query_num, train_num, replace=False))
    #     eval_set_idxs = [i for i in range(query_num) if i not in train_set_idxs]

    #     # train data
    #     train_embeds = [res_list[idx]["embed"] for idx in train_set_idxs]
    #     train_budgets = [
    #         opt_client_budget[res_list[idx]["idx"]] for idx in train_set_idxs
    #     ]
    #     train_dataset = QueryBudgetDataset(embeds=train_embeds, budgets=train_budgets)
    #     train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    #     # eval data
    #     if eval_num > 0:
    #         eval_embeds = [res_list[idx]["embed"] for idx in eval_set_idxs]
    #         eval_budgets = [
    #             opt_client_budget[res_list[idx]["idx"]] for idx in eval_set_idxs
    #         ]
    #         eval_dataset = QueryBudgetDataset(embeds=eval_embeds, budgets=eval_budgets)
    #         eval_loader = DataLoader(eval_dataset, batch_size=16)

    #     # prepare the model
    #     setup_seed(seed)
    #     embedding_size = res_list[0]["embed"].shape  # size is (dim,)
    #     model = get_model(
    #         model_name=model_name,
    #         output_size=self.args.num_clients,
    #         data_shape=embedding_size,
    #         width=model_width,
    #     )
    #     model = model.to(self.device)

    #     optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    #     criterion = torch.nn.MSELoss(reduction="none")

    #     train_loss_hist = []
    #     eval_loss_hist = []
    #     for epoch in range(epochs):
    #         # train
    #         train_loss = self._epoch_train_budget_model(
    #             model, criterion, optimizer, train_loader
    #         )
    #         train_loss_hist.append(train_loss)

    #         # eval
    #         if eval_num > 0:
    #             eval_loss = self._eval_budget_model(model, criterion, eval_loader)
    #         else:
    #             eval_loss = None
    #         eval_loss_hist.append(eval_loss)

    #         if (epoch + 1) % 10 == 0:
    #             logger.info(
    #                 f"Eopch [{epoch+1}/{epochs}], train_loss={train_loss :.10f}, eval_loss={eval_loss:f}"
    #             )

    #     self.train_loss_hist = train_loss_hist
    #     self.eval_loss_hist = eval_loss_hist
    #     self.budget_model = model
    #     torch.save(
    #         model.state_dict(), os.path.join(self.args.log_dir, "budget_model.pt")
    #     )
    #     self.budget_model_path = os.path.join(self.args.log_dir, "budget_model.pt")
    #     return train_loss_hist, eval_loss_hist, res_list  # train_dataset, eval_dataset

    # def _epoch_train_budget_model(self, model, criterion, optimizer, data_loader):
    #     loss_stat = AverageMeter()
    #     model.train()
    #     for batch_idx, (embed, budget) in enumerate(data_loader):
    #         batch_size = budget.shape[0]

    #         embed = embed.to(self.device)
    #         budget = budget.to(self.device)
    #         outputs = model(embed)

    #         tmp_loss = criterion(outputs, budget)
    #         loss = tmp_loss.sum(dim=1).mean()

    #         optimizer.zero_grad()
    #         loss.backward()
    #         optimizer.step()

    #         loss_stat.update(loss.item(), batch_size)

    #     return loss_stat.avg

    # def _eval_budget_model(self, model, criterion, data_loader):
    #     model.eval()
    #     loss_stat = AverageMeter()
    #     with torch.no_grad():
    #         for batch_idx, (embed, budget) in enumerate(data_loader):
    #             batch_size = budget.shape[0]

    #             embed = embed.to(self.device)
    #             budget = budget.to(self.device)
    #             outputs = model(embed)

    #             tmp_loss = criterion(outputs, budget)
    #             loss = tmp_loss.sum(dim=1).mean()

    #             loss_stat.update(loss.item(), batch_size)

    #     return loss_stat.avg

    # def evaluate_local_budget_model(
    #     self,
    #     retrievers: List[BaseRetriever],
    #     query_dataset: Union[Dataset, DatasetDict],
    #     query_split: Optional[str] = None,
    #     model: Optional[torch.nn.Module] = None,
    #     opt_client_budget: Dict[int, List[int]] = None,
    # ):
    #     if isinstance(query_dataset, DatasetDict):
    #         query_dataset = query_dataset[query_split]

    #     assert model is not None

    #     self.retriever_model.eval()

    #     # transform proxy data query to embedding vectors
    #     query_orig_idxs = query_dataset["idx"]
    #     query_datalist = retrievers[0].dataset_reader.generate_input_field_corpus(
    #         query_dataset
    #     )  # TODO: ugly code because of `dataset_reader.generate_input_field_corpus`
    #     query_dataloader = retrievers[0].create_dataloader(query_datalist)

    #     res_list = retrievers[0].forward(
    #         self.retriever_model,
    #         query_dataloader,
    #         orig_idxs=query_orig_idxs,
    #         process_bar=True,
    #         information="Embedding query data queries...",
    #     )  # each element is {'embed': np.array, 'metadata': {'id': int, 'len': int, 'text': sample_text}, 'idx': int}

    #     # construct dataset for budget model training
    #     embeds = [res["embed"] for res in res_list]
    #     budgets = [opt_client_budget[res["idx"]] for res in res_list]
    #     budget_dataset = QueryBudgetDataset(embeds=embeds, budgets=budgets)
    #     embed_loader = DataLoader(budget_dataset, batch_size=16)

    #     model = model.to(self.device)  # TODO: check this
    #     criterion = torch.nn.MSELoss(reduction="none")

    #     model.eval()  # TODO: check this
    #     loss_stat = AverageMeter()
    #     with torch.no_grad():
    #         for batch_idx, (embed, budget) in enumerate(embed_loader):
    #             batch_size = budget.shape[0]

    #             embed = embed.to(self.device)
    #             budget = budget.to(self.device)
    #             outputs = model(embed)

    #             tmp_loss = criterion(outputs, budget)
    #             loss = tmp_loss.sum(dim=1).mean()

    #             loss_stat.update(loss.item(), batch_size)

    #     print(f"Eval loss={loss_stat.avg:.10f}")
    #     return loss_stat.avg

    # def _get_budgets_model_prediction(self, query_res_list):
    #     total_budget = self.args.server_ice_num
    #     query_num = len(query_res_list)
    #     num_clients = self.args.num_clients
    #     self.budget_model.eval()
    #     model_outputs = []
    #     with torch.no_grad():
    #         for query_res in tqdm.tqdm(
    #             query_res_list,
    #             desc="Predict local budgets based on query embedding",
    #         ):
    #             embed = torch.tensor(
    #                 query_res["embed"], device=self.device, requires_grad=False
    #             ).reshape(1, -1)
    #             output = self.budget_model(embed)
    #             pred = output.detach().cpu().numpy()[0] * total_budget
    #             model_outputs.append(pred.tolist())

    #     budgets_prediction = []
    #     for qid in range(query_num):
    #         tmp = [int(budget) for budget in model_outputs[qid]]
    #         # TODO: randomly make up budget number
    #         if sum(tmp) < total_budget:
    #             selected_cids = sorted(
    #                 np.random.choice(
    #                     num_clients, total_budget - sum(tmp), replace=False
    #                 )
    #             )
    #             for cid in selected_cids:
    #                 tmp[cid] += 1

    #         budgets_prediction.append(tmp)

    #     return budgets_prediction

    def __get_ppl(self, input_texts: List[str], mask_length=None):
        return self._PPLInferencer__get_ppl(input_texts, mask_length)

    # def _get_ppl(self, input_texts: List[str], mask_length=None):
    #     # TODO: only for debug
    #     return self._PPLInferencer__get_ppl(input_texts, mask_length)
