"""PPL FederatedInferencer"""

from typing import List, Union, Optional, Dict, Any
import tqdm
from tqdm import trange
from collections import Counter
import json
import copy


import torch
from torch.utils.data import DataLoader

from datasets import concatenate_datasets, load_dataset, DatasetDict, Dataset
from transformers import PretrainedConfig
from transformers import AutoTokenizer

from sentence_transformers import SentenceTransformer
from accelerate import Accelerator

from openicl.icl_dataset_reader import DatasetEncoder
from openicl.icl_retriever import BaseRetriever, TopkRetriever
from openicl.utils.collators import DataCollatorWithPaddingAndCuda
from openicl import PromptTemplate, PPLInferencer, DatasetReader
from openicl.icl_retriever import *
from openicl.icl_evaluator import *
from openicl.icl_inferencer.icl_base_inferencer import PPLInferencerOutputHandler

from openicl.utils.logging import get_logger
from openicl.utils.api_service import *

from util.concatenate import (
    collect_samples,
    simple_concatenate,
    merge_concatenate,
    random_concatenate,
)
from util.icl_ppl_fed_inferencer import PPLFedInferencer
from util.retriever import get_retriever, get_server_retriever
from util.misc import save_json, AverageMeter, setup_seed
from util.dataset import QueryBudgetDataset
from models import get_model


logger = get_logger(__name__)


