import itertools
import random
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from datasets import load_dataset, get_dataset_split_names
from datasets import Dataset as DS
import evaluate
from torch.utils.data import Dataset
import numpy as np

class HuggingFaceClassificationDataset(Dataset):
    x_introduction: str
    y_introduction: str
    task_name: str
    dataset_identifier: str

    def __init__(self, split: str, **kwargs):
        self.ds: Dataset = load_dataset(self.dataset_identifier, split=split, **kwargs)  # type: ignore
        assert isinstance(self.ds, DS)

    @property
    def splits(self):
        return get_dataset_split_names(self.dataset_identifier)

    def __len__(self):
        return self.ds.num_rows  # type: ignore

    def evaluate(self, y_pred: List[str], y_true: List[List[str]]) -> Dict:
        # y_true = [y[0].replace(self.y_introduction, "").strip() for y in y_true]

        true: List[str] = [y[0].strip() for y in y_true]
        pred: List[str] = [y.strip() for y in y_pred]

        return {'accuracy': np.mean([t == p for t, p in zip(true, pred)])}


class Banking(HuggingFaceClassificationDataset):
    dataset_identifier = "PolyAI/banking77"
    x_introduction = "QUESTION:"
    y_introduction = "CATEGORY:"
    task_name = "Banking"

    reduced_labels = {
        "activate my card": "card",
        "age limit": "account",
        "apple pay or google pay": "apple pay or google pay",
        "atm support": "withdrawal",
        "automatic top up": "top up",
        "balance not updated after bank transfer": "transfer",
        "balance not updated after cheque or cash deposit": "top up",
        "beneficiary not allowed": "transfer",
        "cancel transfer": "transfer",
        "card about to expire": "card",
        "card acceptance": "card",
        "card arrival": "card",
        "card delivery estimate": "card",
        "card linking": "card",
        "card not working": "card",
        "card payment fee charged": "card",
        "card payment not recognised": "card",
        "card payment wrong exchange rate": "card",
        "card swallowed": "card",
        "cash withdrawal charge": "withdrawal",
        "cash withdrawal not recognised": "withdrawal",
        "change pin": "card",
        "compromised card": "card",
        "contactless not working": "card",
        "country support": "account",
        "declined card payment": "card",
        "declined cash withdrawal": "withdrawal",
        "declined transfer": "transfer",
        "direct debit payment not recognised": "transfer",
        "disposable card limits": "card",
        "edit personal details": "account",
        "exchange charge": "exchange",
        "exchange rate": "exchange",
        "exchange via app": "exchange",
        "extra charge on statement": "account",
        "failed transfer": "transfer",
        "fiat currency support": "withdrawal",
        "get disposable virtual card": "card",
        "get physical card": "card",
        "getting spare card": "card",
        "getting virtual card": "card",
        "lost or stolen card": "card",
        "lost or stolen phone": "account",
        "order physical card": "card",
        "passcode forgotten": "card",
        "pending card payment": "card",
        "pending cash withdrawal": "withdrawal",
        "pending top up": "top up",
        "pending transfer": "transfer",
        "pin blocked": "card",
        "receiving money": "transfer",
        "Refund not showing up": "refund",
        "request refund": "refund",
        "reverted card payment?": "card",
        "supported cards and currencies": "card",
        "terminate account": "account",
        "top up by bank transfer charge": "top up",
        "top up by card charge": "top up",
        "top up by cash or cheque": "top up",
        "top up failed": "top up",
        "top up limits": "top up",
        "top up reverted": "top up",
        "topping up by card": "top up",
        "transaction charged twice": "transfer",
        "transfer fee charged": "transfer",
        "transfer into account": "transfer",
        "transfer not received by recipient": "transfer",
        "transfer timing": "transfer",
        "unable to verify identity": "account",
        "verify my identity": "account",
        "verify source of funds": "account",
        "verify top up": "top up",
        "virtual card not working": "card",
        "visa or mastercard": "card",
        "why verify identity": "account",
        "wrong amount of cash received": "withdrawal",
        "wrong exchange rate for cash withdrawal": "withdrawal",
    }

    def __init__(self, split: str):
        super().__init__(split=split)
        self.labels = [s.replace("_", " ") for s in self.ds.features["label"].names] # type: ignore
        self.labels = [self.reduced_labels[s] for s in self.labels]

    def __getitem__(self, idx: int) -> Tuple[str, List[str]]:
        item = self.ds[idx]
        x, label = item["text"], item["label"]
        y = self.labels[label]
        return x, [y]


