import os
from pathlib import Path
from typing import Callable, List, Dict
from datasets import load_from_disk
import torch
import random
import numpy as np

from patching_gemma.tasks.task import Task, TaskConfig

class PresentPastAmbiguousTask(Task):
    ALLOWED = {
        "n_shot": [3, 10],
        "num_ambiguous": [i for i in range(11)],
        "where_abmiguous": ["beginning", "middle", "end", "random"]
    }
    def __init__(self, n_shot, num_ambiguous, where_abmiguous):
        self.doc_to_text: Callable = lambda doc : str(doc["input"])
        self.doc_to_target: Callable = lambda doc : str(doc["target"])
        self.fewshot_space: str = "\n"
        self.sep: str = "\t"
        self.loss: Callable = self.loss_function
        self.fewshot: int = n_shot
        self.fewshot_seed: int = 42
        self.can_be_token_separable: bool = True

        self.num_ambiguous = num_ambiguous
        self.where_abmiguous = where_abmiguous
        assert self.where_abmiguous in ["beginning", "middle", "end", "random"]
        
        self.dataset = load_from_disk(self.config.dataset)

        self.TOKEN_TYPES = 16 if self.fewshot == 3 else (44 if self.fewshot == 10 else -1)
        assert self.TOKEN_TYPES != -1

    @staticmethod
    def loss_function(logits, targets, lens, tp_inds):
        predictive_inds = ((tp_inds == PresentPastAmbiguousTask.TARGET_TYPE) |
                          (tp_inds == PresentPastAmbiguousTask.LAST_SEP_TYPE))
        predictive_inds[torch.arange(predictive_inds.shape[0]), (predictive_inds.cumsum(dim=-1) * predictive_inds).argmax(dim=-1)] = 0
        return torch.nn.functional.cross_entropy(logits[predictive_inds, :], targets)

    @staticmethod
    def get_name(n_shot, num_ambiguous, where_abmiguous):
        return f"present_past_{num_ambiguous}_{where_abmiguous}_ambiguous_{n_shot}_shot"
        
    @property
    def config(self) -> TaskConfig:
        return TaskConfig(
            dataset=str(Path(os.path.dirname(os.path.realpath(__file__))).parent.parent.joinpath("datasets").joinpath("present_past")),
            subset=None,
            name=PresentPastAmbiguousTask.get_name(self.fewshot, self.num_ambiguous, self.where_abmiguous),
            evaluation_split="test",
            fewshot_split="train"
        )
        
    AMBIGUOUS_EXAMPLES = [
        "bet", "broadcast", "burst", "cost", "cut", "fit", 
        "hit", "hurt", "let", "put", "read", "set", "shut", "split", "spread", "upset",
        "quit", "wet", "bid", "shed"
    ]

    def fewshot_docs(self, number_of_contexts) -> List[List[dict]]:
        np.random.seed(self.fewshot_seed)
        docs = self.dataset[self.config.fewshot_split]
        docs = docs.filter(lambda doc: doc["input"] != doc["target"])
        inds = [np.random.choice(list(range(len(docs))), self.fewshot, replace=False)
                for _ in range(number_of_contexts)]
        
        if self.where_abmiguous == "random":
            inds_of_ambiguous = [np.random.choice(list(range(self.fewshot)), size=self.num_ambiguous, replace=False)
                                    for _ in range(number_of_contexts)]
        elif self.where_abmiguous == "beginning":
            inds_of_ambiguous = [[k for k in range(self.num_ambiguous)] for _ in range(number_of_contexts)]
        elif self.where_abmiguous == "middle":
            inds_of_ambiguous = [[self.fewshot // 2 - self.num_ambiguous // 2 + k for k in range(self.num_ambiguous)] for _ in range(number_of_contexts)]
        elif self.where_abmiguous == "end":
            inds_of_ambiguous = [[self.fewshot - k - 1 for k in range(self.num_ambiguous)] for _ in range(number_of_contexts)]
        else:
            raise ValueError()
    
        ambiguous_examples = [np.random.choice(self.AMBIGUOUS_EXAMPLES, size=self.num_ambiguous, replace=False)
            for i in range(number_of_contexts)]
        ambiguous_examples = [
            {
                ind: ambiguous_examples[i][ind_i].item()
                for ind_i, ind in enumerate(inds_of_ambiguous[i])
            }
            for i in range(number_of_contexts)
        ]

        return [[docs[inds[i][j].item()] if j not in inds_of_ambiguous[i] 
                    else {"input": ambiguous_examples[i][j], "target": ambiguous_examples[i][j]}
                 for j in range(self.fewshot)]
                for i in range(number_of_contexts)]
    
    
    def get_docs_in_templates(self, limit = None, corrupted = False, tokenizer = None) -> List[Dict[str, str]]:
        eval_docs, fewshot_docs = self.get_docs(limit)
        contexts = [self.fewshot_space.join([self.doc_to_text(doc) + self.sep + self.doc_to_target(doc) for doc in docs]) + self.fewshot_space + self.doc_to_text(eval_doc) + self.sep
                    for docs, eval_doc in zip(fewshot_docs, eval_docs)]
        targets = [self.doc_to_target(eval_doc) for eval_doc in eval_docs]
        if corrupted:
            corrupted_fewshot_docs = []
            special_token = "<special-token-we-will-never-encounter-in-a-dataset>"
            tokenizer.add_tokens([special_token])
            np.random.seed(self.fewshot_seed)
            for context in fewshot_docs:
                num_tokens_per_query = [len(tokenizer(special_token + item["input"], add_special_tokens=False, return_tensors="pt")["input_ids"][0]) - 1 for item in context]
                num_tokens_per_query[0] = len(tokenizer(context[0]["input"], add_special_tokens=True, return_tensors="pt")["input_ids"][0])
                num_tokens_per_target = [len(tokenizer(special_token + item["target"], add_special_tokens=False, return_tensors="pt")["input_ids"][0]) - 1 for item in context]
                corrupted_queries = [np.random.choice(self.WORDS_PER_NUM_OF_TOKENS[num_tok]) for num_tok in num_tokens_per_query]
                corrupted_queries[0] = np.random.choice(self.WORDS_PER_NUM_OF_TOKENS_IN_BEGINNING[num_tokens_per_query[0]])
                corrupted_targets = [np.random.choice(self.WORDS_PER_NUM_OF_TOKENS[num_tok]) for num_tok in num_tokens_per_target]
                new_items = [{"input" : corrupted_queries[i], "target": corrupted_targets[i]} for i in range(len(corrupted_queries))]
                corrupted_fewshot_docs.append(new_items)
                corrupted_contexts = [self.fewshot_space.join([self.doc_to_text(doc) + self.sep + self.doc_to_target(doc) for doc in docs]) + self.fewshot_space + self.doc_to_text(eval_doc) + self.sep
                    for docs, eval_doc in zip(corrupted_fewshot_docs, eval_docs)]
            return [
            {
                "context": context,
                "target": target,
                "corrupted_context": corrupted_context,
                "corrupted_target": target,
            }
            for context, target, corrupted_context in zip(contexts, targets, corrupted_contexts)
            ]
        return [
            {
                "context": context,
                "target": target
            }
            for context, target in zip(contexts, targets)
            ]