class PPLFedOptBudgetInferencer(PPLFedInferencer):
    def __init__(
        self,
        model_name: Optional[str] = "gpt2-xl",
        tokenizer_name: Optional[str] = None,
        max_model_token_num: Optional[int] = None,
        model_config: Optional[PretrainedConfig] = None,
        batch_size: Optional[int] = 1,
        accelerator: Optional[Accelerator] = None,
        output_json_filepath: Optional[str] = "./server_icl_inference_output",
        output_json_filename: Optional[str] = "predictions",
        api_name: Optional[str] = None,
        labels: Optional[List] = None,
        model_parallel: Optional[bool] = False,
        args=None,
        sentence_transformers_model_name: Optional[str] = "all-mpnet-base-v2",
        device: Optional[str] = None,
        **kwargs,
    ) -> None:
        super().__init__(
            model_name,
            tokenizer_name,
            max_model_token_num,
            model_config,
            batch_size,
            accelerator,
            output_json_filepath,
            output_json_filename,
            api_name,
            labels,
            model_parallel,
            args,
            sentence_transformers_model_name,
            device,
            **kwargs,
        )
        self.budget_model = None

    def inference(
        self,
        retrievers: List[BaseRetriever],
        query_dataset: Union[Dataset, DatasetDict],
        query_split: Optional[str] = None,
        ice_template: Optional[PromptTemplate] = None,
        prompt_template: Optional[PromptTemplate] = None,
        output_json_filepath: Optional[str] = None,
        output_json_filename: Optional[str] = None,
        concat: Optional[str] = "simple",
        normalizing_str: Optional[str] = None,
        args=None,
    ):
        assert self.budget_model is not None

        num_clients = len(retrievers)
        if isinstance(query_dataset, DatasetDict):
            query_dataset = query_dataset[query_split]

        query_num = len(query_dataset)

        # 1. Preparation for output logs
        output_handler = PPLInferencerOutputHandler(self.accelerator)

        sub_predictions = []
        ppl = []
        ice = []  # {cid: [] for cid in range(num_clients)}

        if output_json_filepath is None:
            output_json_filepath = self.output_json_filepath
        if output_json_filename is None:
            output_json_filename = self.output_json_filename

        # get embedding for queries
        query_orig_idxs = query_dataset["idx"]
        query_datalist = retrievers[0].dataset_reader.generate_input_field_corpus(
            query_dataset
        )  # TODO: ugly code because of `dataset_reader.generate_input_field_corpus`
        query_dataloader = retrievers[0].create_dataloader(query_datalist)
        query_res_list = retrievers[0].forward(
            self.retriever_model,
            query_dataloader,
            orig_idxs=query_orig_idxs,
            process_bar=True,
            information="Embedding test data queries...",
        )  # each element is {'embed': np.array, 'metadata': {'id': int, 'len': int, 'text': sample_text}, 'idx': int}

        # get prediction of each query given embedding vector using budget_model
        budgets_predion = self._get_budgets_model_prediction(query_res_list)

        local_budgets = [
            [*content] for content in zip(*budgets_predion)
        ]  # list of num_clients lists, each local client's list is a list of integers, each integer represents the local budget for corresponding query

        # TODO: given each query embedding, loop through fed_retriever with predicted local budgets
        # 2. Get results of retrieval process
        server_ice_datasets = self._construct_server_ice_dataset_with_budgets(
            retrievers=retrievers,
            query_res_list=query_res_list,
            ice_budgets=local_budgets,
            concat=concat,
            output_json_filepath=output_json_filepath,
        )  # [q1_ice_Dataset, q2_ice_Dataset, ...]

        # 3. Get labels of all the classes (need to check all client's local dataset since NonIID partition)
        if self.labels is None:
            labels = []
            for cid in range(num_clients):
                labels.extend(
                    retrievers[cid].get_labels(
                        ice_template=ice_template, prompt_template=prompt_template
                    )
                )
            labels = list(set(labels))
        else:
            labels = self.labels

        # 4. Generate in-context examples indices for testing inputs
        ice_dataset_retrievers = []
        server_sample_orig_idx = {}  # [[] for _ in range(query_num)]
        for qid in trange(
            query_num,
            disable=not self.is_main_process,
            desc="Server side generating ICE: ",
        ):
            data_dict = DatasetDict(
                {
                    "train": server_ice_datasets[qid],
                    "test": Dataset.from_list([query_dataset[qid]]),
                }
            )
            data = DatasetReader(
                data_dict,
                input_columns=retrievers[0].dataset_reader.input_columns,
                output_column=retrievers[0].dataset_reader.output_column,
            )
            per_query_retriever = get_server_retriever(args)(
                data, ice_num=args.server_ice_num
            )

            # check whether need to do reorder on server side
            ice_num = len(server_ice_datasets[qid])
            if concat == "reorder":
                ice_idxs = per_query_retriever.retrieve(
                    model=self.retriever_model, use_trange=False
                )[0]
            else:
                ice_idxs = list(range(ice_num))

            q_orig_idx = per_query_retriever.test_ds[0]["idx"]
            server_sample_orig_idx[q_orig_idx] = per_query_retriever.index_ds.select(
                ice_idxs
            )["idx"]
            # use selected samples in each query's ICE dataset to generate ICE
            ice.append(
                per_query_retriever.generate_ice(ice_idxs, ice_template=ice_template)
            )
            ice_dataset_retrievers.append(per_query_retriever)

        output_handler.save_ice(ice)
        server_ice_index_file = f"{output_json_filepath}/server_ice_indices.json"
        save_json(server_sample_orig_idx, server_ice_index_file)

        # 5. Calculating PPL for prompts in each label's class
        for label in labels:
            index = 0
            prompt_list = []
            sub_ppl_list = []
            normalizing_prompt_list = []
            context_length_list = []

            # 5.1 Generate prompts of current label and truncate
            for qid in range(query_num):
                prompt = ice_dataset_retrievers[qid].generate_label_prompt(
                    0,
                    ice[qid],
                    label,
                    ice_template=ice_template,
                    prompt_template=prompt_template,
                    remain_sep=normalizing_str is not None,
                )

                # TODO: max_model_token_num is for generation task, temporory removed
                # if self.max_model_token_num is not None and self.api_name != "gpt3":
                #     prompt_token_num = self.get_input_token_num(prompt)
                #     while (
                #         len(ice_idx_list[idx]) > 0
                #         and prompt_token_num > self.max_model_token_num
                #     ):
                #         ice_idx_list[idx] = ice_idx_list[idx][:-1]
                #         ice[idx] = retriever.generate_ice(
                #             ice_idx_list[idx], ice_template=ice_template
                #         )
                #         prompt = retriever.generate_label_prompt(
                #             idx,
                #             ice[idx],
                #             label,
                #             ice_template=ice_template,
                #             prompt_template=prompt_template,
                #         )
                #         prompt_token_num = self.get_input_token_num(prompt)

                if normalizing_str is not None:
                    prompt_sep = prompt
                    if prompt_template is not None:
                        sep_token = prompt_template.sep_token
                    else:
                        sep_token = ice_template.sep_token
                    sep_pos = prompt_sep.find(sep_token)

                    context = prompt_sep[0:sep_pos]
                    answer = prompt_sep[sep_pos:].replace(sep_token, "")
                    prompt = context + answer
                    normalizing_prompt = normalizing_str + answer

                    context_length_list.append(self.get_input_token_num(context))
                    normalizing_prompt_list.append(normalizing_prompt)
                prompt_list.append(prompt)

            if normalizing_str is not None:
                normalizing_str_len = self.get_input_token_num(normalizing_str)

            # 5.2 Get PPL: loop through all queries with current give label
            logger.info(f"Calculating PPL for prompts labeled '{label}'")
            for idx in trange(
                0, len(prompt_list), self.batch_size, disable=not self.is_main_process
            ):
                sub_prompt_list = prompt_list[idx : idx + self.batch_size]
                if normalizing_str is not None:
                    sub_context_length_list = context_length_list[
                        idx : idx + self.batch_size
                    ]
                    sub_normalizing_prompt_list = normalizing_prompt_list[
                        idx : idx + self.batch_size
                    ]

                with torch.no_grad():
                    if normalizing_str is not None:
                        res1 = self.__get_ppl(
                            input_texts=sub_prompt_list,
                            mask_length=sub_context_length_list,
                        )
                        res2 = self.__get_ppl(
                            input_texts=sub_normalizing_prompt_list,
                            mask_length=[
                                normalizing_str_len for i in range(len(sub_prompt_list))
                            ],
                        )
                        sub_res = res1 - res2
                    else:
                        sub_res = self.__get_ppl(
                            sub_prompt_list
                        ).tolist()  # list of scores, length is batch size

                for batch_offset, (res, prompt) in enumerate(
                    zip(sub_res, sub_prompt_list)
                ):
                    sub_ppl_list.append(res)
                    output_handler.save_prompt_and_ppl(
                        label,
                        prompt[len(ice[idx + batch_offset]) :],
                        prompt,
                        res,
                        index,
                    )
                    index = index + 1
            ppl.append(sub_ppl_list)

        # 6. Get lowest PPL class as predictions
        ppl = list(zip(*ppl))
        for single_ppl in ppl:
            sub_predictions.append(labels[single_ppl.index(min(single_ppl))])
        output_handler.save_predictions(sub_predictions)

        # 7. Output
        output_handler.subprocess_write_to_json(
            output_json_filepath, output_json_filename
        )
        if self.accelerator is not None:
            self.accelerator.wait_for_everyone()
        output_handler.merge_to_main_process(output_json_filepath, output_json_filename)
        output_handler.write_to_json(output_json_filepath, output_json_filename)

        return [sample["prediction"] for sample in output_handler.results_dict.values()]

    def _construct_server_ice_dataset_with_budgets(
        self,
        retrievers: List[BaseRetriever],
        query_res_list: List[Dict[str, Any]],
        ice_budgets: List[List[int]],
        concat: Optional[str] = "simple",
        output_json_filepath: Optional[str] = None,
    ):
        num_clients = len(retrievers)
        query_num = len(query_res_list)

        ice_idx_dict = {}

        for cid in range(num_clients):
            print(f"Client {cid} retrieves local ICE samples...")
            ice_idx_dict[cid] = retrievers[cid].retrieve_with_budget(
                query_res_list, ice_budgets=ice_budgets[cid], model=self.retriever_model
            )

        samples_for_all_query = collect_samples(ice_idx_dict, retrievers)

        # ice original index in original training set
        local_sample_orig_idx = {}  # [[] for _ in range(query_num)]
        for qid in range(query_num):
            q_orig_idx = query_res_list[qid]["idx"]
            local_sample_orig_idx[q_orig_idx] = {}
            for cid in range(num_clients):
                orig_idx = samples_for_all_query[qid][cid]["idx"]
                local_sample_orig_idx[q_orig_idx][cid] = orig_idx

        local_ice_index_file = f"{output_json_filepath}/local_ice_indices.json"
        save_json(local_sample_orig_idx, local_ice_index_file)

        if concat in ["simple", "reorder"]:
            concat_func = simple_concatenate
        elif concat == "random":
            concat_func = random_concatenate
        elif concat == "merge":
            concat_func = merge_concatenate
        else:
            raise ValueError(
                f"concate should be either 'simple', 'merge', or 'reorder', rather than '{concat}'."
            )

        server_ice = []
        # for each test query, perform concatenate for server side ICE dataset based on all clients local ICE dataset
        for qid in trange(
            query_num,
            disable=False,
            desc="Construct server side ICE dataset for each query: ",
        ):
            ice_data = concat_func(
                samples_for_all_query[qid],
                chunk_sample_num=self.args.server_ice_num,
                seed=self.args.seed,
            )
            server_ice.append(ice_data)

        return server_ice

    def proxy_data_opt_budget(
        self,
        retrievers: List[BaseRetriever],
        query_dataset: Union[Dataset, DatasetDict],
        query_split: Optional[str] = None,
        ice_template: Optional[PromptTemplate] = None,
        prompt_template: Optional[PromptTemplate] = None,
        output_json_filepath: Optional[str] = None,
        output_json_filename: Optional[str] = "proxy_predictions",
        concat: Optional[str] = "reorder",
        normalizing_str: Optional[str] = None,
        inference: Optional[bool] = False,
        args=None,
        prefix="proxy",
        local_ice_num=None,
        cal_contribute=True,
    ):
        num_clients = len(retrievers)
        if isinstance(query_dataset, DatasetDict):
            query_dataset = query_dataset[query_split]

        query_num = len(query_dataset)

        # 1. Preparation for output logs
        output_handler = PPLInferencerOutputHandler(self.accelerator)

        sub_predictions = []
        ppl = []
        ice = []  # {cid: [] for cid in range(num_clients)}

        if output_json_filepath is None:
            output_json_filepath = self.output_json_filepath
        if output_json_filename is None:
            output_json_filename = "proxy_predictions"

        # 2. Get results of retrieval process for the whole proxy dataset
        if local_ice_num is not None:
            local_ice_num = args.server_ice_num
        server_ice_datasets = self._construct_server_ice_dataset(
            retrievers=retrievers,
            query_dataset=query_dataset,
            ice_num=local_ice_num,
            concat="reorder",
            output_json_filepath=output_json_filepath,
            prefix=prefix,
        )  # [q1_ice_Dataset, q2_ice_Dataset, ...]

        # 3. Get labels of all the classes (need to check all client's local dataset since NonIID partition)
        # do this only if `inference` is `True`
        if inference:
            if self.labels is None:
                labels = []
                for cid in range(num_clients):
                    labels.extend(
                        retrievers[cid].get_labels(
                            ice_template=ice_template, prompt_template=prompt_template
                        )
                    )
                labels = list(set(labels))
            else:
                labels = self.labels

        # 4. Generate in-context examples indices for testing inputs
        ice_dataset_retrievers = []
        server_sample_orig_idx = {}  # [[] for _ in range(query_num)]
        ice_client_source = {}  # list of source client ID for each proxy query
        opt_per_client_budget = {}
        for qid in trange(
            query_num,
            disable=not self.is_main_process,
            desc=f"Server calculates each client's Top-{args.server_ice_num} ICE contribution",
        ):
            data_dict = DatasetDict(
                {
                    "train": server_ice_datasets[qid],
                    "test": Dataset.from_list([query_dataset[qid]]),
                }
            )
            data = DatasetReader(
                data_dict, input_columns=["sentence"], output_column="label"
            )
            per_query_retriever = get_server_retriever(args)(
                data, ice_num=args.server_ice_num
            )

            # perform reorder on server side
            ice_num = len(server_ice_datasets[qid])
            ice_idxs = per_query_retriever.retrieve(
                model=self.retriever_model, use_trange=False
            )[0]

            q_orig_idx = per_query_retriever.test_ds[0]["idx"]
            server_sample_orig_idx[q_orig_idx] = per_query_retriever.index_ds.select(
                ice_idxs
            )["idx"]
            per_query_ice_cids = per_query_retriever.index_ds.select(ice_idxs)[
                "cid"
            ]  # get client ID of each selected ICE
            ice_client_source[q_orig_idx] = per_query_ice_cids

            # count each client's contribution in the server ICE set
            if cal_contribute:
                counter = Counter(per_query_ice_cids)
                opt_per_client_budget[q_orig_idx] = [
                    counter.get(k, 0) for k in range(num_clients)
                ]

            # use selected samples in each query's ICE dataset to generate ICE
            ice.append(
                per_query_retriever.generate_ice(ice_idxs, ice_template=ice_template)
            )
            ice_dataset_retrievers.append(per_query_retriever)

        # save proxy dataset server side ICE
        output_handler.save_ice(ice)
        server_ice_index_file = (
            f"{output_json_filepath}/{prefix}_server_ice_indices.json"
        )
        save_json(server_sample_orig_idx, server_ice_index_file)

        # 5. Calculating PPL for prompts in each label's class
        if inference:
            for label in labels:
                index = 0
                prompt_list = []
                sub_ppl_list = []
                normalizing_prompt_list = []
                context_length_list = []

                # 5.1 Generate prompts of current label and truncate
                for qid in range(query_num):
                    prompt = ice_dataset_retrievers[qid].generate_label_prompt(
                        0,
                        ice[qid],
                        label,
                        ice_template=ice_template,
                        prompt_template=prompt_template,
                        remain_sep=normalizing_str is not None,
                    )

                    # TODO: max_model_token_num is for generation task, temporory removed
                    # if self.max_model_token_num is not None and self.api_name != "gpt3":
                    #     prompt_token_num = self.get_input_token_num(prompt)
                    #     while (
                    #         len(ice_idx_list[idx]) > 0
                    #         and prompt_token_num > self.max_model_token_num
                    #     ):
                    #         ice_idx_list[idx] = ice_idx_list[idx][:-1]
                    #         ice[idx] = retriever.generate_ice(
                    #             ice_idx_list[idx], ice_template=ice_template
                    #         )
                    #         prompt = retriever.generate_label_prompt(
                    #             idx,
                    #             ice[idx],
                    #             label,
                    #             ice_template=ice_template,
                    #             prompt_template=prompt_template,
                    #         )
                    #         prompt_token_num = self.get_input_token_num(prompt)

                    if normalizing_str is not None:
                        prompt_sep = prompt
                        if prompt_template is not None:
                            sep_token = prompt_template.sep_token
                        else:
                            sep_token = ice_template.sep_token
                        sep_pos = prompt_sep.find(sep_token)

                        context = prompt_sep[0:sep_pos]
                        answer = prompt_sep[sep_pos:].replace(sep_token, "")
                        prompt = context + answer
                        normalizing_prompt = normalizing_str + answer

                        context_length_list.append(self.get_input_token_num(context))
                        normalizing_prompt_list.append(normalizing_prompt)
                    prompt_list.append(prompt)

                if normalizing_str is not None:
                    normalizing_str_len = self.get_input_token_num(normalizing_str)

                # 5.2 Get PPL: loop through all queries with current give label
                logger.info(f"Calculating PPL for prompts labeled '{label}'")
                for idx in trange(
                    0,
                    len(prompt_list),
                    self.batch_size,
                    disable=not self.is_main_process,
                ):
                    sub_prompt_list = prompt_list[idx : idx + self.batch_size]
                    if normalizing_str is not None:
                        sub_context_length_list = context_length_list[
                            idx : idx + self.batch_size
                        ]
                        sub_normalizing_prompt_list = normalizing_prompt_list[
                            idx : idx + self.batch_size
                        ]

                    with torch.no_grad():
                        if normalizing_str is not None:
                            res1 = self.__get_ppl(
                                input_texts=sub_prompt_list,
                                mask_length=sub_context_length_list,
                            )
                            res2 = self.__get_ppl(
                                input_texts=sub_normalizing_prompt_list,
                                mask_length=[
                                    normalizing_str_len
                                    for i in range(len(sub_prompt_list))
                                ],
                            )
                            sub_res = res1 - res2
                        else:
                            sub_res = self.__get_ppl(
                                sub_prompt_list
                            ).tolist()  # list of scores, length is batch size

                    for batch_offset, (res, prompt) in enumerate(
                        zip(sub_res, sub_prompt_list)
                    ):
                        sub_ppl_list.append(res)
                        output_handler.save_prompt_and_ppl(
                            label,
                            prompt[len(ice[idx + batch_offset]) :],
                            prompt,
                            res,
                            index,
                        )
                        index = index + 1
                ppl.append(sub_ppl_list)

            # 6. Get lowest PPL class as predictions
            ppl = list(zip(*ppl))
            for single_ppl in ppl:
                sub_predictions.append(labels[single_ppl.index(min(single_ppl))])
            output_handler.save_predictions(sub_predictions)

            # 7. Output
            output_handler.subprocess_write_to_json(
                output_json_filepath, output_json_filename
            )
            if self.accelerator is not None:
                self.accelerator.wait_for_everyone()
            output_handler.merge_to_main_process(
                output_json_filepath, output_json_filename
            )
            output_handler.write_to_json(output_json_filepath, output_json_filename)

            proxy_predictions = [
                sample["prediction"] for sample in output_handler.results_dict.values()
            ]
        else:
            proxy_predictions = None

        save_json(
            ice_client_source,
            os.path.join(output_json_filepath, f"{prefix}_ice_client_source.json"),
        )
        if cal_contribute:
            save_json(
                opt_per_client_budget,
                os.path.join(output_json_filepath, f"{prefix}_opt_client_budget.json"),
            )

        return ice_client_source, opt_per_client_budget, proxy_predictions

    def train_local_budget_model(
        self,
        retrievers: List[BaseRetriever],
        query_dataset: Union[Dataset, DatasetDict],
        query_split: Optional[str] = None,
        opt_client_budget: Dict[int, List[int]] = None,
        model_name: Optional[str] = "SMLP",
        model_width: Optional[int] = 380,
        epochs: Optional[int] = 50,
        lr: Optional[float] = 0.02,
        batch_size: Optional[int] = 8,
        train_ratio: Optional[float] = 1,
        seed: Optional[int] = 0,
    ):
        if isinstance(query_dataset, DatasetDict):
            query_dataset = query_dataset[query_split]

        assert train_ratio > 0
        assert train_ratio <= 1

        query_num = len(query_dataset)

        if self.args.retriever != "topk":
            raise ValueError("Currently only support 'topk' retriever.")

        self.retriever_model.eval()

        # transform proxy data query to embedding vectors
        query_orig_idxs = query_dataset["idx"]
        query_datalist = retrievers[0].dataset_reader.generate_input_field_corpus(
            query_dataset
        )  # TODO: ugly code because of `dataset_reader.generate_input_field_corpus`
        query_dataloader = retrievers[0].create_dataloader(query_datalist)

        res_list = retrievers[0].forward(
            self.retriever_model,
            query_dataloader,
            orig_idxs=query_orig_idxs,
            process_bar=True,
            information="Embedding proxy data queries...",
        )  # each element is {'embed': np.array, 'metadata': {'id': int, 'len': int, 'text': sample_text}, 'idx': int}

        # construct dataset for budget model training
        # split dataset into train-eval
        train_num = int(query_num * train_ratio)
        eval_num = query_num - train_num
        rng = np.random.default_rng(seed=seed)
        train_set_idxs = sorted(rng.choice(query_num, train_num, replace=False))
        eval_set_idxs = [i for i in range(query_num) if i not in train_set_idxs]

        # train data
        train_embeds = [res_list[idx]["embed"] for idx in train_set_idxs]
        train_budgets = [
            opt_client_budget[res_list[idx]["idx"]] for idx in train_set_idxs
        ]
        train_dataset = QueryBudgetDataset(embeds=train_embeds, budgets=train_budgets)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

        # eval data
        if eval_num > 0:
            eval_embeds = [res_list[idx]["embed"] for idx in eval_set_idxs]
            eval_budgets = [
                opt_client_budget[res_list[idx]["idx"]] for idx in eval_set_idxs
            ]
            eval_dataset = QueryBudgetDataset(embeds=eval_embeds, budgets=eval_budgets)
            eval_loader = DataLoader(eval_dataset, batch_size=16)

        # prepare the model
        setup_seed(seed)
        embedding_size = res_list[0]["embed"].shape  # size is (dim,)
        model = get_model(
            model_name=model_name,
            output_size=self.args.num_clients,
            data_shape=embedding_size,
            width=model_width,
        )
        model = model.to(self.device)

        optimizer = torch.optim.SGD(model.parameters(), lr=lr)
        criterion = torch.nn.MSELoss(reduction="none")

        train_loss_hist = []
        eval_loss_hist = []
        for epoch in range(epochs):
            # train
            train_loss = self._epoch_train_budget_model(
                model, criterion, optimizer, train_loader
            )
            train_loss_hist.append(train_loss)

            # eval
            if eval_num > 0:
                eval_loss = self._eval_budget_model(model, criterion, eval_loader)
            else:
                eval_loss = None
            eval_loss_hist.append(eval_loss)

            if (epoch + 1) % 10 == 0:
                logger.info(
                    f"Eopch [{epoch+1}/{epochs}], train_loss={train_loss :.10f}, eval_loss={eval_loss:f}"
                )

        self.train_loss_hist = train_loss_hist
        self.eval_loss_hist = eval_loss_hist
        self.budget_model = model
        torch.save(
            model.state_dict(), os.path.join(self.args.log_dir, "budget_model.pt")
        )
        self.budget_model_path = os.path.join(self.args.log_dir, "budget_model.pt")
        return train_loss_hist, eval_loss_hist, res_list  # train_dataset, eval_dataset

    def _epoch_train_budget_model(self, model, criterion, optimizer, data_loader):
        loss_stat = AverageMeter()
        model.train()
        for batch_idx, (embed, budget) in enumerate(data_loader):
            batch_size = budget.shape[0]

            embed = embed.to(self.device)
            budget = budget.to(self.device)
            outputs = model(embed)

            tmp_loss = criterion(outputs, budget)
            loss = tmp_loss.sum(dim=1).mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_stat.update(loss.item(), batch_size)

        return loss_stat.avg

    def _eval_budget_model(self, model, criterion, data_loader):
        model.eval()
        loss_stat = AverageMeter()
        with torch.no_grad():
            for batch_idx, (embed, budget) in enumerate(data_loader):
                batch_size = budget.shape[0]

                embed = embed.to(self.device)
                budget = budget.to(self.device)
                outputs = model(embed)

                tmp_loss = criterion(outputs, budget)
                loss = tmp_loss.sum(dim=1).mean()

                loss_stat.update(loss.item(), batch_size)

        return loss_stat.avg

    def evaluate_local_budget_model(
        self,
        retrievers: List[BaseRetriever],
        query_dataset: Union[Dataset, DatasetDict],
        query_split: Optional[str] = None,
        model: Optional[torch.nn.Module] = None,
        opt_client_budget: Dict[int, List[int]] = None,
    ):
        if isinstance(query_dataset, DatasetDict):
            query_dataset = query_dataset[query_split]

        assert model is not None

        self.retriever_model.eval()

        # transform proxy data query to embedding vectors
        query_orig_idxs = query_dataset["idx"]
        query_datalist = retrievers[0].dataset_reader.generate_input_field_corpus(
            query_dataset
        )  # TODO: ugly code because of `dataset_reader.generate_input_field_corpus`
        query_dataloader = retrievers[0].create_dataloader(query_datalist)

        res_list = retrievers[0].forward(
            self.retriever_model,
            query_dataloader,
            orig_idxs=query_orig_idxs,
            process_bar=True,
            information="Embedding query data queries...",
        )  # each element is {'embed': np.array, 'metadata': {'id': int, 'len': int, 'text': sample_text}, 'idx': int}

        # construct dataset for budget model training
        embeds = [res["embed"] for res in res_list]
        budgets = [opt_client_budget[res["idx"]] for res in res_list]
        budget_dataset = QueryBudgetDataset(embeds=embeds, budgets=budgets)
        embed_loader = DataLoader(budget_dataset, batch_size=16)

        model = model.to(self.device)  # TODO: check this
        criterion = torch.nn.MSELoss(reduction="none")

        model.eval()  # TODO: check this
        loss_stat = AverageMeter()
        with torch.no_grad():
            for batch_idx, (embed, budget) in enumerate(embed_loader):
                batch_size = budget.shape[0]

                embed = embed.to(self.device)
                budget = budget.to(self.device)
                outputs = model(embed)

                tmp_loss = criterion(outputs, budget)
                loss = tmp_loss.sum(dim=1).mean()

                loss_stat.update(loss.item(), batch_size)

        print(f"Eval loss={loss_stat.avg:.10f}")
        return loss_stat.avg

    def _get_budgets_model_prediction(self, query_res_list):
        total_budget = self.args.server_ice_num
        query_num = len(query_res_list)
        num_clients = self.args.num_clients
        self.budget_model.eval()
        model_outputs = []
        with torch.no_grad():
            for query_res in tqdm.tqdm(
                query_res_list,
                desc="Predict local budgets based on query embedding",
            ):
                embed = torch.tensor(
                    query_res["embed"], device=self.device, requires_grad=False
                ).reshape(1, -1)
                output = self.budget_model(embed)
                pred = output.detach().cpu().numpy()[0] * total_budget
                model_outputs.append(pred.tolist())

        budgets_prediction = []
        for qid in range(query_num):
            tmp = [int(budget) for budget in model_outputs[qid]]
            # TODO: randomly make up budget number
            if sum(tmp) < total_budget:
                selected_cids = sorted(
                    np.random.choice(
                        num_clients, total_budget - sum(tmp), replace=False
                    )
                )
                for cid in selected_cids:
                    tmp[cid] += 1

            budgets_prediction.append(tmp)

        return budgets_prediction

    def __get_ppl(self, input_texts: List[str], mask_length=None):
        return self._PPLInferencer__get_ppl(input_texts, mask_length)

    def _get_ppl(self, input_texts: List[str], mask_length=None):
        # TODO: only for debug
        return self._PPLInferencer__get_ppl(input_texts, mask_length)
