import numpy as np
import matplotlib.pyplot as plt
import torch
import time
import random
import os
from datasets import load_dataset
from LLM_models.llama.llama import *
import glob
from typing import Dict, List

from transformers import AutoConfig, AutoTokenizer
from torch.profiler import profile, record_function, ProfilerActivity


import re

def get_dataset_dataset_processors(selected_datasets):
    dataset_processors = {
        "piqa": (lambda: process_dataset(load_dataset("piqa", split='validation'), process_doc_piqa)),
        "boolq": (lambda: process_dataset(load_dataset("super_glue", "boolq", split='validation'), process_doc_boolq)),
        "arc": (lambda: process_dataset(load_dataset("allenai/ai2_arc", 'ARC-Challenge', split='validation'), process_doc_arc)),
        "hellaswag": (lambda: process_dataset(load_dataset("hellaswag", split='validation'), process_doc_hellaswag)),
        "winogrande": (lambda: process_dataset(load_dataset("winogrande", "winogrande_xl", split='validation'), process_doc_wino)),
        "gsm8k": (lambda: process_dataset(load_dataset("gsm8k", "main", split="test"), process_doc_gsm8k)),
        "math": (lambda: process_dataset(load_dataset("nlile/hendrycks-MATH-benchmark", split="test"), process_doc_math)),
    }
    datasets = {name: processor() for name, processor in dataset_processors.items() if name in selected_datasets}
    return datasets

    

def load_llama_model_and_tokenizer(model_name,tokenizer_name,device):
    hf_model_dir = os.path.expanduser(model_name)
    if not os.path.exists(hf_model_dir):
        print('You should download the LLaMa-2-7b-hf model from huggingface.')
    else:
        print('Loading "meta-llama/Llama-2-7b-hf"')

    print('CUDA is available?', torch.cuda.is_available())
    print('CUDA version:', torch.version.cuda)

    
    model = LlamaForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16
    ).to(device)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    return model, tokenizer

def process_doc_gsm8k(doc):
    return {
        "query": doc["question"].strip() + "\nAnswer:",
        "choices": [doc["answer"].strip()],  
        "gold": 0
    }

def process_doc_math(doc):
    return {
        "query": doc["problem"].strip() + "\nAnswer:",
        "choices": [doc["solution"].strip()],
        "gold": 0
    }


    
def process_doc_wino(doc):
    answer_to_num = {"1": 0, "2": 1}
    idx = doc["sentence"].index("_")
    choices = [doc["sentence"][:idx] + doc["option1"], doc["sentence"][:idx] + doc["option2"]]
    return {
        "query": doc['sentence'],
        "choices": choices,
        "gold": answer_to_num[doc["answer"]]
    }

def process_doc_piqa(doc):
    return {
        "query": doc['goal'],
        "choices": [doc['sol1'], doc['sol2']],
        "gold": int(doc["label"])
    }

def process_doc_boolq(doc):
    return {
        "query": f"Passage: {doc['passage']}\nQuestion: {doc['question']}?\nAnswer:",
        "choices": ["no", "yes"],
        "gold": int(doc["label"])
    }

def process_doc_arc(doc):
    num_to_letter = {"1": "A", "2": "B", "3": "C", "4": "D", "5": "E"}
    doc["answerKey"] = num_to_letter.get(doc["answerKey"], doc["answerKey"])
    return {
        "id": doc["id"],
        "query": "Question: " + doc["question"] + "\nAnswer:",
        "choices": doc["choices"]["text"],
        "gold": ["A", "B", "C", "D", "E"].index(doc["answerKey"]),
    }

def process_doc_hellaswag(doc):
    def process(text):
        text = text.strip()
        text = text.replace(" [title]", ". ")
        text = re.sub("\\[.*?\\]","", text)
        text = text.replace("  "," ")
        return text

    try:
        ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
        return {
            "query": process(doc["activity_label"] + ": " + ctx),
            "choices": [process(ending) for ending in doc["endings"]],
            "gold": int(doc["label"]),
        }
    except Exception as e:
        print(f"Error: {doc}")
        print(f"Exception: {e}")
        return None

def process_dataset(dataset, process_function):
    tokenized_dataset = dataset.map(process_function, remove_columns=dataset.column_names)
    tokenized_dataset = tokenized_dataset.filter(lambda x: x is not None)
    return tokenized_dataset

def opti_log_likelihood(model, tokenizer, context, ending,sc_config):
    input_ids = tokenizer.encode(context + " " + ending, return_tensors="pt").to(model.device)
    target_ids = input_ids.clone()
    with torch.no_grad():
        out = model(input_ids, labels=target_ids, sc_config = sc_config)
        log_likelihood = -out.loss
    return log_likelihood.item()

def Llama_opti_test(model,sc_config,loader,tokenizer):
    model.eval()
    correct = 0
    total = len(loader)
    
    with torch.no_grad():
        for i, example in enumerate(loader):
            context = example['query']
            endings = example['choices']
            label = example['gold']
            
            log_likelihoods = []
            for ending in endings:
                log_likelihood = opti_log_likelihood(model, tokenizer=tokenizer, context=context,ending=ending,
                                                     sc_config=sc_config)
                log_likelihoods.append(log_likelihood)
                
            predicted_label = np.argmax(log_likelihoods)
            correct += int(predicted_label == label)
            
        return correct / total



def slice_tensor_front(tensor, dim=1, length=10):
    idx = [slice(None)] * tensor.dim()  
    idx[dim] = slice(0, length)       
    return tensor[tuple(idx)]

def get_LLM_ifs(model,tokenizer,save_path,w_bar,datasets,split_layer,n_samples):
    model_name = model.__class__.__name__  
    cnt = 0
    
    for dataset_name, tokenized_dataset in datasets.items():
        print(f"\n=== Starting experiments on dataset: {dataset_name} ===")
        save_dir = os.path.join(save_path, model_name,f"{dataset_name}", f"W{w_bar}")
        os.makedirs(save_dir, exist_ok=True)
        with torch.no_grad():
            for i, example in enumerate(tokenized_dataset):
                context = example['query']
                endings = example['choices']

                for ending in endings:
                    input_ids = tokenizer.encode(context + " " + ending, return_tensors="pt").to(model.device)
                    out, _ = model.split_edge(input_ids, split_layer=split_layer)
                    
                    if out.shape[1] < w_bar:
                        continue
                    else:
                        IFs = slice_tensor_front(out, dim=1, length=w_bar)
                        IF_cpu = IFs.detach().cpu()
                        file_name = os.path.join(save_dir, f"{cnt}.pt")
                        cnt+=1
                        torch.save(IF_cpu, file_name)
            
                        if cnt>=n_samples:
                            return
                        
                        
def load_all_LLM_IF_features(save_path, model_name,dataset_name, w_bar):

    dir_path = os.path.join(save_path, model_name,f"{dataset_name}", f"W{w_bar}")
    pt_files = glob.glob(os.path.join(dir_path, '*.pt'))


    def extract_index(file_path):
        file_name = os.path.basename(file_path)  
        idx_str = file_name.replace('.pt', '')   
        return int(idx_str)
    
    pt_files.sort(key=extract_index)

    IF_list = []
    for file_path in pt_files:
        data = torch.load(file_path)
        IF_list.append(data)

    return IF_list


