"""PPL Federated Inferencer with Budget Prediction prediciton"""

from typing import List, Union, Optional, Dict, Any
import tqdm
from tqdm import trange
from collections import Counter
import json
import copy


import torch
from torch.utils.data import DataLoader

from datasets import concatenate_datasets, load_dataset, DatasetDict, Dataset
from transformers import PretrainedConfig
from transformers import AutoTokenizer

from sentence_transformers import SentenceTransformer
from accelerate import Accelerator

from openicl.icl_dataset_reader import DatasetEncoder
from openicl.icl_retriever import BaseRetriever, TopkRetriever
from openicl.utils.collators import DataCollatorWithPaddingAndCuda
from openicl import PromptTemplate, PPLInferencer, DatasetReader
from openicl.icl_retriever import *
from openicl.icl_evaluator import *
from openicl.icl_inferencer.icl_base_inferencer import PPLInferencerOutputHandler

from openicl.utils.logging import get_logger
from openicl.utils.api_service import *

from util.concatenate import (
    collect_samples,
    simple_concatenate,
    merge_concatenate,
    random_concatenate,
)
from util.icl_ppl_fed_inferencer import PPLFedInferencer
from util.retriever import get_retriever, get_server_retriever
from util.misc import save_json, AverageMeter, setup_seed
from util.dataset import QueryBudgetDatasetNew
from util.budget_model_tools import *

from models import get_model


logger = get_logger(__name__)


