import argparse

from pytorch_lightning.utilities.types import EVAL_DATALOADERS
from experiment import DATASET2CONFIGS
from configs.utils import load_BBL_file
from datasets import concatenate_datasets, Dataset
from torch.nn import PairwiseDistance
import pytorch_lightning as pl
from tqdm import tqdm
import random
import os
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfig, get_peft_model, PrefixTuningConfig, TaskType
from pytorch_lightning.utilities.data import DataLoader
import torch
import json

        
class DistanceDataModule(pl.LightningDataModule):
    
    def __init__(self, dataset, tokenizer, batch_size=32):
        super().__init__()
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        
    def collate_fn(self, batch):
        input_text = [item["input_text"] for item in batch]
        batch = self.tokenizer(input_text, padding=True, truncation=True, return_tensors="pt", max_length=512)
        return batch
    
    def test_mapping(self) -> dict:
        mapping = {}
        for i in range(len(self.dataset)):
            if self.dataset[i]["instruction"] not in mapping.keys():
                mapping[self.dataset[i]["instruction"]] = []
            mapping[self.dataset[i]["instruction"]].append((self.dataset[i]["instance_id"], i))
        return mapping
    
    def test_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(self.dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)
    
    
def load_MMLU_file():
    input_dir = "./data/MMLU"
    
    def to_qa_dict(item: dict) -> dict:
        item["answer"] = item[item.pop("answer")]
        item["options"] = [item.pop(x) for x in ["A", "B", "C", "D"]]
        assert list(item.keys()) == ["question", "answer", "options"]
        return item

    files = [f for f in os.listdir(input_dir) if f.endswith(".csv")]
    items = []
    for file in files:
        df = pd.read_csv(os.path.join(input_dir, file), names=["question", "A", "B", "C", "D", "answer"])
        file_items = df.to_dict("records")

        for item in file_items:
            to_qa_dict(item)
            items.append(item)
    return items


def preprocess(config_dir, data_dir, proc_class, observed_instructions=[], unobserved_instructions=[], instance_samples=300, instruction_samples=10, seed=None):
    if not "MMLU" in data_dir:
        data_set, _, _ = load_BBL_file(os.path.join(data_dir, "task.json"), [], 0)
    else:
        data_set = load_MMLU_file()
        
    if seed is not None:
        random.seed(seed)
        
    if instruction_samples < len(observed_instructions):
        observed_instructions = random.sample(observed_instructions, instruction_samples)
    
    if instruction_samples < len(unobserved_instructions):
        unobserved_instructions = random.sample(unobserved_instructions, instruction_samples)
        
    if instance_samples < len(data_set):
        data_set = random.sample(data_set, instance_samples)

    data_subset = Dataset.from_list(data_set)
    obs_datasets, unobs_datasets = [], []

    for instruction in unobserved_instructions:
        processor = proc_class(instruction, [], True, data_dir).processor if not "MMLU" in data_dir else proc_class(instruction, [], True).processor
        subset = data_subset.map(processor, remove_columns=["question", "options", "answer"], num_proc=1)
        subset = subset.add_column("instruction", [instruction for _ in range(len(subset))])
        subset = subset.add_column("instance_id", [i for i in range(len(subset))])
        unobs_datasets.append(subset)

    for instruction in observed_instructions:
        processor = proc_class(instruction, [], True, data_dir).processor if not "MMLU" in data_dir else proc_class(instruction, [], True).processor
        subset = data_subset.map(processor, remove_columns=["question", "options", "answer"], num_proc=1)
        subset = subset.add_column("instruction", [instruction for _ in range(len(subset))])
        subset = subset.add_column("instance_id", [i for i in range(len(subset))])
        obs_datasets.append(subset)
    
    obs_dataset = concatenate_datasets(obs_datasets)
    unobs_dataset = concatenate_datasets(unobs_datasets)

    return obs_dataset, unobs_dataset