class Emotion(HuggingFaceClassificationDataset):
    dataset_identifier = "dair-ai/emotion"
    x_introduction = "TEXT:"
    y_introduction = "EMOTION:"
    task_name = "Emotion"

    def __init__(self, split: str):
        super().__init__(split=split, name="split")
        self.labels = [s.replace("_", " ") for s in self.ds.features["label"].names] # type: ignore

    def __getitem__(self, idx: int) -> Tuple[str, List[str]]:
        item = self.ds[idx]
        x, label = item["text"], item["label"]
        y = self.labels[label]
        return x, [y]
    
class EmotionChangedLabels(Emotion):

    label_map = {
        "sadness": "joy",
        "joy": "sadness",
        "anger": "love",
        "love": "anger",
        "fear": "surprise",
        "surprise": "fear",
    }

    def __getitem__(self, idx: int) -> Tuple[str, List[str]]:
        x, ys = super().__getitem__(idx)
        return x, [self.label_map[y] for y in ys]


class YahooAnswers(HuggingFaceClassificationDataset):
    dataset_identifier = "yahoo_answers_topics"
    x_introduction = "QUESTION:"
    y_introduction = "CATEGORY:"
    task_name = "YahooAnswers"

    def __init__(self, split: str):
        super().__init__(split=split)
        self.labels = [s.replace("_", " ") for s in self.ds.features["topic"].names] # type: ignore

    def __getitem__(self, idx: int) -> Tuple[str, List[str]]:
        item = self.ds[idx]
        x, label = (
            item["question_title"] + " " + item["question_content"],
            item["topic"],
        )
        y = self.labels[label]
        return x, [y]


class WinoGrande(HuggingFaceClassificationDataset):
    dataset_identifier = "winogrande"
    x_introduction = ""
    y_introduction = ""
    task_name = "WinoGrande"

    def __init__(self, split: str):
        # the test split of this dataset doesn't have labels so we will use the validation instead
        if split == "test":
            split = "validation"
        super().__init__(split=split, name="winogrande_debiased")

    def __getitem__(self, idx: int) -> Tuple[str, List[str]]:
        item = self.ds[idx]
        x = f"{item['sentence']} What is missing: {item['option1']} or {item['option2']}?"
        y = item["option1"] if int(item["answer"]) == 1 else item["option2"]
        return x, [y]


class HuggingFaceTableToTextDataset(Dataset):
    x_introduction: str
    y_introduction: str
    task_name: str
    dataset_identifier: str

    def __init__(self, split: str, **kwargs):
        self.ds: Dataset = load_dataset(self.dataset_identifier, split=split, **kwargs)  # type: ignore
        assert isinstance(self.ds, DS)

        self.metrics = {
            "bleu": evaluate.load("bleu"),
            "ter": evaluate.load("ter"),
            "meteor": evaluate.load("meteor"),
            "bertscore": evaluate.load("bertscore"),
            "rouge": evaluate.load("rouge"),
        }

    @property
    def splits(self):
        return get_dataset_split_names(self.dataset_identifier)

    def __len__(self):
        return self.ds.num_rows  # type: ignore

    def evaluate(self, y_pred: List[str], y_true: List[List[str]]) -> Dict:
        results = dict()
        results.update({f"bleu_{k}": v for k, v in self.metrics["bleu"].compute(predictions=y_pred, references=y_true).items()})  # type: ignore
        results.update({f"ter_{k}": v for k, v in self.metrics["ter"].compute(predictions=y_pred, references=self.equalize_number_of_targets(y_true), case_sensitive=True).items()})  # type: ignore
        results.update({f"bertscore_{k}": v for k, v in self.metrics["bertscore"].compute(predictions=y_pred, references=y_true, lang="en").items()})  # type: ignore
        results.update(self.metrics["meteor"].compute(predictions=y_pred, references=y_true))  # type: ignore
        results.update(self.metrics["rouge"].compute(predictions=y_pred, references=y_true))  # type: ignore

        return results

    @classmethod
    def equalize_number_of_targets(cls, targets: List[List[str]]):
        """ Some metrics require that all samples have the same number of targets. This is a hack to ensure this. """
        lens = [len(t) for t in targets]
        median_length = int(np.median(lens))
        return [list(itertools.islice(itertools.cycle(t), median_length)) for t in targets]