class PPLFedWeakBudgetInferencer(PPLFedInferencer):
    def __init__(
        self,
        model_name: Optional[str] = "gpt2-xl",
        tokenizer_name: Optional[str] = None,
        max_model_token_num: Optional[int] = None,
        model_config: Optional[PretrainedConfig] = None,
        batch_size: Optional[int] = 1,
        accelerator: Optional[Accelerator] = None,
        output_json_filepath: Optional[str] = "./server_icl_inference_output",
        output_json_filename: Optional[str] = "predictions",
        api_name: Optional[str] = None,
        labels: Optional[List] = None,
        model_parallel: Optional[bool] = False,
        args=None,
        sentence_transformers_model_name: Optional[str] = "all-mpnet-base-v2",
        device: Optional[str] = None,
        **kwargs,
    ) -> None:
        super().__init__(
            model_name,
            tokenizer_name,
            max_model_token_num,
            model_config,
            batch_size,
            accelerator,
            output_json_filepath,
            output_json_filename,
            api_name,
            labels,
            model_parallel,
            args,
            sentence_transformers_model_name,
            device,
            **kwargs,
        )

    def inference(
        self,
        retrievers: List[BaseRetriever],
        query_dataset: Union[Dataset, DatasetDict],
        query_split: Optional[str] = None,
        ice_template: Optional[PromptTemplate] = None,
        prompt_template: Optional[PromptTemplate] = None,
        output_json_filepath: Optional[str] = None,
        output_json_filename: Optional[str] = None,
        concat: Optional[str] = "simple",
        budget_strategy: Optional[str] = "random",
        normalizing_str: Optional[str] = None,
        args=None,
    ):
        """_summary_

        Args:
            retrievers (List[BaseRetriever]): _description_
            query_dataset (Union[Dataset, DatasetDict]): _description_
            query_split (Optional[str], optional): _description_. Defaults to None.
            ice_template (Optional[PromptTemplate], optional): _description_. Defaults to None.
            prompt_template (Optional[PromptTemplate], optional): _description_. Defaults to None.
            output_json_filepath (Optional[str], optional): _description_. Defaults to None.
            output_json_filename (Optional[str], optional): _description_. Defaults to None.
            concat (Optional[str], optional): _description_. Defaults to "simple".
            budget_strategy (Optional[str], optional): Can be either 'random', 'uniform'. Defaults to "random".
            normalizing_str (Optional[str], optional): _description_. Defaults to None.
            args (_type_, optional): _description_. Defaults to None.

        Returns:
            _type_: _description_
        """
        num_clients = len(retrievers)
        if isinstance(query_dataset, DatasetDict):
            query_dataset = query_dataset[query_split]

        query_num = len(query_dataset)

        # 1. Preparation for output logs
        output_handler = PPLInferencerOutputHandler(self.accelerator)

        sub_predictions = []
        ppl = []
        ice = []  # {cid: [] for cid in range(num_clients)}

        if output_json_filepath is None:
            output_json_filepath = self.output_json_filepath
        if output_json_filename is None:
            output_json_filename = self.output_json_filename

        # get embedding for queries
        query_orig_idxs = query_dataset["idx"]
        query_datalist = retrievers[0].dataset_reader.generate_input_field_corpus(
            query_dataset
        )  # TODO: ugly code because of `dataset_reader.generate_input_field_corpus`
        query_dataloader = retrievers[0].create_dataloader(query_datalist)
        query_res_list = retrievers[0].forward(
            self.retriever_model,
            query_dataloader,
            orig_idxs=query_orig_idxs,
            process_bar=True,
            information="Embedding test data queries...",
        )  # each element is {'embed': np.array, 'metadata': {'id': int, 'len': int, 'text': sample_text}, 'idx': int}

        # get local budgets allocation for each query
        # local_budgets is a list of num_clients lists, each local client's list is a list of integers, each integer represents the local budget for corresponding query
        if budget_strategy == "random":
            local_budgets = random_budgets_allocation(
                query_num=query_num,
                total_ice_num=args.overall_local_ice_num,
                num_clients=num_clients,
                seed=args.seed,
            )
        elif budget_strategy == "uniform":
            local_budgets = uniform_budgets_allocation(
                query_num=query_num,
                total_ice_num=args.overall_local_ice_num,
                num_clients=num_clients,
            )
        else:
            raise ValueError(
                f"Only supports 'uniform' and 'random' for budget_strategy, rather than '{budget_strategy}'."
            )

        # given each query embedding, loop through fed_retriever with predicted local budgets
        # 2. Get results of retrieval process
        server_ice_datasets = self._construct_server_ice_dataset_with_budgets(
            retrievers=retrievers,
            query_res_list=query_res_list,
            ice_budgets=local_budgets,
            concat=concat,
            output_json_filepath=output_json_filepath,
        )  # [q1_ice_Dataset, q2_ice_Dataset, ...]

        # 3. Get labels of all the classes (need to check all client's local dataset since NonIID partition)
        if self.labels is None:
            labels = []
            for cid in range(num_clients):
                labels.extend(
                    retrievers[cid].get_labels(
                        ice_template=ice_template, prompt_template=prompt_template
                    )
                )
            labels = list(set(labels))
        else:
            labels = self.labels

        # 4. Generate in-context examples indices for testing inputs
        ice_dataset_retrievers = []
        server_sample_orig_idx = {}  # [[] for _ in range(query_num)]
        for qid in trange(
            query_num,
            disable=not self.is_main_process,
            desc="Server side generating ICE: ",
        ):
            data_dict = DatasetDict(
                {
                    "train": server_ice_datasets[qid],
                    "test": Dataset.from_list([query_dataset[qid]]),
                }
            )
            data = DatasetReader(
                data_dict,
                input_columns=retrievers[0].dataset_reader.input_columns,
                output_column=retrievers[0].dataset_reader.output_column,
            )
            per_query_retriever = get_server_retriever(args)(
                data, ice_num=args.server_ice_num
            )

            # check whether need to do reorder on server side
            ice_num = len(server_ice_datasets[qid])
            if concat == "reorder":
                ice_idxs = per_query_retriever.retrieve(
                    model=self.retriever_model, use_trange=False
                )[0]
            else:
                ice_idxs = list(range(ice_num))

            q_orig_idx = per_query_retriever.test_ds[0]["idx"]
            server_sample_orig_idx[q_orig_idx] = per_query_retriever.index_ds.select(
                ice_idxs
            )["idx"]
            # use selected samples in each query's ICE dataset to generate ICE
            ice.append(
                per_query_retriever.generate_ice(ice_idxs, ice_template=ice_template)
            )
            ice_dataset_retrievers.append(per_query_retriever)

        output_handler.save_ice(ice)
        server_ice_index_file = f"{output_json_filepath}/server_ice_indices.json"
        save_json(server_sample_orig_idx, server_ice_index_file)

        # 5. Calculating PPL for prompts in each label's class
        for label in labels:
            index = 0
            prompt_list = []
            sub_ppl_list = []
            normalizing_prompt_list = []
            context_length_list = []

            # 5.1 Generate prompts of current label and truncate
            for qid in range(query_num):
                prompt = ice_dataset_retrievers[qid].generate_label_prompt(
                    0,
                    ice[qid],
                    label,
                    ice_template=ice_template,
                    prompt_template=prompt_template,
                    remain_sep=normalizing_str is not None,
                )

                # TODO: max_model_token_num is for generation task, temporory removed
                # if self.max_model_token_num is not None and self.api_name != "gpt3":
                #     prompt_token_num = self.get_input_token_num(prompt)
                #     while (
                #         len(ice_idx_list[idx]) > 0
                #         and prompt_token_num > self.max_model_token_num
                #     ):
                #         ice_idx_list[idx] = ice_idx_list[idx][:-1]
                #         ice[idx] = retriever.generate_ice(
                #             ice_idx_list[idx], ice_template=ice_template
                #         )
                #         prompt = retriever.generate_label_prompt(
                #             idx,
                #             ice[idx],
                #             label,
                #             ice_template=ice_template,
                #             prompt_template=prompt_template,
                #         )
                #         prompt_token_num = self.get_input_token_num(prompt)

                if normalizing_str is not None:
                    prompt_sep = prompt
                    if prompt_template is not None:
                        sep_token = prompt_template.sep_token
                    else:
                        sep_token = ice_template.sep_token
                    sep_pos = prompt_sep.find(sep_token)

                    context = prompt_sep[0:sep_pos]
                    answer = prompt_sep[sep_pos:].replace(sep_token, "")
                    prompt = context + answer
                    normalizing_prompt = normalizing_str + answer

                    context_length_list.append(self.get_input_token_num(context))
                    normalizing_prompt_list.append(normalizing_prompt)
                prompt_list.append(prompt)

            if normalizing_str is not None:
                normalizing_str_len = self.get_input_token_num(normalizing_str)

            # 5.2 Get PPL: loop through all queries with current give label
            logger.info(f"Calculating PPL for prompts labeled '{label}'")
            for idx in trange(
                0, len(prompt_list), self.batch_size, disable=not self.is_main_process
            ):
                sub_prompt_list = prompt_list[idx : idx + self.batch_size]
                if normalizing_str is not None:
                    sub_context_length_list = context_length_list[
                        idx : idx + self.batch_size
                    ]
                    sub_normalizing_prompt_list = normalizing_prompt_list[
                        idx : idx + self.batch_size
                    ]

                with torch.no_grad():
                    if normalizing_str is not None:
                        res1 = self.__get_ppl(
                            input_texts=sub_prompt_list,
                            mask_length=sub_context_length_list,
                        )
                        res2 = self.__get_ppl(
                            input_texts=sub_normalizing_prompt_list,
                            mask_length=[
                                normalizing_str_len for i in range(len(sub_prompt_list))
                            ],
                        )
                        sub_res = res1 - res2
                    else:
                        sub_res = self.__get_ppl(
                            sub_prompt_list
                        ).tolist()  # list of scores, length is batch size

                for batch_offset, (res, prompt) in enumerate(
                    zip(sub_res, sub_prompt_list)
                ):
                    sub_ppl_list.append(res)
                    output_handler.save_prompt_and_ppl(
                        label,
                        prompt[len(ice[idx + batch_offset]) :],
                        prompt,
                        res,
                        index,
                    )
                    index = index + 1
            ppl.append(sub_ppl_list)

        # 6. Get lowest PPL class as predictions
        ppl = list(zip(*ppl))
        for single_ppl in ppl:
            sub_predictions.append(labels[single_ppl.index(min(single_ppl))])
        output_handler.save_predictions(sub_predictions)

        # 7. Output
        output_handler.subprocess_write_to_json(
            output_json_filepath, output_json_filename
        )
        if self.accelerator is not None:
            self.accelerator.wait_for_everyone()
        output_handler.merge_to_main_process(output_json_filepath, output_json_filename)
        output_handler.write_to_json(output_json_filepath, output_json_filename)

        return [sample["prediction"] for sample in output_handler.results_dict.values()]

    def _construct_server_ice_dataset_with_budgets(
        self,
        retrievers: List[BaseRetriever],
        query_res_list: List[Dict[str, Any]],
        ice_budgets: List[List[int]],
        concat: Optional[str] = "simple",
        output_json_filepath: Optional[str] = None,
    ):
        num_clients = len(retrievers)
        query_num = len(query_res_list)

        ice_idx_dict = {}

        for cid in range(num_clients):
            print(f"Client {cid} retrieves local ICE samples...")
            ice_idx_dict[cid] = retrievers[cid].retrieve_with_budget(
                query_res_list, ice_budgets=ice_budgets[cid], model=self.retriever_model
            )

        samples_for_all_query = collect_samples(ice_idx_dict, retrievers)

        # ice original index in original training set
        local_sample_orig_idx = {}  # [[] for _ in range(query_num)]
        for qid in range(query_num):
            q_orig_idx = query_res_list[qid]["idx"]
            local_sample_orig_idx[q_orig_idx] = {}
            for cid in range(num_clients):
                orig_idx = samples_for_all_query[qid][cid]["idx"]
                local_sample_orig_idx[q_orig_idx][cid] = orig_idx

        local_ice_index_file = f"{output_json_filepath}/local_ice_indices.json"
        save_json(local_sample_orig_idx, local_ice_index_file)

        if concat in ["simple", "reorder"]:
            concat_func = simple_concatenate
        elif concat == "random":
            concat_func = random_concatenate
        elif concat == "merge":
            concat_func = merge_concatenate
        else:
            raise ValueError(
                f"concate should be either 'simple', 'merge', 'random', or 'reorder', rather than '{concat}'."
            )

        server_ice = []
        # for each test query, perform concatenate for server side ICE dataset based on all clients local ICE dataset
        for qid in trange(
            query_num,
            disable=False,
            desc="Construct server side ICE dataset for each query: ",
        ):
            ice_data = concat_func(
                samples_for_all_query[qid],
                chunk_sample_num=self.args.server_ice_num,
                seed=self.args.seed,
            )
            server_ice.append(ice_data)

        return server_ice

    def __get_ppl(self, input_texts: List[str], mask_length=None):
        return self._PPLInferencer__get_ppl(input_texts, mask_length)

    # def _get_ppl(self, input_texts: List[str], mask_length=None):
    #     # TODO: only for debug
    #     return self._PPLInferencer__get_ppl(input_texts, mask_length)