def get_distance(unobserved_mapping, observed_mapping, unobseved_hidden_states, observed_hidden_states, device="cuda"):
    pdist = PairwiseDistance(p=2)
    rows = []
    for instruction in unobserved_mapping.keys():
        unobserved_instances = unobserved_mapping[instruction]
        unobserved_instances = sorted(unobserved_instances, key=lambda x: x[0])
        unobserved_instances = [x[1] for x in unobserved_instances]
        unobserved_instances = unobseved_hidden_states[unobserved_instances]
        unobserved_instances = unobserved_instances.to(device)

        observed_instances = []
        obs_instr_idxs = []
        obs_instr_dist = []
        for obs_instruction in observed_mapping.keys():
            obs_instr_idxs.append(obs_instruction)
            observed_instance = observed_mapping[obs_instruction]
            observed_instance = sorted(observed_instance, key=lambda x: x[0])
            observed_instance = [x[1] for x in observed_instance]
            observed_instance = observed_hidden_states[observed_instance]
            observed_instances.append(observed_instance.unsqueeze(0))
        observed_instances = torch.cat(observed_instances, dim=0)
        observed_instances = observed_instances.to(device)
        for i in range(observed_instances.shape[0]):
            distances = torch.mean(pdist(unobserved_instances, observed_instances[i]))
            obs_instr_dist.append(distances.item())
        
        obs_instr_dist = -torch.tensor(obs_instr_dist)
        closest = torch.topk(obs_instr_dist, 1)
        closest_index = closest.indices.item()
        closest_instruction = obs_instr_idxs[closest_index]
        closest_dist = -closest.values.item()
        rows.append({"Unobserved Instruction": instruction, "Closest Observed": closest_instruction, "L2 Distance": closest_dist})
    
    df = pd.DataFrame(rows)
    return df


def get_batch_hidden_states(model, batch, token_pos=0, layer=-1):
    batch = {k: v.to(model.device) for k, v in batch.items()}
    outputs = model.generate(**batch, max_new_tokens=token_pos+1)
    hs = outputs["decoder_hidden_states"][0][layer][:, token_pos, :].detach().cpu()
    return hs

def get_dataset_hidden_states(model, dataloader, token_pos=0, layer=-1):
    hs = []
    for batch in tqdm(dataloader):
        hs.append(get_batch_hidden_states(model, batch, token_pos, layer))
    hs = torch.cat(hs, dim=0)
    return hs



from configs.BBH.classification.vitaminc_fact_verification import VitamincFactPreprocessor
from configs.BBH.binary_classification.winowhy import PlayDialogPreprocessor as WinowhyPreprocessor
from configs.BBH.binary_classification.strategy_qa import StrategyQAPreprocessor
from configs.BBH.binary_classification.strange_stories import StrangeStoriesPreprocessor
from configs.BBH.binary_classification.play_dialog_same_or_different import PlayDialogPreprocessor
from configs.BBH.multiple_choice.novel_concepts import NovelConceptsPreprocessor
from configs.BBH.multiple_choice.logical_deduction import LogicalDeductionPreprocessor
from configs.BBH.multiple_choice.known_unknowns import KnownUnknownPreprocessor
from configs.BBH.multiple_choice.hindu_knowledge import KnownUnknownPreprocessor as HinduKnowledgePreprocessor
from configs.BBH.multiple_choice.conceptual_combinations import ConceptualCombinationPreprocessor
from configs.BBH.multiple_choice.code_line_description import LogicalDeductionPreprocessor as CodeLineDescriptionPreprocessor
from configs.BBH.classification.language_identification import LanguageIdentificationPreprocessor
from configs.BBH.multiple_choice.bbq_lite import BBQLitePreprocessor
from configs.MMLU.general import MMLUGeneralPreprocessor


def language_identification_instruction_pool():
    obs_instructions2 = [f"NIV2/Classification/{i}" for i in range(31, 41)]
    obs_instructions3 = [f"FLAN/Classification/{i}" for i in range(1, 16)]
    obs_instructions = []; obs_instructions.extend(obs_instructions2); obs_instructions.extend(obs_instructions3)
    unob_instructions = [f"BBL/Unobserved/{i}" for i in range(1, 9)]
    return obs_instructions, unob_instructions

