"""PPL Federated Inferencer with Budget Prediction prediciton"""

from typing import List, Union, Optional, Dict, Any
import tqdm
from tqdm import trange
from collections import Counter
import json
import copy
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt


import torch
from torch.utils.data import DataLoader

from datasets import concatenate_datasets, load_dataset, DatasetDict, Dataset
from transformers import PretrainedConfig
from transformers import AutoTokenizer

from sentence_transformers import SentenceTransformer
from accelerate import Accelerator

from openicl.icl_dataset_reader import DatasetEncoder
from openicl.icl_retriever import BaseRetriever, TopkRetriever
from openicl.utils.collators import DataCollatorWithPaddingAndCuda
from openicl import PromptTemplate, PPLInferencer, DatasetReader
from openicl.icl_retriever import *
from openicl.icl_evaluator import *
from openicl.icl_inferencer.icl_base_inferencer import PPLInferencerOutputHandler

from openicl.utils.logging import get_logger
from openicl.utils.api_service import *

from util.concatenate import (
    collect_samples,
    simple_concatenate,
    merge_concatenate,
    random_concatenate,
)
from util.icl_ppl_fed_inferencer import PPLFedInferencer
from util.retriever import get_retriever, get_server_retriever
from util.misc import save_json, AverageMeter, setup_seed
from util.dataset import QueryBudgetDatasetNew
from util.budget_model_tools import *

from models import get_model


logger = get_logger(__name__)


