import numpy as np
from datasets import load_dataset


def icl_prompt(datum, dataset, samples_idx, prompt_type, task_info):
    prompt_header = {
        "none": "",
        "description": task_info["description"],
        "demo": "Demonstrations:\n",
        "type-only": "",
    }[prompt_type]

    if prompt_type in ["description", "type-only"]:
        output_prompt = task_info["output_type"]
    else:
        output_prompt = "Output"

    input_output = [
        f"Input: {dataset[_idx]['input']}. \n"
        f"{output_prompt}: {dataset[_idx]['output']}. \n"
        for _idx in samples_idx
    ]
    prompt_newinput = f"Input: {datum['input']}. {output_prompt}:"

    prompt = prompt_header + "".join(input_output) + prompt_newinput

    return prompt


def get_icl_samples(dataset, sampling, num_samples):
    icl_samples = []
    n_data = len(dataset)
    for _idx, _datum in enumerate(dataset):
        all_idx = list(range(n_data))
        all_idx.remove(_idx)
        if sampling == "uniform":
            sampled = np.random.choice(all_idx, num_samples)
        else:
            raise NotImplementedError()
        icl_samples.append(sampled)
    return icl_samples


class glue_single_classification:
    pass


class cola(glue_single_classification):
    task_info = {
        "description":
        "This task is to predict the grammar validity of the input sentence",
        "output_type": "Acceptability"
    }

    @staticmethod
    def convert_datum(_datum):
        return {
            "original": _datum,
            "input": _datum["sentence"],
            "output": ["false", "true"][_datum["label"]]
        }

    @staticmethod
    def convert_output(output):
        return {
            "true": 1,
            "false": 0,
        }.get(output, None)


class sst2(glue_single_classification):
    task_info = {
        "description":
        "This task is to predict the sentiment of the input sentence",
        "output_type": "Sentiment"
    }

    @staticmethod
    def convert_datum(_datum):
        return {
            "original": _datum,
            "input": _datum["sentence"],
            "output": ["negative", "positive"][_datum["label"]]
        }

    @staticmethod
    def convert_output(output):
        return {
            "positive": 1,
            "negative": 0,
        }.get(output, None)


class mnli(glue_single_classification):
    task_info = {
        "description":
        "This task is to predict the entailment of the two sentences",
        "output_type": "Entailment"
    }

    @staticmethod
    def convert_datum(_datum):
        return {
            "original": _datum,
            "input": f"Premise: {_datum['premise']}. "
            f"Hypothesis: {_datum['hypothesis']}",
            "output": ["entailment", "neutral", "contradiction"][
                _datum["label"]]
        }

    @staticmethod
    def convert_output(output):
        return {
            "entail": 0,
            "entailment": 0,
            "neutral": 1,
            "contradiction": 2,
        }.get(output, None)


class imdb:
    task_info = {
        "description":
        "This task is to rating in review",
        "output_type": "Review"
    }

    @staticmethod
    def convert_datum(_datum):
        return {
            "original": _datum,
            "input": _datum["text"][:100],
            "output": ["negative", "positive"][_datum["label"]]
        }

    @staticmethod
    def convert_output(output):
        return {
            "positive": 1,
            "negative": 0,
        }.get(output, None)


class tweet_eval:
    task_info = {
        "description":
        "This task is to classify a sentence",
        "output_type": "Label"
    }

    @staticmethod
    def convert_datum(_datum):
        return {
            "original": _datum,
            "input": _datum["text"],
            "output": ["positive", "negative"][_datum["label"]]
        }

    @staticmethod
    def convert_output(output):
        return {
            "positive": 0,
            "negative": 1,
        }.get(output, None)


class tweet_hate(tweet_eval):
    task_info = {
        "description": "This task classifies if sentence contains hate speech",
        "output_type": "Hate speech"
    }


class tweet_irony(tweet_eval):
    task_info = {
        "description": "This task classifies if sentence is ironic",
        "output_type": "Ironic"
    }


class tweet_offensive(tweet_eval):
    task_info = {
        "description": "This task classifies if sentence is offensive",
        "output_type": "Offensiveness"
    }


