from dataclasses import dataclass
import abc
from typing import List, Optional, Callable, Union, Tuple, Dict, Any

from datasets import load_dataset
import numpy as np
import torch

@dataclass
class TaskConfig:
    dataset: str
    subset: Optional[str]
    name: str
    evaluation_split: str
    fewshot_split: str

class Task:
    def __init__(self):
        self.dataset = (load_dataset(self.config.dataset, self.config.subset)
                        if self.config.subset is not None else load_dataset(self.config.dataset))
        self.doc_to_text: Callable = None
        self.doc_to_target: Callable = None
        self.fewshot_space: str = None
        self.sep: str = None
        self.loss: Callable = None
        self.fewshot: int = None
        self.fewshot_seed: int = None
        self.can_be_token_separable: bool = False

    @abc.abstractproperty
    @property
    def config(self) -> TaskConfig:
        pass

    @staticmethod
    def loss_function(logits, targets, *args, **kwargs):
        raise NotImplementedError()

    FEWSHOT_QUERY_TYPES = [0, 7, 8, 16, 17, 18, 19, 20, 21, 22]
    SEPARATOR_TYPES = [1, 12, 13, 30, 31, 32, 33, 34, 35, 36,    9, 14, 15, 37, 38, 39, 40, 41, 42, 43]
    FEWSHOT_TARGET_TYPES = [2, 10, 11, 23, 24, 25, 26, 27, 28, 29]
    QUERY_TYPE = 3
    LAST_SEP_TYPE = 4
    PAD_BOS_TYPE = 5
    TARGET_TYPE = 6

    WORDS_PER_NUM_OF_TOKENS = {1: ['exempt', 'generators', 'broker', 'amide', 'nail', 'covered', 'where', 'capes', 'announce', 'any'], 2: ['norsemen', 'keysters', 'tamlung', 'beeswings', 'misreference', 'nerts', 'hetaery', 'pensacola', 'withies', 'jerl'], 4: ['ferronickel', 'triozonide', 'undecisively', 'chamaerops', 'preadjectival', 'epidemiographist', 'ebracteate', 'unquizzical', 'slapdashes', 'epicoelous'], 3: ['nothosaurus', 'coranoch', 'socotri', 'dramalogue', 'saponify', 'linalool', 'collegatary', 'culverkey', 'incohering', 'autophonous'], 6: ['calciovolborthite', 'cholecystostomies', 'asymmetranthous', 'uncontumaciousness', 'dodecasyllabic', 'acetoacetanilide', 'cricothyroidean', 'hypocraterimorphous', 'lepospondylous', 'vinosulphureous'], 5: ['rhipipterous', 'ventriloquised', 'syphiliphobia', 'redissolubleness', 'ochlophobist', 'parapsychologists', 'perfectibilist', 'semicolloquially', 'sphenomandibular', 'pycnometochic'], 7: ['pericardiosymphysis', 'uranostaphylorrhaphy']}
    WORDS_PER_NUM_OF_TOKENS_IN_BEGINNING = {2: ['exempt', 'generators', 'broker', 'amide', 'nail', 'covered', 'where', 'capes', 'announce', 'any'], 3: ['norsemen', 'keysters', 'tamlung', 'beeswings', 'misreference', 'nerts', 'hetaery', 'pensacola', 'withies', 'jerl'], 5: ['ferronickel', 'triozonide', 'undecisively', 'chamaerops', 'preadjectival', 'epidemiographist', 'ebracteate', 'unquizzical', 'slapdashes', 'epicoelous'], 4: ['nothosaurus', 'coranoch', 'socotri', 'dramalogue', 'saponify', 'linalool', 'collegatary', 'culverkey', 'incohering', 'autophonous'], 7: ['calciovolborthite', 'cholecystostomies', 'asymmetranthous', 'uncontumaciousness', 'dodecasyllabic', 'acetoacetanilide', 'cricothyroidean', 'hypocraterimorphous', 'lepospondylous', 'vinosulphureous'], 6: ['rhipipterous', 'ventriloquised', 'syphiliphobia', 'redissolubleness', 'ochlophobist', 'parapsychologists', 'perfectibilist', 'semicolloquially', 'sphenomandibular', 'pycnometochic'], 8: ['pericardiosymphysis', 'uranostaphylorrhaphy']}

    # 0 - input, 1 - separator (sep or fewshot), 2 - output, 3 - query,
    # 4 - last separator, the one before the answer, 5 - pad/bos
    # 6 - target, only in get_token_types_for_contexts_with_targets

    # This method can ONLY be used to parse context string, namely a string in the format
    # input(sep)output(sep)input(sep)output(sep)query(sep)
    def get_token_types_for_contexts(self, tokenizer, tokenized_contexts_batch):
        assert self.can_be_token_separable

        special_token = "<special-token-we-will-never-encounter-in-a-dataset>"
        tokenizer.add_tokens([special_token])

        contexts_tp_inds = torch.empty_like(tokenized_contexts_batch)
        tokenized_sep = tokenizer.convert_ids_to_tokens(tokenizer(special_token + self.sep, add_special_tokens=False, return_tensors="pt")["input_ids"][0])[1]
        tokenized_fewshot_space = tokenizer.convert_ids_to_tokens(tokenizer(special_token + self.fewshot_space, add_special_tokens=False, return_tensors="pt")["input_ids"][0])[1]
        for context_i in range(tokenized_contexts_batch.shape[0]):
            detokenized_context = tokenizer.convert_ids_to_tokens(tokenized_contexts_batch[context_i])
            num_fewshot_had = 0
            next_is_input = True
            for token_i, token in enumerate(detokenized_context):
                if token == tokenized_sep or token == tokenized_fewshot_space:
                    if num_fewshot_had == self.fewshot:
                        contexts_tp_inds[context_i, token_i] = 4
                    else:
                        if token == tokenized_sep:
                            contexts_tp_inds[context_i, token_i] = self.SEPARATOR_TYPES[num_fewshot_had]
                        else:
                            contexts_tp_inds[context_i, token_i] = self.SEPARATOR_TYPES[len(self.SEPARATOR_TYPES) // 2 + num_fewshot_had]
                    next_is_input = (not next_is_input)
                    if next_is_input:
                        num_fewshot_had += 1
                elif token == tokenizer.pad_token or token == tokenizer.bos_token:
                    contexts_tp_inds[context_i, token_i] = 5
                else:
                    if num_fewshot_had == self.fewshot:
                        contexts_tp_inds[context_i, token_i] = 3
                    else:
                        contexts_tp_inds[context_i, token_i] = \
                            (self.FEWSHOT_QUERY_TYPES[num_fewshot_had] if next_is_input else self.FEWSHOT_TARGET_TYPES[num_fewshot_had])
        return contexts_tp_inds
    
    # This method can ONLY be used to parse context string, namely a string in the format
    # input(sep)output(sep)input(sep)output(sep)query(sep)target
    def get_token_types_for_contexts_with_targets(self, tokenizer, tokenized_contexts_batch):
        assert self.can_be_token_separable

        special_token = "<special-token-we-will-never-encounter-in-a-dataset>"
        tokenizer.add_tokens([special_token])

        contexts_tp_inds = torch.empty_like(tokenized_contexts_batch)
        tokenized_sep = tokenizer.convert_ids_to_tokens(tokenizer(special_token + self.sep, add_special_tokens=False, return_tensors="pt")["input_ids"][0])[1]
        tokenized_fewshot_space = tokenizer.convert_ids_to_tokens(tokenizer(special_token + self.fewshot_space, add_special_tokens=False, return_tensors="pt")["input_ids"][0])[1]
        for context_i in range(tokenized_contexts_batch.shape[0]):
            detokenized_context = tokenizer.convert_ids_to_tokens(tokenized_contexts_batch[context_i])
            num_fewshot_had = 0
            next_is_input = True
            for token_i, token in enumerate(detokenized_context):
                if token == tokenized_sep or token == tokenized_fewshot_space:
                    if num_fewshot_had == self.fewshot:
                        contexts_tp_inds[context_i, token_i] = 4
                    else:
                        if token == tokenized_sep:
                            contexts_tp_inds[context_i, token_i] = self.SEPARATOR_TYPES[num_fewshot_had]
                        else:
                            contexts_tp_inds[context_i, token_i] = self.SEPARATOR_TYPES[len(self.SEPARATOR_TYPES) // 2 + num_fewshot_had]
                    next_is_input = (not next_is_input)
                    if next_is_input:
                        num_fewshot_had += 1
                elif token == tokenizer.pad_token or token == tokenizer.bos_token:
                    contexts_tp_inds[context_i, token_i] = 5
                else:
                    if num_fewshot_had == self.fewshot:
                        contexts_tp_inds[context_i, token_i] = (3 if next_is_input else 6)
                    else:
                        contexts_tp_inds[context_i, token_i] = \
                            (self.FEWSHOT_QUERY_TYPES[num_fewshot_had] if next_is_input else self.FEWSHOT_TARGET_TYPES[num_fewshot_had])
        return contexts_tp_inds

    def evaluation_docs(self, limit = None) -> List[dict]:
        if limit is not None and len(self.dataset[self.config.evaluation_split]) >= limit:
            return self.dataset[self.config.evaluation_split].select(range(limit))
        return self.dataset[self.config.evaluation_split]

    def fewshot_docs(self, number_of_contexts) -> List[List[dict]]:
        np.random.seed(self.fewshot_seed)
        docs = self.dataset[self.config.fewshot_split]
        inds = [np.random.choice(list(range(len(docs))), self.fewshot, replace=False)
                for _ in range(number_of_contexts)]

        return [docs.select(inds[i]) for i in range(number_of_contexts)]

    def get_docs(self, limit = None) -> Tuple[List[dict], List[List[dict]]]:
        eval_docs = self.evaluation_docs(limit)
        fewshot_docs = self.fewshot_docs(len(eval_docs))
        return eval_docs, fewshot_docs
    
    def get_docs_in_templates(self, limit = None, corrupted = False) -> List[Dict[str, str]]:
        if corrupted:
            raise NotImplementedError()
        else:
            eval_docs, fewshot_docs = self.get_docs(limit)
        contexts = [self.fewshot_space.join([self.doc_to_text(doc) + self.sep + self.doc_to_target(doc) for doc in docs]) + self.fewshot_space + self.doc_to_text(eval_doc) + self.sep
                    for docs, eval_doc in zip(fewshot_docs, eval_docs)]
        targets = [self.doc_to_target(eval_doc) for eval_doc in eval_docs]
        return [
        {
            "context": context,
            "target": target
        }
        for context, target in zip(contexts, targets)
        ]

