import os
import argparse
import torch
from huggingface_hub import login
import pandas as pd

from utils import model_utils, data_utils, evaluate_utils, seed_utils, extract_utils
from utils.logging_utils import logger

if __name__ == "__main__":
    # # huggingface login
    # HF_TOKEN = os.getenv("HF_TOKEN")
    # login(token=HF_TOKEN)

    # set torch device
    torch.set_default_device('cuda' if torch.cuda.is_available() else 'cpu')

    parser = argparse.ArgumentParser()

    parser.add_argument('--model_name', type=str, required=True, help="Name of the model to load")
    parser.add_argument('--task_name', type=str, required=False, default='country-capital',
                        help="Name of the task to evaluate")
    parser.add_argument('--train_dataset_name', type=str, required=False, help="Name of the training dataset")
    parser.add_argument('--test_dataset_name', type=str, required=False,
                        help="Name of the causal indirect effect test dataset")
    parser.add_argument('--inject_dataset_name', type=str, required=False,
                        help="Name of the dataset for FV intervention")
    parser.add_argument('--topk', type=int, required=False, default=20,
                        help="Top k heads to consider for function vector")
    parser.add_argument('--device', type=str, required=False, default='cuda', help="Device to use for model inference")
    parser.add_argument('--seed', type=int, required=False, default=42, help="Random seed for reproducibility")

    args = parser.parse_args()
    model_name = args.model_name
    task_name = args.task_name
    train_dataset_name = args.train_dataset_name
    test_dataset_name = args.test_dataset_name
    inject_dataset_name = args.inject_dataset_name
    topk = args.topk
    device = args.device
    seed = int(args.seed)

    model_name_split = model_name.split('/')[1]

    # Set the device for PyTorch
    torch.set_default_device('cuda' if device == "cuda" else 'cpu')
    # Forbid gradient calculation
    torch.set_grad_enabled(False)

    # Set random seed
    seed_utils.set_seed(seed)

    # Load the model and tokenizer
    model_wrapper = model_utils.load_model_and_tokenizer(model_name, device=device)

    # Load the datasets
    train_dataset = data_utils.load_dataset(train_dataset_name)
    test_dataset = data_utils.load_dataset(test_dataset_name)
    inject_dataset = data_utils.load_dataset(inject_dataset_name)

    # Make save directories
    DATA_SAVE_DIR = f"./logs/{model_name_split}/{task_name}"
    FIGURE_SAVE_DIR = f"./logs/{model_name_split}/figures/{task_name}"
    os.makedirs(DATA_SAVE_DIR, exist_ok=True)
    os.makedirs(FIGURE_SAVE_DIR, exist_ok=True)

    # Get mean head activation
    mean_head_activation_path = os.path.join(DATA_SAVE_DIR, f"mean_head_activation_{train_dataset_name}.pt")
    if os.path.exists(mean_head_activation_path):
        # If file exists, load the mean head activation
        mean_head_activation = torch.load(mean_head_activation_path, map_location=device)
    else:
        # Extract and save mean head activation
        mean_head_activation = extract_utils.get_mean_head_activation(model_wrapper, train_dataset)
        torch.save(mean_head_activation, mean_head_activation_path)

    # Get mean causal indirect effects
    CIE_path = os.path.join(DATA_SAVE_DIR, f"mean_causal_indirect_effects_{train_dataset_name}_{test_dataset_name}.pt")
    if os.path.exists(CIE_path):
        # If file exists, load the mean causal indirect effects
        mean_causal_indirect_effects = torch.load(CIE_path, map_location=device)
    else:
        # Calculate and save the causal indirect effects
        causal_indirect_effects = evaluate_utils.calculate_causal_indirect_effect(model_wrapper, test_dataset,
                                                                                  mean_head_activation)  # Shape: (batch_size, num_layers, num_heads)
        mean_causal_indirect_effects = torch.mean(causal_indirect_effects, dim=0)

        # Plot and save the causal indirect effect heatmap
        CIE_heatmap_path = os.path.join(FIGURE_SAVE_DIR,
                                        f"causal_indirect_effect_heatmap_{train_dataset_name}_{test_dataset_name}.pdf")
        evaluate_utils.plot_causal_indirect_effect_heatmap(mean_causal_indirect_effects, save_path=CIE_heatmap_path)

        torch.save(mean_causal_indirect_effects, CIE_path)