import torch
import argparse
import matplotlib.pyplot as plt
import os
import numpy as np
from config.config_loader import ConfigLoader
from utils import create_actor_critic, get_device
from tensordict import TensorDict


def load_model(model_path, config, device):
    actor, _ = create_actor_critic(config, device)
    checkpoint = torch.load(model_path, map_location=device)
    if isinstance(checkpoint, dict) and 'actor' in checkpoint:
        actor.load_state_dict(checkpoint['actor'])
    else:
        actor.load_state_dict(checkpoint)
    actor.eval()
    return actor


def get_raw_energy_scores(batch, actor, device):
    if isinstance(batch, dict):
        input_td = TensorDict(
            {k: v.to(device) for k, v in batch.items()},
            batch_size=[next(iter(batch.values())).shape[0]]
        ).to(device)
    else:
        input_td = TensorDict(
            {"image": batch.to(device)},
            batch_size=[batch.shape[0]]
        ).to(device)

    with torch.no_grad():
        output_td = actor(input_td)
        raw_energy = output_td["raw_energy"]  # [N]

    return raw_energy.cpu().tolist()



def extract_model_number(model_path):
    filename = os.path.basename(model_path)
    if filename.startswith("model-") and filename.endswith(".pt"):
        return filename.split("-")[1].split(".")[0]
    else:
        raise ValueError("Model filename must be in the format model-{num}.pt")


def extract_filename_base(path):
    return os.path.splitext(os.path.basename(path))[0]


def plot_energy_distributions(iid_scores, ood_scores, save_path, model_num):
    plt.figure(figsize=(8, 6))

    plt.hist(iid_scores, bins=50, density=True, label='iid', alpha=0.5)
    plt.hist(ood_scores, bins=50, density=True, label='ood-all', alpha=0.5)

    iid_threshold = np.percentile(iid_scores, 80)
    plt.axvline(x=iid_threshold, color='black', linestyle='dashed', label='iid threshold')

    plt.xlabel('Raw Energy')
    plt.ylabel('Density')
    plt.title(f'Histogram for Checkpoint {model_num}')
    plt.legend()

    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path, dpi=300)
    plt.close()
    print(f"Saved plot to {save_path}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--iid_batch", type=str, required=True, help="Path to IID batch .pt file")
    parser.add_argument("--ood_batch", type=str, required=True, help="Path to OOD batch .pt file")
    parser.add_argument("--model", type=str, required=True, help="Path to model checkpoint (model-xxx.pt)")
    parser.add_argument("--config", type=str, default="multi_grid", help="Name of config to initialize model")
    parser.add_argument("--specific", action="store_true", help="Use specific naming with batch names")
    args = parser.parse_args()

    config = ConfigLoader.load_config(args.config, None)
    device = get_device(config)
    actor = load_model(args.model, config, device)

    iid_batch = torch.load(args.iid_batch, map_location=device)
    ood_batch = torch.load(args.ood_batch, map_location=device)

    print("Computing raw_energy scores for IID batch...")
    iid_raw_energy = get_raw_energy_scores(iid_batch, actor, device)

    print("Computing raw_energy scores for OOD batch...")
    ood_raw_energy = get_raw_energy_scores(ood_batch, actor, device)

    model_dir = os.path.dirname(args.model)
    model_num = extract_model_number(args.model)
    iid_name = extract_filename_base(args.iid_batch)
    ood_name = extract_filename_base(args.ood_batch)

    # Directories stay the same
    plot_dir = os.path.join(model_dir, "energy_distribution")
    iid_save_dir = os.path.join(model_dir, "iid_raw_energy")
    ood_save_dir = os.path.join(model_dir, "ood_raw_energy")

    # Filenames depend on --specific flag
    if args.specific:
        plot_filename = f"{model_num}_{iid_name}_{ood_name}.png"
        iid_filename = f"{model_num}_{iid_name}.pt"
        ood_filename = f"{model_num}_{ood_name}.pt"
    else:
        if iid_name.startswith("random"):
            iid_filename = f"{model_num}_random.pt"
        else:
            iid_filename = f"{model_num}.pt"

        if ood_name.startswith("random"):
            ood_filename = f"{model_num}_random.pt"
        else:
            ood_filename = f"{model_num}.pt"

        if iid_name.startswith("random") or ood_name.startswith("random"):
            plot_filename = f"{model_num}_random.png"
        else:
            plot_filename = f"{model_num}.png"



    plot_path = os.path.join(plot_dir, plot_filename)
    iid_save_path = os.path.join(iid_save_dir, iid_filename)
    ood_save_path = os.path.join(ood_save_dir, ood_filename)

    # Save plot
    plot_energy_distributions(iid_raw_energy, ood_raw_energy, plot_path, model_num)

    # Save raw energy scores
    os.makedirs(iid_save_dir, exist_ok=True)
    os.makedirs(ood_save_dir, exist_ok=True)
    torch.save(torch.tensor(iid_raw_energy), iid_save_path)
    print(f"Saved IID energy scores to {iid_save_path}")
    torch.save(torch.tensor(ood_raw_energy), ood_save_path)
    print(f"Saved OOD energy scores to {ood_save_path}")


if __name__ == "__main__":
    main()
