import os
from pathlib import Path
from typing import Callable, List, Dict
from datasets import load_from_disk
import torch
import random
import numpy as np

from patching_gemma.tasks.task import Task, TaskConfig

class PresentPastAmbiguousDifferentCorruptionsTask(Task):
    ALLOWED = {
        "n_shot": [3, 10],
        "num_ambiguous": [i for i in range(11)],
        "where_abmiguous": ["beginning", "middle", "end", "random"],
        "corruption": [
            "x_in_all_fewshots_agreed",
            "y_in_all_fewshots_agreed",
            "x_within_input_space_in_all_fewshots_agreed",
            "y_within_input_space_in_all_fewshots_agreed",
            "full_corruption_in_all_fewshots_agreed",
            "pp_task_in_all_fewshots_agreed",
        ],
    }
    def __init__(self, n_shot, num_ambiguous, where_abmiguous, corruption):
        assert corruption in PresentPastAmbiguousDifferentCorruptionsTask.ALLOWED["corruption"]
        self.doc_to_text: Callable = lambda doc : str(doc["input"])
        self.doc_to_target: Callable = lambda doc : str(doc["target"])
        self.fewshot_space: str = "\n"
        self.sep: str = "\t"
        self.loss: Callable = self.loss_function
        self.fewshot: int = n_shot
        self.fewshot_seed: int = 42
        self.can_be_token_separable: bool = True

        self.num_ambiguous = num_ambiguous
        self.where_abmiguous = where_abmiguous
        assert self.where_abmiguous in ["beginning", "middle", "end", "random"]
        self.corruption = corruption
        self.NUM_CORRUPTIONS = {
                                    "x_in_all_fewshots_agreed": 1,
                                    "y_in_all_fewshots_agreed": 1,
                                    "x_within_input_space_in_all_fewshots_agreed": 1,
                                    "y_within_input_space_in_all_fewshots_agreed": 1,
                                    "full_corruption_in_all_fewshots_agreed": 1,
                                    "pp_task_in_all_fewshots_agreed": 1,
                               }[self.corruption]
        
        self.dataset = load_from_disk(self.config.dataset)

        self.TOKEN_TYPES = 16 if self.fewshot == 3 else (44 if self.fewshot == 10 else -1)
        assert self.TOKEN_TYPES != -1

    @staticmethod
    def loss_function(logits, targets, lens, tp_inds):
        predictive_inds = ((tp_inds == PresentPastAmbiguousDifferentCorruptionsTask.TARGET_TYPE) |
                          (tp_inds == PresentPastAmbiguousDifferentCorruptionsTask.LAST_SEP_TYPE))
        predictive_inds[torch.arange(predictive_inds.shape[0]), (predictive_inds.cumsum(dim=-1) * predictive_inds).argmax(dim=-1)] = 0
        return torch.nn.functional.cross_entropy(logits[predictive_inds, :], targets)

    @staticmethod
    def get_name(n_shot, num_ambiguous, where_abmiguous, corruption):
        return f"present_past_{num_ambiguous}_{where_abmiguous}_ambiguous_{n_shot}_shot_{corruption}"
        
    @property
    def config(self) -> TaskConfig:
        return TaskConfig(
            dataset=str(Path(os.path.dirname(os.path.realpath(__file__))).parent.parent.joinpath("datasets").joinpath("present_past")),
            subset=None,
            name=PresentPastAmbiguousDifferentCorruptionsTask.get_name(self.fewshot, self.num_ambiguous, self.where_abmiguous, self.corruption),
            evaluation_split="test",
            fewshot_split="train"
        )

    AMBIGUOUS_EXAMPLES = [
        "bet", "broadcast", "burst", "cost", "cut", "fit", 
        "hit", "hurt", "let", "put", "read", "set", "shut", "split", "spread", "upset",
        "quit", "wet", "bid", "shed"
    ]
    X_Y_CORRUPTION_FOR_PP = {
        "regular": {(1, 1): [{'input': 'resolve', 'target': 'resolved'}, {'input': 'divide', 'target': 'divided'}, {'input': 'guide', 'target': 'guided'}, {'input': 'solve', 'target': 'solved'}, {'input': 'process', 'target': 'processed'}], (1, 2): [{'input': 'yield', 'target': 'yielded'}, {'input': 'exist', 'target': 'existed'}, {'input': 'model', 'target': 'modelled'}, {'input': 'gain', 'target': 'gained'}, {'input': 'assess', 'target': 'assessed'}], (2, 1): [{'input': 'utilize', 'target': 'utilized'}, {'input': 'possess', 'target': 'possessed'}, {'input': 'satisfy', 'target': 'satisfied'}, {'input': 'qualify', 'target': 'qualified'}, {'input': 'educate', 'target': 'educated'}], (2, 2): [{'input': 'instruct', 'target': 'instructed'}, {'input': 'investigate', 'target': 'investigated'}, {'input': 'survive', 'target': 'survived'}, {'input': 'clarify', 'target': 'clarified'}, {'input': 'enlarge', 'target': 'enlarged'}], (2, 3): [{'input': 'pursue', 'target': 'pursued'}, {'input': 'facilitate', 'target': 'facilitated'}], (3, 3): [{'input': 'persuade', 'target': 'persuaded'}], (1, 3): [{'input': 'breathe', 'target': 'breathed'}]},
        "in_beginning": {(2, 1): [{'input': 'resolve', 'target': 'resolved'}, {'input': 'divide', 'target': 'divided'}, {'input': 'guide', 'target': 'guided'}, {'input': 'solve', 'target': 'solved'}, {'input': 'process', 'target': 'processed'}], (2, 2): [{'input': 'yield', 'target': 'yielded'}, {'input': 'exist', 'target': 'existed'}, {'input': 'model', 'target': 'modelled'}, {'input': 'gain', 'target': 'gained'}, {'input': 'assess', 'target': 'assessed'}], (3, 1): [{'input': 'utilize', 'target': 'utilized'}, {'input': 'possess', 'target': 'possessed'}, {'input': 'satisfy', 'target': 'satisfied'}, {'input': 'qualify', 'target': 'qualified'}, {'input': 'educate', 'target': 'educated'}], (3, 2): [{'input': 'instruct', 'target': 'instructed'}, {'input': 'investigate', 'target': 'investigated'}, {'input': 'survive', 'target': 'survived'}, {'input': 'clarify', 'target': 'clarified'}, {'input': 'enlarge', 'target': 'enlarged'}], (3, 3): [{'input': 'pursue', 'target': 'pursued'}, {'input': 'facilitate', 'target': 'facilitated'}], (4, 3): [{'input': 'persuade', 'target': 'persuaded'}], (2, 3): [{'input': 'breathe', 'target': 'breathed'}]},
    }
    
    def fewshot_docs(self, number_of_contexts) -> List[List[dict]]:
        np.random.seed(self.fewshot_seed)
        docs = self.dataset[self.config.fewshot_split]
        docs = docs.filter(lambda doc: doc["input"] != doc["target"])
        inds = [np.random.choice(list(range(len(docs))), self.fewshot, replace=False)
                for _ in range(number_of_contexts)]
        
        if self.where_abmiguous == "random":
            inds_of_ambiguous = [np.random.choice(list(range(self.fewshot)), size=self.num_ambiguous, replace=False)
                                    for _ in range(number_of_contexts)]
        elif self.where_abmiguous == "beginning":
            inds_of_ambiguous = [[k for k in range(self.num_ambiguous)] for _ in range(number_of_contexts)]
        elif self.where_abmiguous == "middle":
            inds_of_ambiguous = [[self.fewshot // 2 - self.num_ambiguous // 2 + k for k in range(self.num_ambiguous)] for _ in range(number_of_contexts)]
        elif self.where_abmiguous == "end":
            inds_of_ambiguous = [[self.fewshot - k - 1 for k in range(self.num_ambiguous)] for _ in range(number_of_contexts)]
        else:
            raise ValueError()
    
        ambiguous_examples = [np.random.choice(self.AMBIGUOUS_EXAMPLES, size=self.num_ambiguous, replace=False)
            for i in range(number_of_contexts)]
        ambiguous_examples = [
            {
                ind: ambiguous_examples[i][ind_i].item()
                for ind_i, ind in enumerate(inds_of_ambiguous[i])
            }
            for i in range(number_of_contexts)
        ]

        self.inds_of_ambiguous = inds_of_ambiguous

        return [[docs[inds[i][j].item()] if j not in inds_of_ambiguous[i] 
                    else {"input": ambiguous_examples[i][j], "target": ambiguous_examples[i][j]}
                 for j in range(self.fewshot)]
                for i in range(number_of_contexts)]
    
    
    def get_docs_in_templates(self, limit = None, corrupted = False, tokenizer = None) -> List[Dict[str, str]]:
        eval_docs, fewshot_docs = self.get_docs(limit)
        contexts = [self.fewshot_space.join([self.doc_to_text(doc) + self.sep + self.doc_to_target(doc) for doc in docs]) + self.fewshot_space + self.doc_to_text(eval_doc) + self.sep
                    for docs, eval_doc in zip(fewshot_docs, eval_docs)]
        targets = [self.doc_to_target(eval_doc) for eval_doc in eval_docs]
        if corrupted:
            corrupted_fewshot_docs = []
            corrupt_query = []
            corrupt_query_within_input_space = []
            special_token = "<special-token-we-will-never-encounter-in-a-dataset>"
            tokenizer.add_tokens([special_token])
            np.random.seed(self.fewshot_seed)
            QUERY_CORR = self.WORDS_PER_NUM_OF_TOKENS
            QUERY_CORR_IN_BEGINNING = self.WORDS_PER_NUM_OF_TOKENS_IN_BEGINNING
            TARGET_CORR = self.WORDS_PER_NUM_OF_TOKENS

            for context_i, context in enumerate(fewshot_docs):
                num_tokens_per_query = [len(tokenizer(special_token + item["input"], add_special_tokens=False, return_tensors="pt")["input_ids"][0]) - 1 for item in context]
                num_tokens_per_query[0] = len(tokenizer(context[0]["input"], add_special_tokens=True, return_tensors="pt")["input_ids"][0])
                num_tokens_per_target = [len(tokenizer(special_token + item["target"], add_special_tokens=False, return_tensors="pt")["input_ids"][0]) - 1 for item in context]
                corrupted_queries = [[np.random.choice(QUERY_CORR[num_tok]) if j != 0 else np.random.choice(QUERY_CORR_IN_BEGINNING[num_tok])
                                        for j, num_tok in enumerate(num_tokens_per_query)]
                                        for i in range(self.NUM_CORRUPTIONS + 1)]
                corrupted_targets = [[np.random.choice(TARGET_CORR[num_tok]) for num_tok in num_tokens_per_target] for i in range(self.NUM_CORRUPTIONS + 1)]
                if "pp_task" in self.corruption or "input_space" in self.corruption:
                    corrupted_pp_task = [[np.random.choice(self.X_Y_CORRUPTION_FOR_PP["regular" if j != 0 else "in_beginning"][(t1, t2)])
                                        for j, (t1, t2) in enumerate(zip(num_tokens_per_query, num_tokens_per_target))]
                                        for i in range(self.NUM_CORRUPTIONS)]
                corrupted_fewshot_docs_per_fewshot = []
                corrupt_query_per_fewshot = []
                corrupt_query_within_input_space_per_fewshot = []
                new_items = [{"input" : corrupted_queries[-1][i], "target": corrupted_targets[-1][i]} for i, item in enumerate(context)]
                corrupted_fewshot_docs_per_fewshot.append(new_items)
                corrupt_query_per_fewshot.append(True)
                corrupt_query_within_input_space_per_fewshot.append(False)

                if self.corruption == "x_in_all_fewshots_agreed":
                    new_items = [{"input" : corrupted_queries[-1][i], "target": item["target"]} for i, item in enumerate(context)]
                    corrupted_fewshot_docs_per_fewshot.append(new_items)
                    corrupt_query_per_fewshot.append(False)
                    corrupt_query_within_input_space_per_fewshot.append(False)
                elif self.corruption == "y_in_all_fewshots_agreed":
                    new_items = [{"input" : item["input"], "target": corrupted_targets[-1][i]} for i, item in enumerate(context)]
                    corrupted_fewshot_docs_per_fewshot.append(new_items)
                    corrupt_query_per_fewshot.append(False)
                    corrupt_query_within_input_space_per_fewshot.append(False)
                elif self.corruption == "pp_task_in_all_fewshots_agreed":
                    new_items = [corrupted_pp_task[-1][i] for i, item in enumerate(context)]
                    corrupted_fewshot_docs_per_fewshot.append(new_items)
                    corrupt_query_per_fewshot.append(False)
                    corrupt_query_within_input_space_per_fewshot.append(False)
                elif self.corruption == "x_within_input_space_in_all_fewshots_agreed":
                    new_items = [{"input" : corrupted_pp_task[-1][i]["input"], "target": item["target"]} for i, item in enumerate(context)]
                    corrupted_fewshot_docs_per_fewshot.append(new_items)
                    corrupt_query_per_fewshot.append(False)
                    corrupt_query_within_input_space_per_fewshot.append(False)
                elif self.corruption == "y_within_input_space_in_all_fewshots_agreed":
                    new_items = [{"input" : item["input"], "target": corrupted_pp_task[-1][i]["target"]} for i, item in enumerate(context)]
                    corrupted_fewshot_docs_per_fewshot.append(new_items)
                    corrupt_query_per_fewshot.append(False)
                    corrupt_query_within_input_space_per_fewshot.append(False)
                elif self.corruption == "full_corruption_in_all_fewshots_agreed":
                    new_items = [{"input" : corrupted_queries[-1][i], "target": corrupted_targets[-1][i]} for i, item in enumerate(context)]
                    corrupted_fewshot_docs_per_fewshot.append(new_items)
                    corrupt_query_per_fewshot.append(False)
                    corrupt_query_within_input_space_per_fewshot.append(False)
                else:
                    assert False
                corrupted_fewshot_docs.append(corrupted_fewshot_docs_per_fewshot)
                corrupt_query.append(corrupt_query_per_fewshot)
                corrupt_query_within_input_space.append(corrupt_query_within_input_space_per_fewshot)

            num_tokens_per_query = [len(tokenizer(special_token + item["input"], add_special_tokens=False, return_tensors="pt")["input_ids"][0]) - 1 for item in eval_docs]
            corrupted_queries = [np.random.choice(QUERY_CORR[num_tok]) for num_tok in num_tokens_per_query]
            corrupted_eval_docs = [{"input" : corrupted_queries[i], "target": item["target"]} for i, item in enumerate(eval_docs)]
            
            self.X_CORRUPTION = {
                tp: {
                    nt: [item for (nt1_, nt2_), items in self.X_Y_CORRUPTION_FOR_PP[tp].items() for item in items if nt1_ == nt]
                    for nt in set([nt1 for (nt1, nt2) in self.X_Y_CORRUPTION_FOR_PP[tp]])
                }
                for tp in self.X_Y_CORRUPTION_FOR_PP
            }
            corrupted_within_input_space_queries = [np.random.choice(
                self.X_CORRUPTION["regular"][num_tok])["input"]
                if num_tok in self.X_CORRUPTION["regular"] else 
                eval_docs[i]["input"]
                for i, num_tok in enumerate(num_tokens_per_query)]
            corrupted_within_input_space_eval_docs = [{"input" : corrupted_within_input_space_queries[i], "target": item["target"]} for i, item in enumerate(eval_docs)]

            corrupted_contexts = [
                [
                    self.fewshot_space.join([self.doc_to_text(doc) + self.sep + self.doc_to_target(doc) for doc in docs]) + self.fewshot_space + ((self.doc_to_text(corrupted_within_input_space_eval_doc) if corrupt_within_input_space_query_per_item else self.doc_to_text(corrupted_eval_doc)) if corrupt_query_per_item else self.doc_to_text(eval_doc)) + self.sep
                    for docs, (corrupt_query_per_item, corrupt_within_input_space_query_per_item) in zip(docs_fer_fewshot, zip(corrupt_query_per_fewshot, corrupt_within_input_space_query_per_fewshot))
                ]
                for (docs_fer_fewshot, eval_doc), ((corrupt_query_per_fewshot, corrupted_eval_doc), (corrupt_within_input_space_query_per_fewshot, corrupted_within_input_space_eval_doc)) in zip(zip(corrupted_fewshot_docs, eval_docs), zip(zip(corrupt_query, corrupted_eval_docs), zip(corrupt_query_within_input_space, corrupted_within_input_space_eval_docs)))]
            return [
            {
                "context": context,
                "target": target,
                "corrupted_contexts": corrupted_context,
                "corrupted_target": target,
            }
            for context, target, corrupted_context in zip(contexts, targets, corrupted_contexts)
            ]
        return [
            {
                "context": context,
                "target": target
            }
            for context, target in zip(contexts, targets)
            ]