class DART(HuggingFaceTableToTextDataset):
    x_introduction = "TABLE:"
    y_introduction = "TEXT:"
    task_name = "DART"
    dataset_identifier = "dart"

    def __getitem__(self, idx: int) -> Tuple[str, List[str]]:
        item = self.ds[idx]
        x = " | ".join([" : ".join(triplet) for triplet in item["tripleset"]])
        y = item["annotations"]["text"]
        return x, y


class E2E(HuggingFaceTableToTextDataset):
    x_introduction = "TABLE:"
    y_introduction = "TEXT:"
    task_name = "E2E"
    dataset_identifier = "e2e_nlg"

    @classmethod
    def _reformat_x(cls, string: str):
        """
        Converts strings formatted as "name[The Rice Boat], food[French], priceRange[£20-25], customer rating[3 out of 5]"
        to "name: The Rice Boat, food: French, priceRange: £20-25, customer rating: 3 out of 5"
        """
        string = string.replace("[", ": ").replace("]", "")
        return string

    def __getitem__(self, idx: int) -> Tuple[str, List[str]]:
        item = self.ds[idx]

        x = self._reformat_x(item["meaning_representation"])
        y = [item["human_reference"]]

        return x, y


if __name__ == "__main__":
    import numpy as np
    import tqdm
    from transformers import LlamaForCausalLM, LlamaTokenizer

    # You might need to adjust this path
    tokenizer = LlamaTokenizer.from_pretrained("../llama/META_RELEASED_WEIGHTS")

    # Test that the dataset classes behave as expected

    m_acc = evaluate.load("accuracy")

    print("TEST THE CLASSIFICATION DATASETS:")

    for ds_class in [Banking, Emotion, YahooAnswers, WinoGrande]:
        print(f"\n-> {ds_class.__name__}")
        ds = ds_class(split="train")
        x1, y1 = ds[random.randint(0, len(ds))]
        x2, y2 = ds[random.randint(0, len(ds))]

        for (x, y_corr), y_pred in itertools.product([(x1, y1), (x2, y2)], [y1, y2]):
            print(f"   x='{x}'\n   y='{y_pred[0]}'")
            for m, v in ds.evaluate(y_pred=[y_pred[0]], y_true=[y_corr]).items():
                print(f"   {m} = {v}")
            print()

        # measure how many tokens is the longest sequence and how to long is the 95th percentile
        # we will use this to setup the training parameters
        n_tokens = np.zeros(min(20_000, len(ds)))
        for i in tqdm.tqdm(range(min(20_000, len(ds)))):
            x, y = ds[i]
            n_tokens[i] = len(
                tokenizer.encode(f"{ds.x_introduction} {x} {ds.y_introduction} {y}")
            )

        print(
            f" Average number of tokens: {np.mean(n_tokens):.2f}; 95th percentile: {int(np.percentile(n_tokens, 95))}; Max: {int(np.max(n_tokens))}"
        )

    print("TEST THE TABLE-TO-TEXT DATASETS:")

    for ds_class in [E2E, DART]:
        print(f"\n-> {ds_class.__name__}")

        ds = ds_class(split="train")
        x1, y1 = ds[random.randint(0, len(ds))]
        x2, y2 = ds[random.randint(0, len(ds))]

        for (x, y_corr), y_pred in itertools.product([(x1, y1), (x2, y2)], [y1, y2]):
            print(f"   x='{x}'\n   y='{y_pred[0]}'")
            for m, v in ds.evaluate(y_pred=[y_pred[0]], y_true=[y_corr]).items():
                print(f"   {m} = {v}")
            print()

        # measure how many tokens is the longest sequence and how to long is the 95th percentile
        # we will use this to setup the training parameters
        n_tokens = np.zeros(min(20_000, len(ds)))
        for i in tqdm.tqdm(range(min(20_000, len(ds)))):
            x, y = ds[i]
            n_tokens[i] = len(
                tokenizer.encode(f"{ds.x_introduction} {x} {ds.y_introduction} {y[0]}")
            )

        print(
            f" Average number of tokens: {np.mean(n_tokens):.2f}; 95th percentile: {int(np.percentile(n_tokens, 95))}; Max: {int(np.max(n_tokens))}"
        )
