import random
import string
from typing import Dict, Iterable, List

from core.data.tasks.increment_task import IncrementTask
from core.data.tasks.list_operation_task import ListOperationTask
from core.data.tasks.token_operation_task import TokenOprationTask
from core.data.tasks.mapping_task import MappingTask, BijectionTask
from core.data.tasks.translation_task import TranslationTask
from core.data.tasks.task import Task

from transformers import PreTrainedTokenizer


ALL_TASKS = {
    # Algorithmic
    "algorithmic_copy": {
        "task_type": "token_operation",
        "task_kwargs": {"operation": "copy", "input_space": list(string.ascii_lowercase) + list(string.ascii_uppercase)},
    },
    "algorithmic_next_letter": {
        "task_type": "increment",
        "task_kwargs": {"increment": +1},
    },
    "algorithmic_prev_letter": {
        "task_type": "increment",
        "task_kwargs": {"increment": -1},
    },
    "algorithmic_list_first": {
        "task_type": "list_operation",
        "task_kwargs": {"operation": "first", "list_lenghts": range(2, 5)},
    },
    "algorithmic_list_last": {
        "task_type": "list_operation",
        "task_kwargs": {"operation": "last", "list_lenghts": range(2, 5)},
    },
    "algorithmic_list_min": {
        "task_type": "list_operation",
        "task_kwargs": {"operation": "min", "list_lenghts": range(2, 5), "elements_space": list(string.digits)},
    },
    "algorithmic_list_max": {
        "task_type": "list_operation",
        "task_kwargs": {"operation": "max", "list_lenghts": range(2, 5), "elements_space": list(string.digits)},
    },
    "algorithmic_list_length": {
        "task_type": "list_operation",
        "task_kwargs": {"operation": "length", "list_lenghts": range(1, 4)},
    },
    "algorithmic_to_upper": {
        "task_type": "token_operation",
        "task_kwargs": {"operation": "to_upper", "input_space": list(string.ascii_lowercase)},
    },
    "algorithmic_to_lower": {
        "task_type": "token_operation",
        "task_kwargs": {"operation": "to_lower", "input_space": list(string.ascii_uppercase)},
    },
    "algorithmic_char_to_int": {
        "task_type": "token_operation",
        "task_kwargs": {"operation": "char_to_int", "input_space": list(string.ascii_lowercase[:9])},
    },  # low performance
    "algorithmic_int_to_char": {
        "task_type": "token_operation",
        "task_kwargs": {"operation": "int_to_char", "input_space": list(string.digits[1:])},
    },
    # Translation
    "translation_fr_en": {
        "task_type": "translation",
        "task_kwargs": {"mapping_type": "translation", "mapping_name": "fr_en"},
    },
    "translation_it_en": {
        "task_type": "translation",
        "task_kwargs": {"mapping_type": "translation", "mapping_name": "it_en"},
    },
    "translation_es_en": {
        "task_type": "translation",
        "task_kwargs": {"mapping_type": "translation", "mapping_name": "es_en"},
    },
    "translation_en_fr": {
        "task_type": "translation",
        "task_kwargs": {"mapping_type": "translation", "mapping_name": "en_fr"},
    },
    "translation_en_it": {
        "task_type": "translation",
        "task_kwargs": {"mapping_type": "translation", "mapping_name": "en_it"},
    },
    "translation_en_es": {
        "task_type": "translation",
        "task_kwargs": {"mapping_type": "translation", "mapping_name": "en_es"},
    },
    # Linguistic
    "linguistic_present_simple_gerund": {
        "task_type": "mapping",
        "task_kwargs": {"mapping_type": "linguistic", "mapping_name": "present_simple_gerund"},
    },
    "linguistic_gerund_present_simple": {
        "task_type": "mapping",
        "task_kwargs": {"mapping_type": "linguistic", "mapping_name": "gerund_present_simple"},
    },
    "linguistic_present_simple_past_simple": {
        "task_type": "mapping",
        "task_kwargs": {"mapping_type": "linguistic", "mapping_name": "present_simple_past_simple"},
    },
    "linguistic_past_simple_present_simple": {
        "task_type": "mapping",
        "task_kwargs": {"mapping_type": "linguistic", "mapping_name": "past_simple_present_simple"},
    },
    "linguistic_present_simple_past_perfect": {
        "task_type": "mapping",
        "task_kwargs": {"mapping_type": "linguistic", "mapping_name": "present_simple_past_perfect"},
    },
    "linguistic_past_perfect_present_simple": {
        "task_type": "mapping",
        "task_kwargs": {"mapping_type": "linguistic", "mapping_name": "past_perfect_present_simple"},
    },
    "linguistic_singular_plural": {
        "task_type": "mapping",
        "task_kwargs": {"mapping_type": "linguistic", "mapping_name": "singular_plural"},
    },
    "linguistic_plural_singular": {
        "task_type": "mapping",
        "task_kwargs": {"mapping_type": "linguistic", "mapping_name": "plural_singular"},
    },
    "linguistic_antonyms": {
        "task_type": "mapping",
        "task_kwargs": {"mapping_type": "linguistic", "mapping_name": "antonyms"},
    },
    # Knowledge
    "knowledge_country_capital": {
        "task_type": "mapping",
        "task_kwargs": {"mapping_type": "knowledge", "mapping_name": "country_capital", "allow_prefix": True},
    },
    "knowledge_person_language": {
        "task_type": "mapping",
        "task_kwargs": {"mapping_type": "knowledge", "mapping_name": "person_language", "allow_prefix": True},
    },
    "knowledge_location_continent": {
        "task_type": "mapping",
        "task_kwargs": {"mapping_type": "knowledge", "mapping_name": "location_continent", "allow_prefix": True},
    },
    "knowledge_location_religion": {
        "task_type": "mapping",
        "task_kwargs": {"mapping_type": "knowledge", "mapping_name": "location_religion", "allow_prefix": True},
    },
    # "sentiment": {
    #     "task_type": "sentiment",
    #     "task_kwargs": {"allow_prefix": True},
    # },
    "bijection_algorithmic_lower_upper": {
        "task_type": "mixed",
        "task_kwargs": {"tasks": ["algorithmic_to_upper", "algorithmic_to_lower"]},
    },
    "bijection_algorithmic_char_int": {
        "task_type": "mixed",
        "task_kwargs": {"tasks": ["algorithmic_char_to_int", "algorithmic_int_to_char"]},
    },
    "bijection_translation_fr_en": {
        "task_type": "mixed",
        "task_kwargs": {"tasks": ["translation_fr_en", "translation_en_fr"]},
    },
    "bijection_translation_it_en": {
        "task_type": "mixed",
        "task_kwargs": {"tasks": ["translation_it_en", "translation_en_it"]},
    },
    "bijection_translation_es_en": {
        "task_type": "mixed",
        "task_kwargs": {"tasks": ["translation_es_en", "translation_en_es"]},
    },
    "bijection_linguistic_present_simple_gerund": {
        "task_type": "bijection",
        "task_kwargs": {"mapping_type": "linguistic", "mapping_name": "present_simple_gerund"},
    },
    "bijection_linguistic_present_simple_past_simple": {
        "task_type": "bijection",
        "task_kwargs": {"mapping_type": "linguistic", "mapping_name": "present_simple_past_simple"},
    },
    "bijection_linguistic_present_simple_past_perfect": {
        "task_type": "bijection",
        "task_kwargs": {"mapping_type": "linguistic", "mapping_name": "present_simple_past_perfect"},
    },
    "bijection_linguistic_plural_singular": {
        "task_type": "bijection",
        "task_kwargs": {"mapping_type": "linguistic", "mapping_name": "plural_singular"},
    },
    "bijection_linguistic_antonyms": {
        "task_type": "bijection",
        "task_kwargs": {"mapping_type": "linguistic", "mapping_name": "antonyms"},
    },
}