def rotten_tomatoes_dataset(split, subset, sampling, num_samples, prompt_type):
    dataset = load_dataset("rotten_tomatoes", split=split)
    dataset_cls = imdb
    dataset = [
        dataset_cls.convert_datum(_datum)
        for _datum in dataset
    ]
    if subset is not None:
        # dataset = dataset[:subset]
        subset_index = np.random.choice(range(len(dataset)), subset,
                                        replace=False)
        dataset = [dataset[_i] for _i in subset_index]
    icl_samples = get_icl_samples(dataset, sampling, num_samples)

    for _datum, _samples_idx in zip(dataset, icl_samples):
        _datum["icl_prompt"] = icl_prompt(
            _datum, dataset, _samples_idx, prompt_type, dataset_cls.task_info)
        _datum["sample_index"] = _samples_idx

    return dataset, dataset_cls


def imdb_dataset(split, subset, sampling, num_samples, prompt_type):
    dataset = load_dataset("imdb", split=split)
    dataset_cls = imdb
    dataset = [
        dataset_cls.convert_datum(_datum)
        for _datum in dataset
    ]
    if subset is not None:
        # dataset = dataset[:subset]
        subset_index = np.random.choice(range(len(dataset)), subset,
                                        replace=False)
        dataset = [dataset[_i] for _i in subset_index]
    icl_samples = get_icl_samples(dataset, sampling, num_samples)

    for _datum, _samples_idx in zip(dataset, icl_samples):
        _datum["icl_prompt"] = icl_prompt(
            _datum, dataset, _samples_idx, prompt_type, dataset_cls.task_info)
        _datum["sample_index"] = _samples_idx

    return dataset, dataset_cls


def glue_dataset(dataset_name, split, subset, sampling,
                 num_samples, prompt_type):
    group = dataset_name[5:]
    dataset = load_dataset("glue", split=split)
    dataset_cls = {
        "cola": cola,
        "sst2": sst2,
        "mnli_matched": mnli,
    }[group]
    dataset = [
        dataset_cls.convert_datum(_datum)
        for _datum in dataset
    ]
    if subset is not None:
        dataset = dataset[:subset]
    icl_samples = get_icl_samples(dataset, sampling, num_samples)

    for _datum, _samples_idx in zip(dataset, icl_samples):
        _datum["icl_prompt"] = icl_prompt(
            _datum, dataset, _samples_idx, prompt_type, dataset_cls.task_info)
        _datum["sample_index"] = _samples_idx

    return dataset, dataset_cls


def tweet_eval_dataset(dataset_name, split, subset, sampling,
                       num_samples, prompt_type):
    group = dataset_name[11:]
    dataset = load_dataset("tweet_eval", group, split=split)
    dataset_cls = {
        "hate": tweet_hate,
        "irony": tweet_irony,
        "offensive": tweet_offensive,
    }[group]
    dataset = [
        dataset_cls.convert_datum(_datum)
        for _datum in dataset
    ]
    if subset is not None:
        dataset = dataset[:subset]
    icl_samples = get_icl_samples(dataset, sampling, num_samples)

    for _datum, _samples_idx in zip(dataset, icl_samples):
        _datum["icl_prompt"] = icl_prompt(
            _datum, dataset, _samples_idx, prompt_type, dataset_cls.task_info)
        _datum["sample_index"] = _samples_idx

    return dataset, dataset_cls


def get_icl_dataset(dataset_name, split, subset, sampling,
                    num_samples, prompt_type):
    if dataset_name.startswith("glue-"):
        dataset, dataset_cls = glue_dataset(
            dataset_name, split, subset, sampling, num_samples, prompt_type)
    elif dataset_name == "imdb":
        dataset, dataset_cls = imdb_dataset(
            split, subset, sampling, num_samples, prompt_type)
    elif dataset_name == "rotten_tomatoes":
        dataset, dataset_cls = rotten_tomatoes_dataset(
            split, subset, sampling, num_samples, prompt_type)
    elif dataset_name.startswith("tweet_eval-"):
        dataset, dataset_cls = tweet_eval_dataset(
            dataset_name, split, subset, sampling, num_samples, prompt_type)
    else:
        raise NotImplementedError()
    return dataset, dataset_cls
