"""Basic Inferencer."""
import json
import os
from pathlib import Path
from typing import List, Optional

import numpy as np
from mmengine.dist import is_main_process
from torch.utils.data import DataLoader

from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever


class BaseInferencer:
    """Base Inferencer class for all evaluation Inferencer.

    Attributes:
        model (:obj:`BaseModel`, optional): The module to inference.
        max_model_token_num (:obj:`int`, optional): Maximum number of
            tokenized words allowed by the LM.
        batch_size (:obj:`int`, optional): Batch size for the
            :obj:`DataLoader`.
        output_json_filepath (:obj:`str`, optional): File path for output
            `JSON` file.
        output_json_filename (:obj:`str`, optional): File name for output
            `JSON` file.
    """
    model = None

    def __init__(
        self,
        model,
        max_seq_len: Optional[int] = None,
        batch_size: Optional[int] = 1,
        output_json_filepath: Optional[str] = './icl_inference_output',
        output_json_filename: Optional[str] = 'predictions',
        fix_id_list: Optional[List[int]] = None,
        **kwargs,
    ) -> None:

        if fix_id_list:
            raise ValueError('Passing fix_id_list to Inferencer is no longer '
                             'allowed. Please pass it to FixKRetriever '
                             'instead.')

        self.model = model

        self.max_seq_len = max_seq_len
        self.batch_size = batch_size
        self.output_json_filepath = output_json_filepath
        self.output_json_filename = output_json_filename
        self.is_main_process = is_main_process()
        os.makedirs(self.output_json_filepath, exist_ok=True)

    def inference(self,
                  retriever: BaseRetriever,
                  ice_template: Optional[PromptTemplate] = None,
                  prompt_template: Optional[PromptTemplate] = None,
                  output_json_filepath: Optional[str] = None,
                  output_json_filename: Optional[str] = None) -> List:
        """Perform In-Context Inference given a retriever and optional
        templates.

        Args:
            retriever (:obj:`BaseRetriever`): An instance of a Retriever class
                that will be used to retrieve in-context examples
            ice_template (:obj:`PromptTemplate`, optional): A template for
                generating the in-context examples prompt. Defaults to None.
            prompt_template (:obj:`PromptTemplate`, optional): A template for
                generating the final prompt. Defaults to None.
            output_json_filepath (:obj:`str`, optional): The file path to save
                the results as a `JSON` file. Defaults to None.
            output_json_filename (:obj:`str`, optional): The file name to save
                the results as a `JSON` file. Defaults to None.

        Raises:
            NotImplementedError: If the function is not implemented in the
                subclass.

        Returns:
            :obj:`List:` A list of string, each representing the results of one
                inference.
        """
        raise NotImplementedError("Method hasn't been implemented yet")

    @staticmethod
    def get_dataloader(datalist: List[List], batch_size: int) -> DataLoader:
        """Return a dataloader of the input data list."""
        dataloader = DataLoader(datalist,
                                batch_size=batch_size,
                                collate_fn=lambda x: x)
        return dataloader


def dump_results_dict(results_dict, filename):
    with open(filename, 'w', encoding='utf-8') as json_file:
        json.dump(results_dict, json_file, indent=4, ensure_ascii=False)


class GenInferencerOutputHandler:
    origin_prompt_dict = {}
    output_dict = {}
    prediction_dict = {}
    results_dict = {}

    def __init__(self) -> None:
        self.results_dict = {}

    def write_to_json(self, save_dir: str, filename: str):
        """Dump the result to a json file."""
        dump_results_dict(self.results_dict, Path(save_dir) / filename)

    def save_results(self, origin_prompt, prediction, idx, gold=None):
        self.results_dict[str(idx)] = {
            'origin_prompt': origin_prompt,
            'prediction': prediction,
        }
        if gold:
            self.results_dict[str(idx)]['gold'] = gold


class PPLInferencerOutputHandler:
    results_dict = {}

    def __init__(self) -> None:
        self.results_dict = {}

    def write_to_json(self, save_dir: str, filename: str):
        """Dump the result to a json file."""
        dump_results_dict(self.results_dict, Path(save_dir) / filename)

    def save_ice(self, ice):
        for idx, example in enumerate(ice):
            if str(idx) not in self.results_dict.keys():
                self.results_dict[str(idx)] = {}
            self.results_dict[str(idx)]['in-context examples'] = example

    def save_predictions(self, predictions):
        for idx, prediction in enumerate(predictions):
            if str(idx) not in self.results_dict.keys():
                self.results_dict[str(idx)] = {}
            self.results_dict[str(idx)]['prediction'] = prediction

    def save_prompt_and_ppl(self, label, input, prompt, ppl, idx):
        if str(idx) not in self.results_dict.keys():
            self.results_dict[str(idx)] = {}
        if 'label: ' + str(label) not in self.results_dict[str(idx)].keys():
            self.results_dict[str(idx)]['label: ' + str(label)] = {}
        self.results_dict[str(idx)]['label: ' +
                                    str(label)]['testing input'] = input
        self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt
        self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl

    def save_golds(self, golds):
        for idx, gold in enumerate(golds):
            if str(idx) not in self.results_dict.keys():
                self.results_dict[str(idx)] = {}
            self.results_dict[str(idx)]['gold'] = gold


class CLPInferencerOutputHandler:
    results_dict = {}

    def __init__(self) -> None:
        self.results_dict = {}

    def write_to_json(self, save_dir: str, filename: str):
        """Dump the result to a json file."""
        dump_results_dict(self.results_dict, Path(save_dir) / filename)

    def save_ice(self, ice):
        for idx, example in enumerate(ice):
            if str(idx) not in self.results_dict.keys():
                self.results_dict[str(idx)] = {}
            self.results_dict[str(idx)]['in-context examples'] = example

    def save_prompt_and_condprob(self,
                                 input,
                                 prompt,
                                 cond_prob,
                                 idx,
                                 choices,
                                 gold=None):
        if str(idx) not in self.results_dict.keys():
            self.results_dict[str(idx)] = {}
        # TODO:
        # for single token situation, the input will always be yes currently
        self.results_dict[str(idx)]['testing input'] = input
        self.results_dict[str(idx)]['prompt'] = prompt
        # TODO: hard code here
        self.results_dict[str(idx)]['choices'] = choices
        # For calculate auc scores, set scores as prediction
        self.results_dict[str(idx)]['prediction'] = cond_prob
        # set pred label in case needed
        self.results_dict[str(idx)]['pred_label'] = int(np.argmax(cond_prob))
        self.results_dict[str(idx)]['gold'] = gold