def mmlu_instruction_pool():
    obs_instructions = [f"NIV2/QA/{i}" for i in range(1, 51)]
    unob_instructions = [f"MMLU/Unobserved/{i}" for i in range(1, 21)]
    return obs_instructions, unob_instructions

def bbq_instruction_pool():
    obs_instructions = [f"NIV2/QA/{i}" for i in range(1, 51)]
    unob_instructions = [f"BBL/Unobserved/{i}" for i in range(1, 13)]
    return obs_instructions, unob_instructions

def code_line_description_instruction_pool():
    obs_instructions = [f"NIV2/QA/{i}" for i in range(1, 51)]
    unob_instructions = [f"BBL/Unobserved/{i}" for i in range(1, 11)]
    return obs_instructions, unob_instructions

def conceptual_combinations_instruction_pool():
    obs_instructions = [f"NIV2/QA/{i}" for i in range(1, 51)]
    unob_instructions = [f"BBL/Unobserved/{i}" for i in range(1, 11)]
    return obs_instructions, unob_instructions

def hindu_knowledge_instruction_pool():
    obs_instructions = [f"NIV2/QA/{i}" for i in range(1, 51)]
    unob_instructions = [f"BBL/Unobserved/{i}" for i in range(1, 11)]
    return obs_instructions, unob_instructions

def known_unknowns_instruction_pool():
    obs_instructions = [f"NIV2/QA/{i}" for i in range(1, 51)]
    unob_instructions = [f"BBL/Unobserved/{i}" for i in range(1, 11)]
    return obs_instructions, unob_instructions

def logical_deduction_instruction_pool():
    obs_instructions = [f"NIV2/QA/{i}" for i in range(1, 51)]
    unob_instructions = [f"BBL/Unobserved/{i}" for i in range(1, 11)]
    return obs_instructions, unob_instructions

def novel_concepts_instruction_pool():
    obs_instructions = [f"NIV2/QA/{i}" for i in range(1, 51)]
    unob_instructions = [f"BBL/Unobserved/{i}" for i in range(1, 11)]
    return obs_instructions, unob_instructions


def play_dialog_instruction_pool():
    obs_instructions1 = [f"NIV2/BC/{i}" for i in range(1, 11)]
    obs_instructions2 = [f"FLAN/BC/{i}" for i in range(1, 9)]
    obs_instructions = []; obs_instructions.extend(obs_instructions1); obs_instructions.extend(obs_instructions2)
    unob_instructions = [f"BBL/Unobserved/{i}" for i in range(1, 11)]
    return obs_instructions, unob_instructions

def strange_stories_instruction_pool():
    obs_instructions1 = [f"NIV2/BC/{i}" for i in range(1, 11)]
    obs_instructions2 = [f"FLAN/BC/{i}" for i in range(1, 9)]
    obs_instructions = []; obs_instructions.extend(obs_instructions1); obs_instructions.extend(obs_instructions2)
    unob_instructions = [f"BBL/Unobserved/{i}" for i in range(1, 11)]
    return obs_instructions, unob_instructions

def strategy_qa_instruction_pool():
    obs_instructions1 = [f"NIV2/BC/{i}" for i in range(1, 11)]
    obs_instructions2 = [f"FLAN/BC/{i}" for i in range(1, 9)]
    obs_instructions = []; obs_instructions.extend(obs_instructions1); obs_instructions.extend(obs_instructions2)
    unob_instructions = [f"BBL/Unobserved/{i}" for i in range(1, 11)]
    return obs_instructions, unob_instructions

def vitaminc_fact_verification_instruction_pool():
    obs_instructions2 = [f"NIV2/Classification/{i}" for i in range(1, 21)]
    obs_instructions3 = [f"FLAN/Classification/{i}" for i in range(1, 16)]
    obs_instructions = []; obs_instructions.extend(obs_instructions2); obs_instructions.extend(obs_instructions3)
    unob_instructions = [f"BBL/Unobserved/{i}" for i in range(1, 11)]
    return obs_instructions, unob_instructions

