from transformers import AutoTokenizer, T5ForConditionalGeneration
import torch
import json

from configs.utils import load_BBL_file
from datasets import concatenate_datasets, Dataset
from torch.nn import PairwiseDistance
from tqdm import tqdm
import random
import os
import pandas as pd


def load_MMLU_file():
    input_dir = "./data/MMLU"
    def to_qa_dict(item: dict) -> dict:
            item["answer"] = item[item.pop("answer")]
            item["options"] = [item.pop(x) for x in ["A", "B", "C", "D"]]
            assert list(item.keys()) == ["question", "answer", "options"]
            return item

    files = [f for f in os.listdir(input_dir) if f.endswith(".csv")]
    items, examples = [], []
    for file in files:
        df = pd.read_csv(os.path.join(input_dir, file), names=["question", "A", "B", "C", "D", "answer"])
        file_items = df.to_dict("records")
        key = file.replace(".csv", "")

        for item in file_items:
            to_qa_dict(item)
            items.append(item)
    return items


def preprocess(config_dir, data_dir, proc_class, obs_instructions=[], uno_instructions=[], instance_samples=300, instruction_samples=10):
    if not "MMLU" in data_dir:
        data_set, _, _ = load_BBL_file(os.path.join(data_dir, "task.json"), [], 0)
    else:
        data_set = load_MMLU_file()
    observed_instructions = random.sample(obs_instructions, instruction_samples)
    unobserved_instructions = random.sample(uno_instructions, instruction_samples)
    if instance_samples < len(data_set):
        data_set = random.sample(data_set, instance_samples)

    data_subset = Dataset.from_list(data_set)
    bbl_obs_datasets, bbl_unobs_datasets = [], []

    for instruction in unobserved_instructions:
        processor = proc_class(instruction, [], True, data_dir).processor if not "MMLU" in data_dir else proc_class(instruction, [], True).processor
        subset = data_subset.map(processor, remove_columns=["question", "options", "answer"], num_proc=1)
        subset = subset.add_column("instruction", [instruction for _ in range(len(subset))])
        subset = subset.add_column("instance_id", [i for i in range(len(subset))])
        bbl_unobs_datasets.append(subset)

    for instruction in observed_instructions:
        processor = proc_class(instruction, [], True, data_dir).processor if not "MMLU" in data_dir else proc_class(instruction, [], True).processor
        subset = data_subset.map(processor, remove_columns=["question", "options", "answer"], num_proc=1)
        subset = subset.add_column("instruction", [instruction for _ in range(len(subset))])
        subset = subset.add_column("instance_id", [i for i in range(len(subset))])
        bbl_obs_datasets.append(subset)
    
    bbl_unobs_dataset = concatenate_datasets(bbl_unobs_datasets)
    bbl_obs_dataset = concatenate_datasets(bbl_obs_datasets)

    return bbl_obs_dataset, bbl_unobs_dataset

def get_hidden_state(input_text, pos, tokenizer, model, device): 
    inputs = tokenizer(input_text, return_tensors='pt').input_ids.to(device)
    with torch.no_grad():
        outputs = model.generate(inputs, output_hidden_states=True, return_dict_in_generate=True, num_beams=1, num_return_sequences=1, max_length=100, early_stopping=True, use_cache=True)
    hidden = outputs['decoder_hidden_states'][0][pos][0][0].cpu().numpy()
    return hidden

def get_hidden_states(dataset):
    hiddens = []
    for batch in tqdm(dataset):
        input_text = batch["input_text"]
        hiddens.append(get_hidden_state(input_text, 1))
    
    hiddens = [torch.tensor(h) for h in hiddens]
    hiddens = torch.stack(hiddens)
    return hiddens
    
def build_mapping(dataset):
    mapping = {}
    for i in range(len(dataset)):
        if dataset[i]["instruction"] not in mapping.keys():
            mapping[dataset[i]["instruction"]] = []
        mapping[dataset[i]["instruction"]].append((dataset[i]["instance_id"], i))
    return mapping