def get_task(task_type: str, task_kwargs: Dict[str, str], tokenizer: PreTrainedTokenizer) -> Task:
    task = TASK_TYPE_TO_CLASS[task_type](**task_kwargs, tokenizer=tokenizer)
    return task


def get_task_by_name(tokenizer: PreTrainedTokenizer, task_name: str) -> Task:
    task_args = ALL_TASKS[task_name]
    task = get_task(task_args["task_type"], task_args["task_kwargs"], tokenizer)
    return task


def get_all_tasks(tokenizer: PreTrainedTokenizer) -> Dict[str, Task]:
    tasks = {task_name: get_task_by_name(tokenizer, task_name) for task_name in ALL_TASKS}
    return tasks


class MixedTask(Task):
    def __init__(
        self,
        tokenizer: PreTrainedTokenizer,
        tasks
    ):
        super().__init__(tokenizer)
        self.tasks = []
        for task_name in tasks:
            self.tasks.append(get_task_by_name(tokenizer, task_name))

    def sample_inputs(self, num_inputs: int, exclude: Iterable[str] = ()) -> List[str]:
        input_space = self.input_space
        return random.sample(sorted(set(input_space) - set(exclude)), num_inputs)

    @property
    def input_space(self) -> List[int]:
        input_space = []
        for task in self.tasks:
            input_space += task.input_space
        return input_space

    def calc_output(self, inp) -> int:
        for task in self.tasks:
            if inp in task.input_space:
                return task.calc_output(inp)

    def num_examples(self) -> int:
        return len(self.input_space)


TASK_TYPE_TO_CLASS = {
    "increment": IncrementTask,
    "list_operation": ListOperationTask,
    "token_operation": TokenOprationTask,
    "mapping": MappingTask,
    "translation": TranslationTask,
    # "sentiment": SentimentTask,
    "mixed": MixedTask,
    "bijection": BijectionTask,
}