import os 
import json
json_path = 'env_config.json'
with open(json_path, 'r') as file:
    env_config = json.load(file)

hf_home = env_config['HF_HOME']
# Set the HF_HOME environment variable
os.environ['HF_HOME'] = hf_home
# Set the access token to huggingface hub
access_token = env_config['access_token']
os.environ['HUGGINGFACE_HUB_TOKEN'] = access_token
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

from typing import List
import numpy as np
from llmexp.llm.smollm import LLMWrapper, Template
from accelerate import Accelerator
from llmexp.data import EvalDataset
import random 
from llmexp.utils.evaluation import calculate_avg_log_prob_diff, calculate_bertscore
from llmexp.explainer.leave_one_out import LeaveOneOutWrapper
import pandas as pd
import torch

def get_perturbed_segments(segments: List[str], theta: np.ndarray, k: int = 3):
    # remove the top k segments
    non_top_k_indices = np.argsort(theta)[:-k]
    non_top_k_indices = np.sort(non_top_k_indices)
    non_top_k_segments = [segments[i] for i in non_top_k_indices]

    return non_top_k_segments

def get_results_for_k(llm: LLMWrapper, segments: List[str], question: str, response: str, theta: np.ndarray, log_likelihood: torch.Tensor, k: int = 3):
    perturbed_segments = get_perturbed_segments(segments, theta, k=k)
    perturbed_response = llm.get_response(perturbed_segments, question)
    perturbed_log_likelihood = llm.get_log_likelihood(perturbed_segments, question, response)
    log_prob_drop = calculate_avg_log_prob_diff(perturbed_log_likelihood, log_likelihood)
    bertscore = calculate_bertscore(perturbed_response, response)
    return log_prob_drop, bertscore

if __name__ == "__main__":
    random.seed(42)
    
    # define parameters
    K = 100
    num_samples = 1  # Number of samples for LOO averaging
    dataset_name = "cnn"
    model_name = "smollm"

    accelerator = Accelerator()
    device = accelerator.device

    # load the model
    if model_name == "llama":
        checkpoint = "meta-llama/Meta-Llama-3-8B-Instruct"
    elif model_name == "qwen":
        checkpoint = "Qwen/Qwen3-8B"
    elif model_name == "smollm":
        checkpoint = "HuggingFaceTB/SmolLM-1.7B-Instruct"
    llm = LLMWrapper(checkpoint, device=device, access_token=access_token)
    tokenizer = llm.tokenizer
    template = Template(tokenizer, task='qa')
    
    # load the data
    dataset = EvalDataset(dataset_name, split='test')
    # get the first example
    total_samples = len(dataset)
    sampled_indices = random.sample(range(total_samples), K)
    test_data = [dataset[i] for i in sampled_indices]
    
    # load explainer 
    explainer = LeaveOneOutWrapper(llm, tokenizer, device)
    
    log_prob_drop_k1_list = []
    bertscore_k1_list = []
    log_prob_drop_k3_list = []
    bertscore_k3_list = []
    log_prob_drop_k5_list = []
    bertscore_k5_list = []
    
    for data in test_data:
        # Use standardized EvalDataset format
        segments = data['segment']
        question = data['question']
        response = llm.get_response(segments, question)
    
        segments, theta = explainer.attribute(segments, question, response, num_samples=num_samples)
        
        log_likelihood = llm.get_log_likelihood(segments, question, response)
        
        log_prob_drop_k1, bertscore_k1 = get_results_for_k(llm, segments, question, response, theta, log_likelihood, k=1)
        log_prob_drop_k3, bertscore_k3 = get_results_for_k(llm, segments, question, response, theta, log_likelihood, k=3)
        log_prob_drop_k5, bertscore_k5 = get_results_for_k(llm, segments, question, response, theta, log_likelihood, k=5)
        
        # save the results as a csv file 
        log_prob_drop_k1_list.append(log_prob_drop_k1)
        bertscore_k1_list.append(bertscore_k1)
        log_prob_drop_k3_list.append(log_prob_drop_k3)
        bertscore_k3_list.append(bertscore_k3)
        log_prob_drop_k5_list.append(log_prob_drop_k5)
        bertscore_k5_list.append(bertscore_k5)
        
        results = {
            'log_prob_drop_k1': log_prob_drop_k1_list,
            'log_prob_drop_k3': log_prob_drop_k3_list,
            'log_prob_drop_k5': log_prob_drop_k5_list,
            'bertscore_k1': bertscore_k1_list,
            'bertscore_k3': bertscore_k3_list,
            'bertscore_k5': bertscore_k5_list
        }
        
        # save the results as a csv file 
        results_df = pd.DataFrame(results)
        results_df.to_csv(f'results/{model_name}_{dataset_name}_loo_nsamples_{num_samples}.csv', index=False)