class PPLFedBudgetModelInferencer(PPLFedInferencer):
    def __init__(
        self,
        model_name: Optional[str] = "gpt2-xl",
        tokenizer_name: Optional[str] = None,
        max_model_token_num: Optional[int] = None,
        model_config: Optional[PretrainedConfig] = None,
        batch_size: Optional[int] = 1,
        accelerator: Optional[Accelerator] = None,
        output_json_filepath: Optional[str] = "./server_icl_inference_output",
        output_json_filename: Optional[str] = "predictions",
        api_name: Optional[str] = None,
        labels: Optional[List] = None,
        model_parallel: Optional[bool] = False,
        args=None,
        sentence_transformers_model_name: Optional[str] = "all-mpnet-base-v2",
        device: Optional[str] = None,
        **kwargs,
    ) -> None:
        super().__init__(
            model_name,
            tokenizer_name,
            max_model_token_num,
            model_config,
            batch_size,
            accelerator,
            output_json_filepath,
            output_json_filename,
            api_name,
            labels,
            model_parallel,
            args,
            sentence_transformers_model_name,
            device,
            **kwargs,
        )
        self.budget_model = None

    def inference(
        self,
        retrievers: List[BaseRetriever],
        query_dataset: Union[Dataset, DatasetDict],
        query_split: Optional[str] = None,
        ice_template: Optional[PromptTemplate] = None,
        prompt_template: Optional[PromptTemplate] = None,
        output_json_filepath: Optional[str] = None,
        output_json_filename: Optional[str] = None,
        output_model_filepath: Optional[str] = None,
        concat: Optional[str] = "simple",
        normalizing_str: Optional[str] = None,
        strategy: str = "medium",
        buffer: int = 0,
        args=None,
        use_budget_model=True,
        query_budgets=None,
    ):
        # assert self.budget_model is not None

        num_clients = len(retrievers)
        if isinstance(query_dataset, DatasetDict):
            query_dataset = query_dataset[query_split]

        query_num = len(query_dataset)

        # 1. Preparation for output logs
        output_handler = PPLInferencerOutputHandler(self.accelerator)

        sub_predictions = []
        ppl = []
        ice = []  # {cid: [] for cid in range(num_clients)}

        if output_json_filepath is None:
            output_json_filepath = self.output_json_filepath
        if output_json_filename is None:
            output_json_filename = self.output_json_filename

        # get embedding for queries
        query_orig_idxs = query_dataset["idx"]
        query_datalist = retrievers[0].dataset_reader.generate_input_field_corpus(
            query_dataset
        )  # TODO: ugly code because of `dataset_reader.generate_input_field_corpus`
        query_dataloader = retrievers[0].create_dataloader(query_datalist)
        query_res_list = retrievers[0].forward(
            self.retriever_model,
            query_dataloader,
            orig_idxs=query_orig_idxs,
            process_bar=True,
            information="Embedding test data queries...",
        )  # each element is {'embed': np.array, 'metadata': {'id': int, 'len': int, 'text': sample_text}, 'idx': int}

        query_embed_list = [res["embed"] for res in query_res_list]

        if use_budget_model:
            pred_local_budgets = self.model_based_budgets(
                query_embed_list,
                query_orig_idxs,
                strategy,
                buffer,
                output_json_filepath,
                output_model_filepath,
            )
        else:
            assert query_budgets is not None
            pred_local_budgets = query_budgets

        # 2. Get results of retrieval process
        server_ice_datasets = self._construct_server_ice_dataset_with_budgets(
            retrievers=retrievers,
            query_res_list=query_res_list,
            ice_budgets=pred_local_budgets,
            concat=concat,
            output_json_filepath=output_json_filepath,
        )  # [q1_ice_Dataset, q2_ice_Dataset, ...]

        # 3. Get labels of all the classes (need to check all client's local dataset since NonIID partition)
        if self.labels is None:
            labels = []
            for cid in range(num_clients):
                labels.extend(
                    retrievers[cid].get_labels(
                        ice_template=ice_template, prompt_template=prompt_template
                    )
                )
            labels = list(set(labels))
        else:
            labels = self.labels

        # 4. Generate in-context examples indices for testing inputs
        ice_dataset_retrievers = []
        server_sample_orig_idx = {}  # [[] for _ in range(query_num)]
        for qid in trange(
            query_num,
            disable=not self.is_main_process,
            desc="Server side generating ICE: ",
        ):
            data_dict = DatasetDict(
                {
                    "train": server_ice_datasets[qid],
                    "test": Dataset.from_list([query_dataset[qid]]),
                }
            )
            data = DatasetReader(
                data_dict,
                input_columns=retrievers[0].dataset_reader.input_columns,
                output_column=retrievers[0].dataset_reader.output_column,
            )
            per_query_retriever = get_server_retriever(args)(
                data, ice_num=args.server_ice_num
            )

            # check whether need to do reorder on server side
            ice_num = len(server_ice_datasets[qid])
            if concat == "reorder":
                ice_idxs = per_query_retriever.retrieve(
                    model=self.retriever_model, use_trange=False
                )[0]
            else:
                ice_idxs = list(range(ice_num))

            q_orig_idx = per_query_retriever.test_ds[0]["idx"]
            server_sample_orig_idx[q_orig_idx] = per_query_retriever.index_ds.select(
                ice_idxs
            )["idx"]
            # use selected samples in each query's ICE dataset to generate ICE
            ice.append(
                per_query_retriever.generate_ice(ice_idxs, ice_template=ice_template)
            )
            ice_dataset_retrievers.append(per_query_retriever)

        output_handler.save_ice(ice)
        server_ice_index_file = f"{output_json_filepath}/server_ice_indices.json"
        save_json(server_sample_orig_idx, server_ice_index_file)

        # 5. Calculating PPL for prompts in each label's class
        for label in labels:
            index = 0
            prompt_list = []
            sub_ppl_list = []
            normalizing_prompt_list = []
            context_length_list = []

            # 5.1 Generate prompts of current label and truncate
            for qid in range(query_num):
                prompt = ice_dataset_retrievers[qid].generate_label_prompt(
                    0,
                    ice[qid],
                    label,
                    ice_template=ice_template,
                    prompt_template=prompt_template,
                    remain_sep=normalizing_str is not None,
                )

                # TODO: max_model_token_num is for generation task, temporory removed
                # if self.max_model_token_num is not None and self.api_name != "gpt3":
                #     prompt_token_num = self.get_input_token_num(prompt)
                #     while (
                #         len(ice_idx_list[idx]) > 0
                #         and prompt_token_num > self.max_model_token_num
                #     ):
                #         ice_idx_list[idx] = ice_idx_list[idx][:-1]
                #         ice[idx] = retriever.generate_ice(
                #             ice_idx_list[idx], ice_template=ice_template
                #         )
                #         prompt = retriever.generate_label_prompt(
                #             idx,
                #             ice[idx],
                #             label,
                #             ice_template=ice_template,
                #             prompt_template=prompt_template,
                #         )
                #         prompt_token_num = self.get_input_token_num(prompt)

                if normalizing_str is not None:
                    prompt_sep = prompt
                    if prompt_template is not None:
                        sep_token = prompt_template.sep_token
                    else:
                        sep_token = ice_template.sep_token
                    sep_pos = prompt_sep.find(sep_token)

                    context = prompt_sep[0:sep_pos]
                    answer = prompt_sep[sep_pos:].replace(sep_token, "")
                    prompt = context + answer
                    normalizing_prompt = normalizing_str + answer

                    context_length_list.append(self.get_input_token_num(context))
                    normalizing_prompt_list.append(normalizing_prompt)
                prompt_list.append(prompt)

            if normalizing_str is not None:
                normalizing_str_len = self.get_input_token_num(normalizing_str)

            # 5.2 Get PPL: loop through all queries with current give label
            logger.info(f"Calculating PPL for prompts labeled '{label}'")
            for idx in trange(
                0, len(prompt_list), self.batch_size, disable=not self.is_main_process
            ):
                sub_prompt_list = prompt_list[idx : idx + self.batch_size]
                if normalizing_str is not None:
                    sub_context_length_list = context_length_list[
                        idx : idx + self.batch_size
                    ]
                    sub_normalizing_prompt_list = normalizing_prompt_list[
                        idx : idx + self.batch_size
                    ]

                with torch.no_grad():
                    if normalizing_str is not None:
                        res1 = self.__get_ppl(
                            input_texts=sub_prompt_list,
                            mask_length=sub_context_length_list,
                        )
                        res2 = self.__get_ppl(
                            input_texts=sub_normalizing_prompt_list,
                            mask_length=[
                                normalizing_str_len for i in range(len(sub_prompt_list))
                            ],
                        )
                        sub_res = res1 - res2
                    else:
                        sub_res = self.__get_ppl(
                            sub_prompt_list
                        ).tolist()  # list of scores, length is batch size

                for batch_offset, (res, prompt) in enumerate(
                    zip(sub_res, sub_prompt_list)
                ):
                    sub_ppl_list.append(res)
                    output_handler.save_prompt_and_ppl(
                        label,
                        prompt[len(ice[idx + batch_offset]) :],
                        prompt,
                        res,
                        index,
                    )
                    index = index + 1
            ppl.append(sub_ppl_list)

        # 6. Get lowest PPL class as predictions
        ppl = list(zip(*ppl))
        for single_ppl in ppl:
            sub_predictions.append(labels[single_ppl.index(min(single_ppl))])
        output_handler.save_predictions(sub_predictions)

        # 7. Output
        output_handler.subprocess_write_to_json(
            output_json_filepath, output_json_filename
        )
        if self.accelerator is not None:
            self.accelerator.wait_for_everyone()
        output_handler.merge_to_main_process(output_json_filepath, output_json_filename)
        output_handler.write_to_json(output_json_filepath, output_json_filename)

        return [sample["prediction"] for sample in output_handler.results_dict.values()]

    def _construct_server_ice_dataset_with_budgets(
        self,
        retrievers: List[BaseRetriever],
        query_res_list: List[Dict[str, Any]],
        ice_budgets: List[List[int]],
        concat: Optional[str] = "simple",
        output_json_filepath: Optional[str] = None,
    ):
        num_clients = len(retrievers)
        query_num = len(query_res_list)

        ice_idx_dict = {}

        for cid in range(num_clients):
            print(f"Client {cid} retrieves local ICE samples...")
            ice_idx_dict[cid] = retrievers[cid].retrieve_with_budget(
                query_res_list, ice_budgets=ice_budgets[cid], model=self.retriever_model
            )

        samples_for_all_query = collect_samples(ice_idx_dict, retrievers)

        # ice original index in original training set
        local_sample_orig_idx = {}  # [[] for _ in range(query_num)]
        for qid in range(query_num):
            q_orig_idx = query_res_list[qid]["idx"]
            local_sample_orig_idx[q_orig_idx] = {}
            for cid in range(num_clients):
                orig_idx = samples_for_all_query[qid][cid]["idx"]
                local_sample_orig_idx[q_orig_idx][cid] = orig_idx

        local_ice_index_file = f"{output_json_filepath}/local_ice_indices.json"
        save_json(local_sample_orig_idx, local_ice_index_file)

        if concat in ["simple", "reorder"]:
            concat_func = simple_concatenate
        elif concat == "random":
            concat_func = random_concatenate
        elif concat == "merge":
            concat_func = merge_concatenate
        else:
            raise ValueError(
                f"concate should be either 'simple', 'merge', or 'reorder', rather than '{concat}'."
            )

        server_ice = []
        # for each test query, perform concatenate for server side ICE dataset based on all clients local ICE dataset
        for qid in trange(
            query_num,
            disable=False,
            desc="Construct server side ICE dataset for each query: ",
        ):
            ice_data = concat_func(
                samples_for_all_query[qid],
                chunk_sample_num=self.args.server_ice_num,
                seed=self.args.seed,
            )
            server_ice.append(ice_data)

        return server_ice

    def opt_budget_allocation(
        self,
        retrievers: List[BaseRetriever],
        query_dataset: Union[Dataset, DatasetDict],
        query_split: Optional[str] = None,
        ice_template: Optional[PromptTemplate] = None,
        prompt_template: Optional[PromptTemplate] = None,
        output_json_filepath: Optional[str] = None,
        output_json_filename: Optional[str] = "proxy_predictions",
        args=None,
        prefix: str = "proxy",
        local_ice_num: int = None,
        server_ice_num: int = None,
    ):
        """Get the optimal budget allocation for ``

        Args:
            retrievers (List[BaseRetriever]): _description_
            query_dataset (Union[Dataset, DatasetDict]): _description_
            query_split (Optional[str], optional): _description_. Defaults to None.
            ice_template (Optional[PromptTemplate], optional): _description_. Defaults to None.
            prompt_template (Optional[PromptTemplate], optional): _description_. Defaults to None.
            output_json_filepath (Optional[str], optional): _description_. Defaults to None.
            output_json_filename (Optional[str], optional): _description_. Defaults to "proxy_predictions".
            concat (Optional[str], optional): _description_. Defaults to "reorder".
            normalizing_str (Optional[str], optional): _description_. Defaults to None.
            inference (Optional[bool], optional): _description_. Defaults to False.
            args (_type_, optional): _description_. Defaults to None.
            prefix (str, optional): _description_. Defaults to "proxy".
            local_ice_num (_type_, optional): Local ICE number when doing optimal budget estimation, can be as large as server side ICE number, or even larger. Defaults to None.

        Returns:
            _type_: _description_
        """
        logger.info(f"===== Construct optimal budget allocation for {prefix} set ====")
        num_clients = len(retrievers)
        if isinstance(query_dataset, DatasetDict):
            query_dataset = query_dataset[query_split]

        query_num = len(query_dataset)

        # 1. Preparation for output logs
        ice = []  # {cid: [] for cid in range(num_clients)}
        if output_json_filepath is None:
            output_json_filepath = self.output_json_filepath
        if output_json_filename is None:
            output_json_filename = "proxy_predictions"

        # 2. Get results of retrieval process for the whole proxy dataset
        if local_ice_num is None:
            local_ice_num = server_ice_num

        server_ice_datasets = self._construct_server_ice_dataset(
            retrievers=retrievers,
            query_dataset=query_dataset,
            ice_num=local_ice_num,
            concat="reorder",
            output_json_filepath=output_json_filepath,
            prefix=prefix,
        )  # [q1_ice_Dataset, q2_ice_Dataset, ...]

        # 4. Generate in-context examples indices for testing inputs
        ice_dataset_retrievers = []
        server_sample_orig_idx = {}  # [[] for _ in range(query_num)]
        ice_client_source = {}  # list of source client ID for each proxy query
        opt_per_client_budget = (
            {}
        )  # {q_orig_idx: [client1_budget, client2_budget, client3_budget, ...], ....}
        for qid in trange(
            query_num,
            disable=not self.is_main_process,
            desc=f"Server calculates each client's Top-{server_ice_num} ICE contribution",
        ):
            data_dict = DatasetDict(
                {
                    "train": server_ice_datasets[qid],
                    "test": Dataset.from_list([query_dataset[qid]]),
                }
            )
            data = DatasetReader(
                data_dict, input_columns=["sentence"], output_column="label"
            )
            per_query_retriever = get_server_retriever(args)(
                data, ice_num=server_ice_num
            )

            # perform reorder on server side
            ice_num = len(server_ice_datasets[qid])
            ice_idxs = per_query_retriever.retrieve(
                model=self.retriever_model, use_trange=False
            )[0]

            q_orig_idx = per_query_retriever.test_ds[0]["idx"]
            server_sample_orig_idx[q_orig_idx] = per_query_retriever.index_ds.select(
                ice_idxs
            )["idx"]
            per_query_ice_cids = per_query_retriever.index_ds.select(ice_idxs)[
                "cid"
            ]  # get client ID of each selected ICE
            ice_client_source[q_orig_idx] = per_query_ice_cids

            # count each client's contribution in the server ICE set
            counter = Counter(per_query_ice_cids)
            opt_per_client_budget[q_orig_idx] = [
                counter.get(k, 0) for k in range(num_clients)
            ]

            # use selected samples in each query's ICE dataset to generate ICE
            ice.append(
                per_query_retriever.generate_ice(ice_idxs, ice_template=ice_template)
            )
            ice_dataset_retrievers.append(per_query_retriever)

        # save proxy dataset server side ICE
        server_ice_index_file = (
            f"{output_json_filepath}/{prefix}_server_ice_indices.json"
        )
        save_json(server_sample_orig_idx, server_ice_index_file)
        save_json(
            ice_client_source,
            os.path.join(output_json_filepath, f"{prefix}_ice_client_source.json"),
        )
        save_json(
            opt_per_client_budget,
            os.path.join(output_json_filepath, f"{prefix}_opt_client_budget.json"),
        )
        logger.info(
            f"{prefix} set optimal budget allocation saved to {prefix}_opt_client_budget.json."
        )

        # get embedding of query dataset
        query_orig_idxs = query_dataset["idx"]
        query_datalist = retrievers[0].dataset_reader.generate_input_field_corpus(
            query_dataset
        )
        query_dataloader = retrievers[0].create_dataloader(query_datalist)
        query_res_list = retrievers[0].forward(
            self.retriever_model,
            query_dataloader,
            orig_idxs=query_orig_idxs,
            process_bar=True,
            information=f"Embedding {prefix} data queries...",
        )

        # save the embedding
        torch.save(
            {"res_list": query_res_list},
            os.path.join(output_json_filepath, f"{prefix}_res_list.pt"),
        )
        logger.info(f"{prefix} res_list saved to {prefix}_res_list.pt")

        return ice_client_source, opt_per_client_budget, query_res_list

    def train_local_budget_model(
        self,
        query_res_list: List[Dict[Any, Any]],
        opt_per_client_budget: Dict[int, List[int]] = None,
        model_name: Optional[str] = "SMLP",
        model_width: Optional[int] = 300,
        num_classes_per_client: Optional[int] = 3,
        epochs: Optional[int] = 300,
        lr: Optional[float] = 0.02,
        batch_size: Optional[int] = 8,
        train_ratio: Optional[float] = 0.8,
        output_json_filepath: Optional[str] = None,
        output_model_filepath: Optional[str] = None,
        seed: Optional[int] = 0,
    ):
        assert train_ratio > 0
        assert train_ratio < 1

        query_num = len(query_res_list)

        if self.args.retriever != "topk":
            raise ValueError("Currently only support 'topk' retriever.")

        # ========== get basic info of query dataset
        query_orig_idxs = [res["idx"] for res in query_res_list]
        embed_mat = np.array([res["embed"] for res in query_res_list])
        opt_budgets_list = [
            opt_per_client_budget[q_orig_idx] for q_orig_idx in query_orig_idxs
        ]
        per_client_budgets = [
            [*x] for x in zip(*opt_budgets_list)
        ]  # [[b1, ..., ]_1, [b1, ...]_2,  [b1, ...]_{num_clients}]
        num_clients = len(per_client_budgets)

        # ========== auto decide budget value allocation into budget classes
        budgets_allocations = {}
        budget_label_ranges = {}
        client_budgets_labels = []  # List of each client's label list
        budget_label_cnt = {}
        fig, axs = plt.subplots(1, num_clients, figsize=(12 * num_clients, 7))
        fig.subplots_adjust(
            hspace=0.3, wspace=0.1
        )  # Adjust the space between subplots, 'hspace' is vertical space, 'wspace' is horizontal space
        for cid in range(num_clients):
            budget_vals = []
            sample_cnt = []
            # budget_val_space = list(set(per_client_budgets[cid]))
            tmp_cnt = Counter(per_client_budgets[cid])
            max_budget_val = max(tmp_cnt.keys())  # max budget value on current client
            for val in range(max_budget_val + 1):
                budget_vals.append(val)
                sample_cnt.append(tmp_cnt.get(val, 0))

            # allocate budget values to multiple budget classes
            allocations = basket_assign_brute_force(sample_cnt, num_classes_per_client)
            budgets_allocations[cid] = copy.deepcopy(allocations)

            # assign budgets for each sample based on budget
            single_client_budget_label_ranges = get_basket_range(
                budget_vals, allocations
            )
            budget_label_ranges[cid] = copy.deepcopy(single_client_budget_label_ranges)
            # map budget values to labels
            labels = find_group_indices(
                single_client_budget_label_ranges, per_client_budgets[cid]
            )
            client_budgets_labels.append(labels)

            # distribution
            tmp = Counter(labels)
            label_cnt = {b: tmp[b] for b in sorted(tmp.keys())}
            budget_label_cnt[cid] = copy.deepcopy(label_cnt)

            # plot the local budget distribution figure for each client
            df = pd.DataFrame(per_client_budgets[cid], columns=["Numbers"])
            number_counts = df["Numbers"].value_counts().reset_index()
            number_counts.columns = ["Number", "Frequency"]
            number_counts_desc = number_counts.sort_values(by="Number")
            # Create a bar plot using seaborn
            sns.barplot(x="Number", y="Frequency", data=number_counts_desc, ax=axs[cid])
            # Adding titles and labels
            axs[cid].set_title(
                f"Frequency of proxy set budget on client-{cid}\n Budget label range for each class: {budget_label_ranges[cid]}\n Budget Label count: {budget_label_cnt[cid]}"
            )
            axs[cid].set_xlabel("Local Budget Value")
            axs[cid].set_ylabel("Frequency")

        fig.savefig(
            os.path.join(output_json_filepath, "barplot-local-budget.png"),
            dpi=600,
            bbox_inches="tight",
        )
        logger.info("Local budget distribution barplot save.")
        plt.close()

        # train/val split for both embedding and all clients' local budget labels
        train_X, train_y, val_X, val_y = train_val_split(
            embed_mat, client_budgets_labels, train_ratio=train_ratio, seed=seed
        )
        logger.info(
            f"Shape of train embedding:{train_X.shape}; train labels:{len(train_y[0])}x{len(train_y)}"
        )
        logger.info(
            f"Shape of val embedding:{val_X.shape}; val labels: {len(val_y[0])}x{len(val_y)}"
        )

        # add budget info into results_dict
        results_dict = {
            "budget_label_ranges": budget_label_ranges,
            "budget_label_cnt": budget_label_cnt,
        }

        # ========== train budget model for each client
        all_clients_best_val_accs = {}
        for cid in range(num_clients):
            logger.info(f"Train budget model for Client {cid}...")
            checkpoint_path = os.path.join(
                output_model_filepath, f"best_budget_model-client-{cid}.pt"
            )
            setup_seed(seed)

            # ------- prepare train/val dataloader
            train_dataset = QueryBudgetDatasetNew(train_X, train_y[cid])
            train_loader = torch.utils.data.DataLoader(
                train_dataset, batch_size=batch_size, shuffle=True
            )
            val_dataset = QueryBudgetDatasetNew(val_X, val_y[cid])
            val_loader = torch.utils.data.DataLoader(
                val_dataset, batch_size=batch_size, shuffle=False
            )

            # ------- prepare initial budget model
            embedding_size = (self.retriever_model.get_sentence_embedding_dimension(),)
            budget_model = get_model(
                model_name=model_name,
                output_size=num_classes_per_client,
                data_shape=embedding_size,
                width=model_width,
            )
            budget_model = budget_model.to(self.device)
            optimizer = torch.optim.SGD(budget_model.parameters(), lr=lr)
            criterion = torch.nn.CrossEntropyLoss(reduction="mean")

            train_loss_hist = []
            val_loss_hist = []
            train_acc_hist = []
            val_acc_hist = []
            train_per_class_acc_hist = [[] for _ in range(num_classes_per_client)]
            val_per_class_acc_hist = [[] for _ in range(num_classes_per_client)]
            best_val_acc = 0
            best_epoch = 0

            # ------ train loop
            for epoch in range(epochs):
                # train
                train_loss, train_acc, train_per_class_acc = epoch_train_budget_model(
                    budget_model,
                    criterion,
                    optimizer,
                    train_loader,
                    num_classes_per_client,
                    self.device,
                )
                train_loss_hist.append(train_loss)
                train_acc_hist.append(train_acc)
                for c in range(num_classes_per_client):
                    train_per_class_acc_hist[c].append(train_per_class_acc[c])

                # validation
                val_loss, val_acc, val_per_class_acc = eval_budget_model(
                    budget_model,
                    criterion,
                    val_loader,
                    num_classes_per_client,
                    self.device,
                )
                val_loss_hist.append(val_loss)
                val_acc_hist.append(val_acc)
                for c in range(num_classes_per_client):
                    val_per_class_acc_hist[c].append(val_per_class_acc[c])

                # save the best checkpoint
                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    best_epoch = epoch
                    torch.save(budget_model.state_dict(), checkpoint_path)
                    # logger.info(f"Saved for epoch {epoch} as best checkpoint")

                if (epoch + 1) % 50 == 0:
                    logger.info(
                        f"Eopch [{epoch+1}/{epochs}], train_loss={train_loss :.10f}, train_acc={train_acc:.2f}, val_loss={val_loss:f}, val_acc={val_acc:.2f}"
                    )

            # ------ add train hist into results dict, and save
            all_clients_best_val_accs[f"client-{cid}-best-val-acc"] = best_val_acc
            results_dict[f"client-{cid}-train_acc_hist"] = copy.deepcopy(train_acc_hist)
            results_dict[f"client-{cid}-val_acc_hist"] = copy.deepcopy(val_acc_hist)
            results_dict[f"client-{cid}-train_loss_hist"] = copy.deepcopy(
                train_loss_hist
            )
            results_dict[f"client-{cid}-val_loss_hist"] = copy.deepcopy(val_loss_hist)
            results_dict[f"client-{cid}-train_per_class_acc_hist"] = copy.deepcopy(
                train_per_class_acc_hist
            )
            results_dict[f"client-{cid}-val_per_class_acc_hist"] = copy.deepcopy(
                val_per_class_acc_hist
            )
            results_dict[f"client-{cid}-best-val_acc"] = copy.deepcopy(best_val_acc)
            torch.save(
                results_dict, os.path.join(output_model_filepath, f"train-record.pt")
            )
            logger.info(f"client {cid} training history saved")

            # ------ plot training history and save figure
            fig, axs = plt.subplots(1, 3, figsize=(20, 5))
            x = list(range(epochs))
            axs[0].plot(x, train_acc_hist, label="train overall acc", linestyle="-")
            axs[0].plot(x, val_acc_hist, label="val overall acc", linestyle=":")
            axs[0].set_title(f"Best overall val acc={max(val_acc_hist):.4f}")
            axs[0].set_ylabel("Overall Accuracy")
            axs[0].set_xlabel("Epoch")
            axs[0].legend()

            colors = ["royalblue", "orange", "purple", "green"]
            for c in range(num_classes_per_client):
                axs[1].plot(
                    x,
                    train_per_class_acc_hist[c],
                    label=f"class-{c}-{budget_label_ranges[cid][c]} train acc",
                    linestyle="-",
                    color=colors[c],
                )
                axs[1].plot(
                    x,
                    val_per_class_acc_hist[c],
                    label=f"class-{c}-{budget_label_ranges[cid][c]} val acc",
                    linestyle=":",
                    color=colors[c],
                )
            axs[1].set_title("class-[class_id]-[$budget_{min}$, $budget_{max}$]")
            axs[1].set_ylabel("Per-class Accuracy")
            axs[1].set_xlabel("Epoch")
            axs[1].legend()

            axs[2].plot(x, train_loss_hist, label="train loss", linestyle="-")
            axs[2].plot(x, val_loss_hist, label="val loss", linestyle=":")
            axs[2].set_title("CrossEntropy History")
            axs[2].set_ylabel("CrossEntropy Loss")
            axs[2].set_xlabel("Epoch")
            axs[2].legend()

            fig.suptitle(
                f"Client {cid} ProxyData Embedding->BudgetLabel Training: {model_name}-W={model_width}-lr={lr}-E={epochs}",
                fontsize=15,
            )
            plt.tight_layout()
            fig.savefig(
                os.path.join(output_model_filepath, f"client-{cid}.png"), dpi=600
            )
            plt.close()
            logger.info(f"Client {cid} figure saved.")
            # logger.info(" " * 10 + "-" * 15 + " " * 10)

        return all_clients_best_val_accs

    def _predict_budgets_label(
        self,
        embed_list: List[Any],
        budget_models: List[torch.nn.Module],
    ):
        num_clients = len(budget_models)
        budget_pred_labels = [[] for _ in range(num_clients)]
        budget_outputs = [[] for _ in range(num_clients)]
        for cid in range(num_clients):
            budget_models[cid].eval()

        # ----- use budget model to predict budget range
        for embed_val in embed_list:
            embed = torch.tensor(
                embed_val, device=self.device, requires_grad=False
            ).reshape(1, -1)
            # get the prediction of budget on each local budget model
            for cid in range(num_clients):
                with torch.no_grad():
                    output = budget_models[cid](embed)
                    _, pred_label = torch.max(output, 1)
                    budget_outputs[cid].append(
                        output.detach().cpu().numpy()[0]
                    )  # output logits value
                    budget_pred_labels[cid].append(
                        pred_label.detach().cpu().numpy()[0].item()
                    )

        return budget_pred_labels, budget_outputs

    def _create_and_load_budget_model(
        self,
        output_model_filepath: Optional[str] = None,
    ):
        # ----- get budget model info
        budget_model_info_file = os.path.join(output_model_filepath, "train-record.pt")
        budget_model_info = torch.load(budget_model_info_file)
        budget_label_ranges = budget_model_info["budget_label_ranges"]
        num_clients = len(budget_label_ranges)
        budget_model_num_classes = {
            cid: len(budget_label_ranges[cid]) for cid in range(num_clients)
        }
        logger.info(
            f"Budget model number of classes on each client: {budget_model_num_classes}"
        )

        # ---- create and load the budget models
        embedding_size = (self.retriever_model.get_sentence_embedding_dimension(),)
        budget_models = [None for cid in range(num_clients)]
        for cid in range(num_clients):
            budget_models[cid] = get_model(
                self.args.budget_model_name,
                output_size=budget_model_num_classes[cid],
                data_shape=embedding_size,
                width=self.args.budget_model_width,
            )
            checkpoint_file = os.path.join(
                output_model_filepath, f"best_budget_model-client-{cid}.pt"
            )
            budget_models[cid].load_state_dict(
                torch.load(checkpoint_file, map_location="cpu")
            )
            budget_models[cid] = budget_models[cid].to(self.device)
            logger.info(f"Client {cid} budget model loaded to GPU")
        return budget_models, budget_label_ranges

    def model_based_budgets(
        self,
        query_embed_list,
        query_orig_idxs,
        strategy,
        buffer,
        output_json_filepath,
        output_model_filepath,
    ):
        # query_embed_list = [res["embed"] for res in query_res_list]

        # get prediction of each query given embedding vector using budget_model
        budget_models, budget_label_ranges = self._create_and_load_budget_model(
            output_model_filepath
        )
        budget_pred_labels, budget_outputs = self._predict_budgets_label(
            query_embed_list, budget_models
        )

        query_num = len(query_orig_idxs)
        num_clients = len(budget_models)

        pred_local_budgets = map_budget_values(
            budget_pred_labels,
            budget_label_ranges,
            server_ice_num=self.args.server_ice_num,
            strategy=strategy,
            use_buffer=True,
            buffer=buffer,
        )  # [client0_budgets_list, client1_budgets_list, ...]

        # save the budget prediction
        pred_local_budgets_dict = {}
        budget_outputs_dict = {}
        for qid in range(query_num):
            q_orig_idx = query_orig_idxs[qid]
            pred_local_budgets_dict[q_orig_idx] = [
                pred_local_budgets[cid][qid] for cid in range(num_clients)
            ]
            budget_outputs_dict[q_orig_idx] = [
                budget_outputs[cid][qid] for cid in range(num_clients)
            ]

        save_json(
            pred_local_budgets_dict,
            os.path.join(output_json_filepath, "pred_local_budget.json"),
        )
        torch.save(
            {"pred_budget_outputs_dict": budget_outputs_dict},
            os.path.join(output_json_filepath, "pred_budget_outputs.pt"),
        )

        # plot distribution of overall budgets on all queries
        pred_overall_budgets = [
            sum(vals) for vals in list(pred_local_budgets_dict.values())
        ]
        df = pd.DataFrame(pred_overall_budgets, columns=["Numbers"])
        number_counts = df["Numbers"].value_counts().reset_index()
        number_counts.columns = ["Number", "Frequency"]
        number_counts_desc = number_counts.sort_values(by="Number")
        fig, axs = plt.subplots(1, 1, figsize=(15, 7))
        sns.barplot(x="Number", y="Frequency", data=number_counts_desc, ax=axs)
        axs.set_title(
            f"Frequency of total predicted budgets (strategy='{strategy}', buffer={buffer}, server_ice_num={self.args.server_ice_num})"
        )
        axs.set_xlabel("Overall budgets values from 4 clients")
        axs.set_ylabel("Frequency")
        fig.savefig(
            os.path.join(output_json_filepath, "barplot-overall-local-budget.png"),
            dpi=600,
            bbox_inches="tight",
        )
        logger.info("Overall local budget distribution barplot save.")
        plt.close()

        return pred_local_budgets

    def __get_ppl(self, input_texts: List[str], mask_length=None):
        return self._PPLInferencer__get_ppl(input_texts, mask_length)

    def _get_ppl(self, input_texts: List[str], mask_length=None):
        # TODO: only for debug
        return self._PPLInferencer__get_ppl(input_texts, mask_length)
