import os
from pathlib import Path
from typing import Callable, List, Dict
from datasets import load_from_disk
import torch
import random
import numpy as np

from patching_gemma.tasks.task import Task, TaskConfig

class PresentPastAmbiguousDifferentCorruptionsSeparateFullCorruptionTask(Task):
    ALLOWED = {
        "n_shot": [3, 10],
        "num_ambiguous": [i for i in range(11)],
        "where_abmiguous": ["beginning", "middle", "end", "random"],
        "corruption": [
            "copy_task_in_all_fewshots_agreed",
        ],
    }
    def __init__(self, n_shot, num_ambiguous, where_abmiguous, corruption):
        assert corruption in PresentPastAmbiguousDifferentCorruptionsSeparateFullCorruptionTask.ALLOWED["corruption"]
        self.doc_to_text: Callable = lambda doc : str(doc["input"])
        self.doc_to_target: Callable = lambda doc : str(doc["target"])
        self.fewshot_space: str = "\n"
        self.sep: str = "\t"
        self.loss: Callable = self.loss_function
        self.fewshot: int = n_shot
        self.fewshot_seed: int = 42
        self.can_be_token_separable: bool = True
        self.separate_full_corruption: bool = True

        self.num_ambiguous = num_ambiguous
        self.where_abmiguous = where_abmiguous
        assert self.where_abmiguous in ["beginning", "middle", "end", "random"]
        self.corruption = corruption
        self.NUM_CORRUPTIONS = {
                                    "copy_task_in_all_fewshots_agreed": 1,
                               }[self.corruption]
        
        self.dataset = load_from_disk(self.config.dataset)

        self.TOKEN_TYPES = 16 if self.fewshot == 3 else (44 if self.fewshot == 10 else -1)
        assert self.TOKEN_TYPES != -1

    @staticmethod
    def loss_function(logits, targets, lens, tp_inds):
        predictive_inds = ((tp_inds == PresentPastAmbiguousDifferentCorruptionsSeparateFullCorruptionTask.TARGET_TYPE) |
                          (tp_inds == PresentPastAmbiguousDifferentCorruptionsSeparateFullCorruptionTask.LAST_SEP_TYPE))
        predictive_inds[torch.arange(predictive_inds.shape[0]), (predictive_inds.cumsum(dim=-1) * predictive_inds).argmax(dim=-1)] = 0
        return torch.nn.functional.cross_entropy(logits[predictive_inds, :], targets)

    @staticmethod
    def get_name(n_shot, num_ambiguous, where_abmiguous, corruption):
        return f"present_past_{num_ambiguous}_{where_abmiguous}_ambiguous_separate_full_corruption_{n_shot}_shot_{corruption}"
        
    @property
    def config(self) -> TaskConfig:
        return TaskConfig(
            dataset=str(Path(os.path.dirname(os.path.realpath(__file__))).parent.parent.joinpath("datasets").joinpath("present_past")),
            subset=None,
            name=PresentPastAmbiguousDifferentCorruptionsSeparateFullCorruptionTask.get_name(self.fewshot, self.num_ambiguous, self.where_abmiguous, self.corruption),
            evaluation_split="test",
            fewshot_split="train"
        )

    AMBIGUOUS_EXAMPLES = [
        "bet", "broadcast", "burst", "cost", "cut", "fit", 
        "hit", "hurt", "let", "put", "read", "set", "shut", "split", "spread", "upset",
        "quit", "wet", "bid", "shed"
    ]
    X_Y_CORRUPTION_FOR_PP = {
        "regular": {(1, 1): [{'input': 'resolve', 'target': 'resolved'}, {'input': 'divide', 'target': 'divided'}, {'input': 'guide', 'target': 'guided'}, {'input': 'solve', 'target': 'solved'}, {'input': 'process', 'target': 'processed'}], (1, 2): [{'input': 'yield', 'target': 'yielded'}, {'input': 'exist', 'target': 'existed'}, {'input': 'model', 'target': 'modelled'}, {'input': 'gain', 'target': 'gained'}, {'input': 'assess', 'target': 'assessed'}], (2, 1): [{'input': 'utilize', 'target': 'utilized'}, {'input': 'possess', 'target': 'possessed'}, {'input': 'satisfy', 'target': 'satisfied'}, {'input': 'qualify', 'target': 'qualified'}, {'input': 'educate', 'target': 'educated'}], (2, 2): [{'input': 'instruct', 'target': 'instructed'}, {'input': 'investigate', 'target': 'investigated'}, {'input': 'survive', 'target': 'survived'}, {'input': 'clarify', 'target': 'clarified'}, {'input': 'enlarge', 'target': 'enlarged'}], (2, 3): [{'input': 'pursue', 'target': 'pursued'}, {'input': 'facilitate', 'target': 'facilitated'}], (3, 3): [{'input': 'persuade', 'target': 'persuaded'}], (1, 3): [{'input': 'breathe', 'target': 'breathed'}]},
        "in_beginning": {(2, 1): [{'input': 'resolve', 'target': 'resolved'}, {'input': 'divide', 'target': 'divided'}, {'input': 'guide', 'target': 'guided'}, {'input': 'solve', 'target': 'solved'}, {'input': 'process', 'target': 'processed'}], (2, 2): [{'input': 'yield', 'target': 'yielded'}, {'input': 'exist', 'target': 'existed'}, {'input': 'model', 'target': 'modelled'}, {'input': 'gain', 'target': 'gained'}, {'input': 'assess', 'target': 'assessed'}], (3, 1): [{'input': 'utilize', 'target': 'utilized'}, {'input': 'possess', 'target': 'possessed'}, {'input': 'satisfy', 'target': 'satisfied'}, {'input': 'qualify', 'target': 'qualified'}, {'input': 'educate', 'target': 'educated'}], (3, 2): [{'input': 'instruct', 'target': 'instructed'}, {'input': 'investigate', 'target': 'investigated'}, {'input': 'survive', 'target': 'survived'}, {'input': 'clarify', 'target': 'clarified'}, {'input': 'enlarge', 'target': 'enlarged'}], (3, 3): [{'input': 'pursue', 'target': 'pursued'}, {'input': 'facilitate', 'target': 'facilitated'}], (4, 3): [{'input': 'persuade', 'target': 'persuaded'}], (2, 3): [{'input': 'breathe', 'target': 'breathed'}]},
    }
    X_Y_CORRUPTION_FOR_COPY = {
        "regular": {1: ['run', 'jump', 'eat', 'sit', 'read', 'write', 'think', 'speak', 'listen', 'watch', 'play', 'sing', 'dance', 'swim', 'fly', 'drive', 'cook', 'bake', 'clean', 'wash', 'sleep', 'wake', 'study', 'teach', 'learn', 'work', 'build', 'create', 'design', 'draw', 'paint', 'knit', 'sew', 'weave', 'mend', 'fix', 'repair', 'break', 'cut', 'slice', 'chop', 'grill', 'fry', 'boil', 'steam', 'stir', 'mix', 'blend', 'beat', 'fold', 'roll', 'shape', 'mold', 'cast', 'forge', 'weld', 'hammer', 'nail', 'screw', 'drill', 'saw', 'plane', 'sand', 'polish', 'paint', 'decorate', 'arrange', 'organize', 'plan', 'schedule', 'manage', 'lead', 'follow', 'guide', 'direct', 'coach', 'mentor', 'inspire', 'support', 'help', 'assist', 'communicate', 'argue', 'debate', 'discuss', 'analyze', 'evaluate', 'assess', 'judge', 'be', 'have', 'do', 'say', 'go', 'get', 'make', 'know', 'think', 'take', 'see', 'come', 'want', 'look', 'use', 'find', 'give', 'tell', 'work', 'call', 'try', 'ask', 'need', 'feel', 'become', 'leave', 'put', 'mean', 'keep', 'let', 'begin', 'seem', 'help', 'talk', 'turn', 'start', 'show', 'hear', 'play', 'run', 'move', 'like', 'live', 'believe', 'hold', 'bring', 'happen', 'write', 'provide', 'sit', 'stand', 'lose', 'pay', 'meet', 'include', 'continue', 'set', 'learn', 'change', 'lead', 'understand', 'watch', 'follow', 'stop', 'create', 'speak', 'read', 'allow', 'add', 'spend', 'grow', 'open', 'walk', 'win', 'offer', 'remember', 'love', 'consider', 'appear', 'buy', 'wait', 'serve', 'die', 'send', 'build', 'stay', 'fall', 'cut', 'reach', 'kill', 'remain', 'suggest', 'raise', 'pass', 'sell', 'require', 'report', 'decide', 'pull', 'break', 'thank', 'join', 'cause'], 2: ['sculpt', 'carve', 'roast', 'poach', 'simmer', 'whisk', 'knead', 'solder', 'varnish', 'instruct', 'motivate', 'encourage', 'collaborate', 'negotiate', 'convince'], 3: ['persuade']},
        "in_beginning": {2: ['run', 'jump', 'eat', 'sit', 'read', 'write', 'think', 'speak', 'listen', 'watch', 'play', 'sing', 'dance', 'swim', 'fly', 'drive', 'cook', 'bake', 'clean', 'wash', 'sleep', 'wake', 'study', 'teach', 'learn', 'work', 'build', 'create', 'design', 'draw', 'paint', 'knit', 'sew', 'weave', 'mend', 'fix', 'repair', 'break', 'cut', 'slice', 'chop', 'grill', 'fry', 'boil', 'steam', 'stir', 'mix', 'blend', 'beat', 'fold', 'roll', 'shape', 'mold', 'cast', 'forge', 'weld', 'hammer', 'nail', 'screw', 'drill', 'saw', 'plane', 'sand', 'polish', 'paint', 'decorate', 'arrange', 'organize', 'plan', 'schedule', 'manage', 'lead', 'follow', 'guide', 'direct', 'coach', 'mentor', 'inspire', 'support', 'help', 'assist', 'communicate', 'argue', 'debate', 'discuss', 'analyze', 'evaluate', 'assess', 'judge', 'be', 'have', 'do', 'say', 'go', 'get', 'make', 'know', 'think', 'take', 'see', 'come', 'want', 'look', 'use', 'find', 'give', 'tell', 'work', 'call', 'try', 'ask', 'need', 'feel', 'become', 'leave', 'put', 'mean', 'keep', 'let', 'begin', 'seem', 'help', 'talk', 'turn', 'start', 'show', 'hear', 'play', 'run', 'move', 'like', 'live', 'believe', 'hold', 'bring', 'happen', 'write', 'provide', 'sit', 'stand', 'lose', 'pay', 'meet', 'include', 'continue', 'set', 'learn', 'change', 'lead', 'understand', 'watch', 'follow', 'stop', 'create', 'speak', 'read', 'allow', 'add', 'spend', 'grow', 'open', 'walk', 'win', 'offer', 'remember', 'love', 'consider', 'appear', 'buy', 'wait', 'serve', 'die', 'send', 'build', 'stay', 'fall', 'cut', 'reach', 'kill', 'remain', 'suggest', 'raise', 'pass', 'sell', 'require', 'report', 'decide', 'pull', 'break', 'thank', 'join', 'cause'], 3: ['sculpt', 'carve', 'roast', 'poach', 'simmer', 'whisk', 'knead', 'solder', 'varnish', 'instruct', 'motivate', 'encourage', 'collaborate', 'negotiate', 'convince'], 4: ['persuade']},
    }
    
    def fewshot_docs(self, number_of_contexts) -> List[List[dict]]:
        np.random.seed(self.fewshot_seed)
        docs = self.dataset[self.config.fewshot_split]
        docs = docs.filter(lambda doc: doc["input"] != doc["target"])
        inds = [np.random.choice(list(range(len(docs))), self.fewshot, replace=False)
                for _ in range(number_of_contexts)]
        
        if self.where_abmiguous == "random":
            inds_of_ambiguous = [np.random.choice(list(range(self.fewshot)), size=self.num_ambiguous, replace=False)
                                    for _ in range(number_of_contexts)]
        elif self.where_abmiguous == "beginning":
            inds_of_ambiguous = [[k for k in range(self.num_ambiguous)] for _ in range(number_of_contexts)]
        elif self.where_abmiguous == "middle":
            inds_of_ambiguous = [[self.fewshot // 2 - self.num_ambiguous // 2 + k for k in range(self.num_ambiguous)] for _ in range(number_of_contexts)]
        elif self.where_abmiguous == "end":
            inds_of_ambiguous = [[self.fewshot - k - 1 for k in range(self.num_ambiguous)] for _ in range(number_of_contexts)]
        else:
            raise ValueError()
    
        ambiguous_examples = [np.random.choice(self.AMBIGUOUS_EXAMPLES, size=self.num_ambiguous, replace=False)
            for i in range(number_of_contexts)]
        ambiguous_examples = [
            {
                ind: ambiguous_examples[i][ind_i].item()
                for ind_i, ind in enumerate(inds_of_ambiguous[i])
            }
            for i in range(number_of_contexts)
        ]

        self.inds_of_ambiguous = inds_of_ambiguous

        return [[docs[inds[i][j].item()] if j not in inds_of_ambiguous[i] 
                    else {"input": ambiguous_examples[i][j], "target": ambiguous_examples[i][j]}
                 for j in range(self.fewshot)]
                for i in range(number_of_contexts)]
    
    
    def get_docs_in_templates(self, limit = None, corrupted = False, tokenizer = None) -> List[Dict[str, str]]:
        eval_docs, fewshot_docs = self.get_docs(limit)
        contexts = [self.fewshot_space.join([self.doc_to_text(doc) + self.sep + self.doc_to_target(doc) for doc in docs]) + self.fewshot_space + self.doc_to_text(eval_doc) + self.sep
                    for docs, eval_doc in zip(fewshot_docs, eval_docs)]
        targets = [self.doc_to_target(eval_doc) for eval_doc in eval_docs]
        if corrupted:
            corrupted_fewshot_docs = []
            corrupt_query = []
            corrupt_query_within_input_space = []
            special_token = "<special-token-we-will-never-encounter-in-a-dataset>"
            tokenizer.add_tokens([special_token])
            np.random.seed(self.fewshot_seed)
            QUERY_CORR = self.WORDS_PER_NUM_OF_TOKENS
            QUERY_CORR_IN_BEGINNING = self.WORDS_PER_NUM_OF_TOKENS_IN_BEGINNING
            TARGET_CORR = self.WORDS_PER_NUM_OF_TOKENS

            for context_i, context in enumerate(fewshot_docs):
                num_tokens_per_query = [len(tokenizer(special_token + item["input"], add_special_tokens=False, return_tensors="pt")["input_ids"][0]) - 1 for item in context]
                num_tokens_per_query[0] = len(tokenizer(context[0]["input"], add_special_tokens=True, return_tensors="pt")["input_ids"][0])
                num_tokens_per_target = [len(tokenizer(special_token + item["target"], add_special_tokens=False, return_tensors="pt")["input_ids"][0]) - 1 for item in context]
                corrupted_queries = [[np.random.choice(QUERY_CORR[num_tok]) if j != 0 else np.random.choice(QUERY_CORR_IN_BEGINNING[num_tok])
                                        for j, num_tok in enumerate(num_tokens_per_query)]
                                        for i in range(self.NUM_CORRUPTIONS + 1)]
                corrupted_targets = [[np.random.choice(TARGET_CORR[num_tok]) for num_tok in num_tokens_per_target] for i in range(self.NUM_CORRUPTIONS + 1)]
                if "pp_task" in self.corruption or "input_space" in self.corruption:
                    corrupted_pp_task = [[np.random.choice(self.X_Y_CORRUPTION_FOR_PP["regular" if j != 0 else "in_beginning"][(t1, t2)])
                                        for j, (t1, t2) in enumerate(zip(num_tokens_per_query, num_tokens_per_target))]
                                        for i in range(self.NUM_CORRUPTIONS)]
                if "copy_task" in self.corruption:
                    assert all([t in self.X_Y_CORRUPTION_FOR_COPY["regular"] for j, t in enumerate(num_tokens_per_target)]), (num_tokens_per_target, fewshot_docs)
                    corrupted_copy_task = [[np.random.choice(self.X_Y_CORRUPTION_FOR_COPY["regular"][t2])
                                        for j, t2 in enumerate(num_tokens_per_target)]
                                        for i in range(self.NUM_CORRUPTIONS)]
                    corrupted_copy_task = [[{"input": corrupted_copy_task[i][j], "target": corrupted_copy_task[i][j]}
                                        for j, t2 in enumerate(num_tokens_per_target)]
                                        for i in range(self.NUM_CORRUPTIONS)]
                corrupted_fewshot_docs_per_fewshot = []
                corrupt_query_per_fewshot = []
                corrupt_query_within_input_space_per_fewshot = []
                new_items = [{"input" : corrupted_queries[-1][i], "target": corrupted_targets[-1][i]} for i, item in enumerate(context)]
                corrupted_fewshot_docs_per_fewshot.append(new_items)
                corrupt_query_per_fewshot.append(True)
                corrupt_query_within_input_space_per_fewshot.append(False)

                if self.corruption == "copy_task_in_all_fewshots_agreed":
                    new_items = [corrupted_copy_task[-1][i] for i, item in enumerate(context)]

                    num_tokens_per_query = [
                        len(tokenizer(special_token + item["input"], add_special_tokens=False, return_tensors="pt")["input_ids"][0]) - 1 if i != 0 else
                        len(tokenizer(item["input"], add_special_tokens=True, return_tensors="pt")["input_ids"][0])
                        for i, item in enumerate(new_items)
                    ]
                    num_tokens_per_target = [
                        len(tokenizer(special_token + item["target"], add_special_tokens=False, return_tensors="pt")["input_ids"][0]) - 1
                        for item in new_items
                    ]
                    corrupted_new_items = [
                        {
                            "input": np.random.choice(QUERY_CORR[num_tok_q]) if i != 0 else np.random.choice(QUERY_CORR_IN_BEGINNING[num_tok_q]),
                            "target": np.random.choice(TARGET_CORR[num_tok_t])
                        }
                        for i, (num_tok_q, num_tok_t) in enumerate(zip(num_tokens_per_query, num_tokens_per_target))
                    ]
                    corrupted_fewshot_docs_per_fewshot.append(corrupted_new_items)
                    corrupt_query_per_fewshot.append(True)
                    corrupt_query_within_input_space_per_fewshot.append(False)
                    corrupted_fewshot_docs_per_fewshot.append(new_items)
                    corrupt_query_per_fewshot.append(False)
                    corrupt_query_within_input_space_per_fewshot.append(False)
                else:
                    assert False
                corrupted_fewshot_docs.append(corrupted_fewshot_docs_per_fewshot)
                corrupt_query.append(corrupt_query_per_fewshot)
                corrupt_query_within_input_space.append(corrupt_query_within_input_space_per_fewshot)

            num_tokens_per_query = [len(tokenizer(special_token + item["input"], add_special_tokens=False, return_tensors="pt")["input_ids"][0]) - 1 for item in eval_docs]
            corrupted_queries = [np.random.choice(QUERY_CORR[num_tok]) for num_tok in num_tokens_per_query]
            corrupted_eval_docs = [{"input" : corrupted_queries[i], "target": item["target"]} for i, item in enumerate(eval_docs)]
            
            self.X_CORRUPTION = {
                tp: {
                    nt: [item for (nt1_, nt2_), items in self.X_Y_CORRUPTION_FOR_PP[tp].items() for item in items if nt1_ == nt]
                    for nt in set([nt1 for (nt1, nt2) in self.X_Y_CORRUPTION_FOR_PP[tp]])
                }
                for tp in self.X_Y_CORRUPTION_FOR_PP
            }
            corrupted_within_input_space_queries = [np.random.choice(
                self.X_CORRUPTION["regular"][num_tok])["input"]
                if num_tok in self.X_CORRUPTION["regular"] else 
                eval_docs[i]["input"]
                for i, num_tok in enumerate(num_tokens_per_query)]
            corrupted_within_input_space_eval_docs = [{"input" : corrupted_within_input_space_queries[i], "target": item["target"]} for i, item in enumerate(eval_docs)]

            corrupted_contexts = [
                [
                    self.fewshot_space.join([self.doc_to_text(doc) + self.sep + self.doc_to_target(doc) for doc in docs]) + self.fewshot_space + ((self.doc_to_text(corrupted_within_input_space_eval_doc) if corrupt_within_input_space_query_per_item else self.doc_to_text(corrupted_eval_doc)) if corrupt_query_per_item else self.doc_to_text(eval_doc)) + self.sep
                    for docs, (corrupt_query_per_item, corrupt_within_input_space_query_per_item) in zip(docs_fer_fewshot, zip(corrupt_query_per_fewshot, corrupt_within_input_space_query_per_fewshot))
                ]
                for (docs_fer_fewshot, eval_doc), ((corrupt_query_per_fewshot, corrupted_eval_doc), (corrupt_within_input_space_query_per_fewshot, corrupted_within_input_space_eval_doc)) in zip(zip(corrupted_fewshot_docs, eval_docs), zip(zip(corrupt_query, corrupted_eval_docs), zip(corrupt_query_within_input_space, corrupted_within_input_space_eval_docs)))]
            return [
            {
                "context": context,
                "target": target,
                "corrupted_contexts": corrupted_context,
                "corrupted_target": target,
            }
            for context, target, corrupted_context in zip(contexts, targets, corrupted_contexts)
            ]
        return [
            {
                "context": context,
                "target": target
            }
            for context, target in zip(contexts, targets)
            ]