"""PPL FederatedInferencer"""

from typing import List, Union, Optional, Dict, Any
import tqdm
from tqdm import trange
from collections import Counter
import json
import copy


import torch
from torch.utils.data import DataLoader

from datasets import concatenate_datasets, load_dataset, DatasetDict, Dataset
from transformers import PretrainedConfig
from transformers import AutoTokenizer

from sentence_transformers import SentenceTransformer
from accelerate import Accelerator

from openicl.icl_dataset_reader import DatasetEncoder
from openicl.icl_retriever import BaseRetriever, TopkRetriever
from openicl.utils.collators import DataCollatorWithPaddingAndCuda
from openicl import PromptTemplate, PPLInferencer, DatasetReader
from openicl.icl_retriever import *
from openicl.icl_evaluator import *
from openicl.icl_inferencer.icl_base_inferencer import PPLInferencerOutputHandler

from openicl.utils.logging import get_logger
from openicl.utils.api_service import *

from util.concatenate import (
    collect_samples,
    simple_concatenate,
    merge_concatenate,
    random_concatenate,
)
from util.retriever import get_retriever, get_server_retriever
from util.misc import save_json, AverageMeter, setup_seed
from util.dataset import QueryBudgetDataset
from models import get_model


logger = get_logger(__name__)


class PPLFedInferencer(PPLInferencer):
    def __init__(
        self,
        model_name: Optional[str] = "gpt2-xl",
        tokenizer_name: Optional[str] = None,
        max_model_token_num: Optional[int] = None,
        model_config: Optional[PretrainedConfig] = None,
        batch_size: Optional[int] = 1,
        accelerator: Optional[Accelerator] = None,
        output_json_filepath: Optional[str] = "./server_icl_inference_output",
        output_json_filename: Optional[str] = "predictions",
        api_name: Optional[str] = None,
        labels: Optional[List] = None,
        model_parallel: Optional[bool] = False,
        args=None,
        sentence_transformers_model_name: Optional[str] = "all-mpnet-base-v2",
        device: Optional[str] = None,
        **kwargs,
    ) -> None:
        super().__init__(
            model_name,
            tokenizer_name,
            max_model_token_num,
            model_config,
            batch_size,
            accelerator,
            output_json_filepath,
            output_json_filename,
            api_name,
            labels,
            model_parallel,
            **kwargs,
        )
        self.args = args
        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device

        if self.args.retriever == "topk":
            self.retriever_model = SentenceTransformer(sentence_transformers_model_name)
            self.retriever_model = self.retriever_model.to(self.device)
            logger.info(f"Retriever model move to CUDA")
            self.retriever_model.eval()
        else:
            self.retriever_model = None

    def _construct_server_ice_dataset(
        self,
        retrievers: List[BaseRetriever],
        query_dataset: Dataset,
        ice_num: Optional[int] = None,
        concat: Optional[str] = "simple",
        output_json_filepath: Optional[str] = None,
        prefix: Optional[str] = None,
    ) -> List[Dataset]:
        """Retrieve ICE samples from each local retrieversm, then collect selected samples from each client's local dataset to form the server side
        ICE ``Dataset``.

        Args:
            retrievers (List[BaseRetriever]): List of local retrievers, each is able to retrieve samples based on local trainset.
            query_dataset (Dataset): Dataset containing test queries
            ice_num (Optional[int], optional): _description_. Defaults to None.
            concat (Optional[str], optional): Concatenation strategy for ICE samples from different local clients, can be ``'simple'`` or ``'merge'``, or ``'reorder'``. Defaults to 'simple'.
            output_json_filepath (Optional[str], optional): _description_. Defaults to None.
            proxy (Optional[bool], optional): Whether is proxy dataset construction process. Defaults to False.

        Raises:
            ValueError: _description_

        Returns:
            List[Dataset]: ``[q1_ice_Dataset, q2_ice_Dataset, q3_ice_Dataset, ...]``
        """
        num_clients = len(retrievers)
        query_num = len(query_dataset)

        ice_idx_dict = {}

        for cid in range(num_clients):
            print(f"Client {cid} retrieves local ICE samples...")
            ice_idx_dict[cid] = retrievers[cid].retrieve(
                query_dataset, ice_num=ice_num, model=self.retriever_model
            )

        samples_for_all_query = collect_samples(ice_idx_dict, retrievers)

        # ice original index in original training set
        local_sample_orig_idx = {}  # [[] for _ in range(query_num)]
        for qid in range(query_num):
            q_orig_idx = query_dataset[qid]["idx"]
            local_sample_orig_idx[q_orig_idx] = {}
            for cid in range(num_clients):
                orig_idx = samples_for_all_query[qid][cid]["idx"]
                local_sample_orig_idx[q_orig_idx][cid] = orig_idx

        local_ice_index_file = os.path.join(
            output_json_filepath,
            f"{prefix+'_' if prefix is not None else ''}local_ice_indices.json",
        )

        save_json(local_sample_orig_idx, local_ice_index_file)

        if concat in ["simple", "reorder"]:
            concat_func = simple_concatenate
        elif concat == "random":
            concat_func = random_concatenate
        elif concat == "merge":
            concat_func = merge_concatenate
        else:
            raise ValueError(
                f"concate should be either 'simple', 'merge', 'random', or 'reorder', rather than '{concat}'."
            )

        server_ice = []
        # for each test query, perform concatenate for server side ICE dataset based on all clients local ICE dataset
        for qid in trange(
            query_num,
            disable=False,
            desc="Construct server side ICE dataset for each query: ",
        ):
            ice_data = concat_func(
                samples_for_all_query[qid],
                chunk_sample_num=self.args.server_ice_num,
                seed=self.args.seed,
            )
            server_ice.append(ice_data)

        return server_ice

    def inference(
        self,
        retrievers: List[BaseRetriever],
        query_dataset: Union[Dataset, DatasetDict],
        query_split: Optional[str] = None,
        ice_template: Optional[PromptTemplate] = None,
        prompt_template: Optional[PromptTemplate] = None,
        output_json_filepath: Optional[str] = None,
        output_json_filename: Optional[str] = None,
        concat: Optional[str] = "simple",
        normalizing_str: Optional[str] = None,
        args=None,
    ) -> List:

        num_clients = len(retrievers)
        if isinstance(query_dataset, DatasetDict):
            query_dataset = query_dataset[query_split]

        query_num = len(query_dataset)

        # 1. Preparation for output logs
        output_handler = PPLInferencerOutputHandler(self.accelerator)

        sub_predictions = []
        ppl = []
        ice = []  # {cid: [] for cid in range(num_clients)}

        if output_json_filepath is None:
            output_json_filepath = self.output_json_filepath
        if output_json_filename is None:
            output_json_filename = self.output_json_filename

        # 2. Get results of retrieval process
        server_ice_datasets = self._construct_server_ice_dataset(
            retrievers=retrievers,
            query_dataset=query_dataset,
            concat=concat,
            output_json_filepath=output_json_filepath,
        )  # [q1_ice_Dataset, q2_ice_Dataset, ...]

        # 3. Get labels of all the classes (need to check all client's local dataset since NonIID partition)
        if self.labels is None:
            labels = []
            for cid in range(num_clients):
                labels.extend(
                    retrievers[cid].get_labels(
                        ice_template=ice_template, prompt_template=prompt_template
                    )
                )
            labels = list(set(labels))
        else:
            labels = self.labels

        # 4. Generate in-context examples indices for testing inputs
        ice_dataset_retrievers = []
        server_sample_orig_idx = {}  # [[] for _ in range(query_num)]
        for qid in trange(
            query_num,
            disable=not self.is_main_process,
            desc="Server side generating ICE: ",
        ):
            data_dict = DatasetDict(
                {
                    "train": server_ice_datasets[qid],
                    "test": Dataset.from_list([query_dataset[qid]]),
                }
            )
            data = DatasetReader(
                data_dict,
                input_columns=retrievers[0].dataset_reader.input_columns,
                output_column=retrievers[0].dataset_reader.output_column,
            )
            per_query_retriever = get_server_retriever(args)(
                data, ice_num=args.server_ice_num
            )

            # check whether need to do reorder on server side
            ice_num = len(server_ice_datasets[qid])
            if concat == "reorder":
                ice_idxs = per_query_retriever.retrieve(
                    model=self.retriever_model, use_trange=False
                )[0]
            else:
                ice_idxs = list(range(ice_num))

            q_orig_idx = per_query_retriever.test_ds[0]["idx"]
            server_sample_orig_idx[q_orig_idx] = per_query_retriever.index_ds.select(
                ice_idxs
            )["idx"]
            # use selected samples in each query's ICE dataset to generate ICE
            ice.append(
                per_query_retriever.generate_ice(ice_idxs, ice_template=ice_template)
            )
            ice_dataset_retrievers.append(per_query_retriever)

        output_handler.save_ice(ice)
        server_ice_index_file = f"{output_json_filepath}/server_ice_indices.json"
        save_json(server_sample_orig_idx, server_ice_index_file)

        # 5. Calculating PPL for prompts in each label's class
        for label in labels:
            index = 0
            prompt_list = []
            sub_ppl_list = []
            normalizing_prompt_list = []
            context_length_list = []

            # 5.1 Generate prompts of current label and truncate
            for qid in range(query_num):
                prompt = ice_dataset_retrievers[qid].generate_label_prompt(
                    0,
                    ice[qid],
                    label,
                    ice_template=ice_template,
                    prompt_template=prompt_template,
                    remain_sep=normalizing_str is not None,
                )

                # TODO: max_model_token_num is for generation task, temporory removed
                # if self.max_model_token_num is not None and self.api_name != "gpt3":
                #     prompt_token_num = self.get_input_token_num(prompt)
                #     while (
                #         len(ice_idx_list[idx]) > 0
                #         and prompt_token_num > self.max_model_token_num
                #     ):
                #         ice_idx_list[idx] = ice_idx_list[idx][:-1]
                #         ice[idx] = retriever.generate_ice(
                #             ice_idx_list[idx], ice_template=ice_template
                #         )
                #         prompt = retriever.generate_label_prompt(
                #             idx,
                #             ice[idx],
                #             label,
                #             ice_template=ice_template,
                #             prompt_template=prompt_template,
                #         )
                #         prompt_token_num = self.get_input_token_num(prompt)

                if normalizing_str is not None:
                    prompt_sep = prompt
                    if prompt_template is not None:
                        sep_token = prompt_template.sep_token
                    else:
                        sep_token = ice_template.sep_token
                    sep_pos = prompt_sep.find(sep_token)

                    context = prompt_sep[0:sep_pos]
                    answer = prompt_sep[sep_pos:].replace(sep_token, "")
                    prompt = context + answer
                    normalizing_prompt = normalizing_str + answer

                    context_length_list.append(self.get_input_token_num(context))
                    normalizing_prompt_list.append(normalizing_prompt)
                prompt_list.append(prompt)

            if normalizing_str is not None:
                normalizing_str_len = self.get_input_token_num(normalizing_str)

            # 5.2 Get PPL: loop through all queries with current give label
            logger.info(f"Calculating PPL for prompts labeled '{label}'")
            for idx in trange(
                0, len(prompt_list), self.batch_size, disable=not self.is_main_process
            ):
                sub_prompt_list = prompt_list[idx : idx + self.batch_size]
                if normalizing_str is not None:
                    sub_context_length_list = context_length_list[
                        idx : idx + self.batch_size
                    ]
                    sub_normalizing_prompt_list = normalizing_prompt_list[
                        idx : idx + self.batch_size
                    ]

                with torch.no_grad():
                    if normalizing_str is not None:
                        res1 = self.__get_ppl(
                            input_texts=sub_prompt_list,
                            mask_length=sub_context_length_list,
                        )
                        res2 = self.__get_ppl(
                            input_texts=sub_normalizing_prompt_list,
                            mask_length=[
                                normalizing_str_len for i in range(len(sub_prompt_list))
                            ],
                        )
                        sub_res = res1 - res2
                    else:
                        sub_res = self.__get_ppl(
                            sub_prompt_list
                        ).tolist()  # list of scores, length is batch size

                for batch_offset, (res, prompt) in enumerate(
                    zip(sub_res, sub_prompt_list)
                ):
                    sub_ppl_list.append(res)
                    output_handler.save_prompt_and_ppl(
                        label,
                        prompt[len(ice[idx + batch_offset]) :],
                        prompt,
                        res,
                        index,
                    )
                    index = index + 1
            ppl.append(sub_ppl_list)

        # 6. Get lowest PPL class as predictions
        ppl = list(zip(*ppl))
        for single_ppl in ppl:
            sub_predictions.append(labels[single_ppl.index(min(single_ppl))])
        output_handler.save_predictions(sub_predictions)

        # 7. Output
        output_handler.subprocess_write_to_json(
            output_json_filepath, output_json_filename
        )
        if self.accelerator is not None:
            self.accelerator.wait_for_everyone()
        output_handler.merge_to_main_process(output_json_filepath, output_json_filename)
        output_handler.write_to_json(output_json_filepath, output_json_filename)

        return [sample["prediction"] for sample in output_handler.results_dict.values()]

    def __get_ppl(self, input_texts: List[str], mask_length=None):
        return self._PPLInferencer__get_ppl(input_texts, mask_length)
