import argparse

from pytorch_lightning.utilities.types import EVAL_DATALOADERS
from configs.utils import load_BBL_file
from datasets import concatenate_datasets, Dataset
import pytorch_lightning as pl
from tqdm import tqdm
import random
import os
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfig
from pytorch_lightning.utilities.data import DataLoader
import torch
import json
from distance_utils import DATASET2CONFIGS, preprocess, DistanceDataModule, get_dataset_hidden_states, get_distance
    
    

def distance_stat(args):
    
    if args.use_prefix:
        config = PeftConfig.from_pretrained(args.model_path_or_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, torch_dtype=torch.bfloat16, output_hidden_states=True, return_dict_in_generate=True)
        model = PeftModel.from_pretrained(model, args.model_path_or_name)
    else:
        model = AutoModelForSeq2SeqLM.from_pretrained(args.model_path_or_name, torch_dtype=torch.bfloat16, output_hidden_states=True, return_dict_in_generate=True)
    
    tokenizer = AutoTokenizer.from_pretrained(args.model_path_or_name)
    
    model = model.to("cuda")
        
    
    if args.result_csv is None:
        all_datasets = DATASET2CONFIGS.keys()
    else:
        result_df = pd.read_csv(args.result_csv)
        all_datasets = result_df["Dataset"].unique()
        
    for dataset in all_datasets:
        print("Getting distances for dataset: ", dataset)
        if dataset not in DATASET2CONFIGS.keys():
            print("Dataset {} not found for distance experiment".format(dataset))
            continue
        data_dir, config_dir, proc_class, instruction_pool = DATASET2CONFIGS[dataset]
        
        if args.result_csv is None:
            obs_instructions, unobs_instructions = instruction_pool()
        else:
            obs_instructions, unobs_instructions = [], []
            dataset_result_df = result_df[result_df["Dataset"] == dataset]
            
            for i in dataset_result_df.index:
                row = dataset_result_df.loc[i]
                instruction_name = "{}/{}/{}".format(row["Collection"], row["Type"], str(row["ID"]))
                if row["Type"] in ["Unobserved", "Default"]:
                    unobs_instructions.append(instruction_name)
                else:
                    obs_instructions.append(instruction_name)
            assert len(obs_instructions) == len(unobs_instructions)
            
        obs_dataset, unob_dataset = preprocess(config_dir, data_dir, proc_class, obs_instructions, unobs_instructions, args.instance_samples_cap, args.instruction_samples_cap, args.seed)
        
        obs_dataset = DistanceDataModule(obs_dataset, tokenizer, args.batch_size)
        unob_dataset = DistanceDataModule(unob_dataset, tokenizer, args.batch_size)
        
        obs_dataloader = obs_dataset.test_dataloader()
        unob_dataloader = unob_dataset.test_dataloader()
        
        obs_mapping = obs_dataset.test_mapping()
        unob_mapping = unob_dataset.test_mapping()
        
        print("Getting hidden states for observed instructions")
        obs_hs = get_dataset_hidden_states(model, obs_dataloader, args.token_pos, args.layer)
        print("Getting hidden states for unobserved instructions")
        unobs_hs = get_dataset_hidden_states(model, unob_dataloader, args.token_pos, args.layer)
        print("Getting closest instructions and their average distances")
        distance_df = get_distance(unob_mapping, obs_mapping, unobs_hs, obs_hs)
        
        distance_df.to_csv(os.path.join(args.output_dir, "{}.csv".format(dataset)))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path_or_name", type=str, default="google/flan-t5-xl")
    parser.add_argument("--use_prefix", default=False, action="store_true")
    parser.add_argument("--result_csv", type=str, default=None)
    parser.add_argument("--output_dir", type=str, default="./results_csv/Distance/Flan-T5-XL")
    
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--instance_samples_cap", type=int, default=600)
    parser.add_argument("--instruction_samples_cap", type=int, default=10)
    parser.add_argument("--layer", type=int, default=-1)
    parser.add_argument("--token_pos", type=int, default=0)
    
    parser.add_argument("--seed", type=int, default=42)
    
    args = parser.parse_args()
    
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)
        
    distance_stat(args)
        
    
if __name__ == "__main__":
    main()
