import os
from pathlib import Path
from typing import Callable, List, Dict
from datasets import load_from_disk
import torch
import random
import numpy as np

from patching_gemma.tasks.task import Task, TaskConfig

class FineGrainedTypesTaskPhi2(Task):
    ALLOWED = {
        "task_name": ["capitalization_phi2", "country_capital_phi2",
                        "present_past_phi2", "person_sport_phi2"],
        "n_shot": [3, 10],
        "corrupt_query": [True, False]
    }
    
    def __init__(self, task_name, n_shot, corrupt_query=False):
        assert task_name in FineGrainedTypesTaskPhi2.ALLOWED["task_name"]
        self.dataset_name = task_name
        self.doc_to_text: Callable = lambda doc : str(doc["input"])
        self.doc_to_target: Callable = lambda doc : str(doc["target"])
        self.fewshot_space: str = "\n"
        self.sep: str = "\t"
        self.loss: Callable = self.loss_function
        self.fewshot: int = n_shot
        self.fewshot_seed: int = 42
        self.corrupt_query = corrupt_query
        self.can_be_token_separable: bool = True
        self.dataset = load_from_disk(self.config.dataset)
        self.TOKEN_TYPES = 16 if self.fewshot == 3 else (44 if self.fewshot == 10 else -1)
        assert self.TOKEN_TYPES != -1

    @staticmethod
    def get_name(task_name, n_shot, corrupt_query):
        return f"{task_name}_{n_shot}_shot{'_corrupt_query' if corrupt_query else ''}"

    @staticmethod
    def loss_function(logits, targets, lens, tp_inds):
        predictive_inds = ((tp_inds == FineGrainedTypesTaskPhi2.TARGET_TYPE) |
                          (tp_inds == FineGrainedTypesTaskPhi2.LAST_SEP_TYPE))
        predictive_inds[torch.arange(predictive_inds.shape[0]), (predictive_inds.cumsum(dim=-1) * predictive_inds).argmax(dim=-1)] = 0
        return torch.nn.functional.cross_entropy(logits[predictive_inds, :], targets)
        
    @property
    def config(self) -> TaskConfig:
        return TaskConfig(
            dataset=str(Path(os.path.dirname(os.path.realpath(__file__))).parent.joinpath("datasets").joinpath(self.dataset_name)),
            subset=None,
            name=FineGrainedTypesTaskPhi2.get_name(self.dataset_name, self.fewshot, self.corrupt_query),
            evaluation_split="test",
            fewshot_split="train"
        )

    def get_docs_in_templates(self, limit = None, corrupted = False, tokenizer = None) -> List[Dict[str, str]]:
        eval_docs, fewshot_docs = self.get_docs(limit)
        contexts = [self.fewshot_space.join([self.doc_to_text(doc) + self.sep + self.doc_to_target(doc) for doc in docs]) + self.fewshot_space + self.doc_to_text(eval_doc) + self.sep
                    for docs, eval_doc in zip(fewshot_docs, eval_docs)]
        targets = [self.doc_to_target(eval_doc) for eval_doc in eval_docs]
        if corrupted:
            corrupted_fewshot_docs = []
            special_token = "<special-token-we-will-never-encounter-in-a-dataset>"
            tokenizer.add_tokens([special_token])
            np.random.seed(self.fewshot_seed)
            QUERY_CORR = self.WORDS_PER_NUM_OF_TOKENS
            QUERY_CORR_IN_BEGINNING = self.WORDS_PER_NUM_OF_TOKENS_IN_BEGINNING
            TARGET_CORR = self.WORDS_PER_NUM_OF_TOKENS
            for context in fewshot_docs:
                num_tokens_per_query = [len(tokenizer(special_token + item["input"], add_special_tokens=False, return_tensors="pt")["input_ids"][0]) - 1 for item in context]
                num_tokens_per_query[0] = len(tokenizer(context[0]["input"], add_special_tokens=True, return_tensors="pt")["input_ids"][0])
                num_tokens_per_target = [len(tokenizer(special_token + item["target"], add_special_tokens=False, return_tensors="pt")["input_ids"][0]) - 1 for item in context]
                corrupted_queries = [np.random.choice(QUERY_CORR[num_tok]) for num_tok in num_tokens_per_query]
                corrupted_queries[0] = np.random.choice(QUERY_CORR_IN_BEGINNING[num_tokens_per_query[0]])
                corrupted_targets = [np.random.choice(TARGET_CORR[num_tok]) for num_tok in num_tokens_per_target]
                new_items = [{"input" : corrupted_queries[i], "target": corrupted_targets[i]} for i in range(len(corrupted_queries))]
                corrupted_fewshot_docs.append(new_items)

            if self.corrupt_query:
                num_tokens_per_query = [len(tokenizer(special_token + item["input"], add_special_tokens=False, return_tensors="pt")["input_ids"][0]) - 1 for item in eval_docs]
                num_tokens_per_query[0] = len(tokenizer(eval_docs[0]["input"], add_special_tokens=True, return_tensors="pt")["input_ids"][0])
                corrupted_queries = [np.random.choice(QUERY_CORR[num_tok]) for num_tok in num_tokens_per_query]
                corrupted_queries[0] = np.random.choice(QUERY_CORR_IN_BEGINNING[num_tokens_per_query[0]])
                corrupted_eval_docs = [{"input" : corrupted_queries[i], "target": item["target"]} for i, item in enumerate(eval_docs)]
                
                corrupted_contexts = [self.fewshot_space.join([self.doc_to_text(doc) + self.sep + self.doc_to_target(doc) for doc in docs]) + self.fewshot_space + self.doc_to_text(eval_doc) + self.sep
                    for docs, eval_doc in zip(corrupted_fewshot_docs, corrupted_eval_docs)]
            else:
                corrupted_contexts = [self.fewshot_space.join([self.doc_to_text(doc) + self.sep + self.doc_to_target(doc) for doc in docs]) + self.fewshot_space + self.doc_to_text(eval_doc) + self.sep
                    for docs, eval_doc in zip(corrupted_fewshot_docs, eval_docs)]
            return [
            {
                "context": context,
                "target": target,
                "corrupted_context": corrupted_context,
                "corrupted_target": target,
            }
            for context, target, corrupted_context in zip(contexts, targets, corrupted_contexts)
            ]
        return [
            {
                "context": context,
                "target": target
            }
            for context, target in zip(contexts, targets)
            ]