def get_performance(observed_dataset, unobserved_datadset, csv_dir, dataset):
    df = pd.read_csv(csv_dir, index_col=None)
    instructions = {}
    for sample in observed_dataset:
        if sample["instruction"] not in instructions.keys():
            instructions[sample["instruction"]] = None
    
    for sample in unobserved_datadset:
        if sample["instruction"] not in instructions.keys():
            instructions[sample["instruction"]] = None
    
    for i in df.index:
        instruction = "{}/{}/{}".format(df.loc[i]["Collection"], df.loc[i]["Type"], str(df.loc[i]["ID"]))
        if instruction in instructions.keys() and df.loc[i]["Dataset"] == dataset:
            instructions[instruction] = df.loc[i]["Performance"]
    
    return instructions


def get_performance_mmlu(observed_dataset, unobserved_datadset, csv_dir, dataset):
    instruction_from = [f"MMLU/Unobserved/{i}" for i in range(1, 41) if i not in [8, 12, 40, 16, 36, 9, 15, 11, 18, 33, 7, 17, 13, 20, 37, 5, 31, 24, 39, 35]]
    instruction_to = [f"MMLU/Unobserved/{i}" for i in range(1, 21)]
    performance_mapping = {}
    performance_mapping_reverse = {}
    assert len(instruction_from) == len(instruction_to)
    for f, t in zip(instruction_from, instruction_to):
        performance_mapping[f] = t
        performance_mapping_reverse[t] = f

    df = pd.read_csv(csv_dir, index_col=None)
    instructions = {}
    for sample in observed_dataset:
        if sample["instruction"] not in instructions.keys():
            instructions[sample["instruction"]] = None
    
    for sample in unobserved_datadset:
        if sample["instruction"] not in instructions.keys():
            instructions[sample["instruction"]] = None
    
    for i in df.index:
        instruction = "{}/{}/{}".format(df.loc[i]["Collection"], df.loc[i]["Type"], str(df.loc[i]["ID"]))
        if "Unobserved" in instruction:
            if instruction in performance_mapping_reverse.keys():
                instruction = performance_mapping_reverse[instruction]
            else:
                continue
        if instruction in instructions.keys() and df.loc[i]["Dataset"] == dataset:
            instructions[instruction] = df.loc[i]["Performance"]
    
    return instructions
    

def get_distance(unobserved_dataset, observed_dataset, unobseved_hidden_states, observed_hidden_states, device, performances):
    pdist = PairwiseDistance(p=2)
    observed_mapping = build_mapping(observed_dataset)
    unobserved_mapping = build_mapping(unobserved_dataset)
    df = pd.DataFrame(columns=["Unobserved Instruction", "Closest Observed", "L2 Distance", "Degradation", "Closest Performance"])
    for instruction in unobserved_mapping.keys():
        unobserved_instances = unobserved_mapping[instruction]
        unobserved_instances = sorted(unobserved_instances, key=lambda x: x[0])
        unobserved_instances = [x[1] for x in unobserved_instances]
        unobserved_instances = unobseved_hidden_states[unobserved_instances]
        unobserved_instances = unobserved_instances.to(device)

        observed_instances = []
        obs_instr_idxs = []
        obs_instr_dist = []
        for obs_instruction in observed_mapping.keys():
            obs_instr_idxs.append(obs_instruction)
            observed_instance = observed_mapping[obs_instruction]
            observed_instance = sorted(observed_instance, key=lambda x: x[0])
            observed_instance = [x[1] for x in observed_instance]
            observed_instance = observed_hidden_states[observed_instance]
            observed_instances.append(observed_instance.unsqueeze(0))
        observed_instances = torch.cat(observed_instances, dim=0)
        observed_instances = observed_instances.to(device)
        for i in range(observed_instances.shape[0]):
            distances = torch.mean(pdist(unobserved_instances, observed_instances[i]))
            obs_instr_dist.append(distances.item())
        closest = torch.topk(distances, 1)
        closest_index = closest.indices.item()
        closest_instruction = obs_instr_idxs[closest_index]
        closest_dist = closest.values.item()
        original_performance = performances[instruction]
        closest_performance = performances[closest_instruction]
        try:
            degradation = float(original_performance) - float(closest_performance)
        except Exception:
            print(performances.keys(), instruction)
            raise 
        df = df.append({"Unobserved Instruction": instruction, "Closest Observed": closest_instruction, "L2 Distance": closest_dist, "Degradation": degradation, "Closest Performance": float(closest_performance)}, ignore_index=True)
    
    return df
        