def winowhy_instruction_pool():
    obs_instructions1 = [f"NIV2/BC/{i}" for i in range(1, 11)]
    obs_instructions2 = [f"FLAN/BC/{i}" for i in range(1, 9)]
    obs_instructions = []; obs_instructions.extend(obs_instructions1); obs_instructions.extend(obs_instructions2)
    unob_instructions = [f"BBL/Unobserved/{i}" for i in range(1, 11)]
    return obs_instructions, unob_instructions


DATASET2CONFIGS = {
    # data_dir, config_dir
    "MMLU_General": ("./data/MMLU", "./configs/MMLU/general.py", MMLUGeneralPreprocessor, mmlu_instruction_pool),
    "BBQ_Lite": (
        "./data/benchmark_tasks/bbq_lite_json/age_ambig",
        "./configs/BBH/multiple_choice/bbq_lite.py",
        BBQLitePreprocessor,
        bbq_instruction_pool
    ),
    "Code_Line_Description": (
        "./data/benchmark_tasks/code_line_description",
        "./configs/BBH/multiple_choice/code_line_description.py",
        CodeLineDescriptionPreprocessor,
        code_line_description_instruction_pool
    ),
    "Logical_Deduction": (
        "./data/benchmark_tasks/logical_deduction/five_objects/",
        "./configs/BBH/multiple_choice/logical_deduction.py",
        LogicalDeductionPreprocessor,
        logical_deduction_instruction_pool
    ),
    "Play_Dialog": (
        "./data/benchmark_tasks/play_dialog_same_or_different",
        "./configs/BBH/binary_classification/play_dialog_same_or_different.py",
        PlayDialogPreprocessor,
        play_dialog_instruction_pool
    ),
    "Vitaminc_Fact_Verification": (
        "./data/benchmark_tasks/vitaminc_fact_verification",
        "./configs/BBH/classification/vitaminc_fact_verification.py",
        VitamincFactPreprocessor,
        vitaminc_fact_verification_instruction_pool
    ),
    "StrategyQA": (
        "./data/benchmark_tasks/strategyqa",
        "./configs/BBH/binary_classification/strategy_qa.py",
        StrategyQAPreprocessor,
        strategy_qa_instruction_pool
    ),
    "Strange_Stories": (
        "./data/benchmark_tasks/strange_stories/boolean/",
        "./configs/BBH/binary_classification/strange_stories.py",
        StrangeStoriesPreprocessor,
        strange_stories_instruction_pool
    ),
    "Language_Identification": (
        "./data/benchmark_tasks/language_identification",
        "./configs/BBH/classification/language_identification.py",
        LanguageIdentificationPreprocessor,
        language_identification_instruction_pool
    ),
    "Known_Unknowns": (
        "./data/benchmark_tasks/known_unknowns",
        "./configs/BBH/multiple_choice/known_unknowns.py",
        KnownUnknownPreprocessor,
        known_unknowns_instruction_pool
    ),
    "Hindu_Knowledge": (
        "./data/benchmark_tasks/hindu_knowledge",
        "./configs/BBH/multiple_choice/hindu_knowledge.py",
        HinduKnowledgePreprocessor,
        hindu_knowledge_instruction_pool
    ),
    "Novel_Concepts": (
        "./data/benchmark_tasks/novel_concepts",
        "./configs/BBH/multiple_choice/novel_concepts.py",
        NovelConceptsPreprocessor,
        novel_concepts_instruction_pool
    ),
    "Winowhy": (
        "./data/benchmark_tasks/winowhy",
        "./configs/BBH/binary_classification/winowhy.py",
        WinowhyPreprocessor,
        winowhy_instruction_pool
    ),
    "Conceptual_Combinations": (
        "./data/benchmark_tasks/conceptual_combinations/contradictions/",
        "./configs/BBH/multiple_choice/conceptual_combinations.py",
        ConceptualCombinationPreprocessor,
        conceptual_combinations_instruction_pool
    ),
}