import torch
import argparse
import random
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch.nn as nn
from wilds import get_dataset
from wilds.common.data_loaders import get_eval_loader
import torchvision.transforms as transforms
from torchvision import models
from src.utils import *
from src.lp_robust_cp import LPRobustCP

# specify device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# argument parser
parser = argparse.ArgumentParser('Robust-CP')
parser.add_argument('--alpha',          type=float, default=0.1,  help="user prescribed confidence (1-alpha)")
parser.add_argument('--batch_size',     type=int,   default=512, help="batch size for loading data")
parser.add_argument('--save',           type=str,   default='experiments/wilds', help="define the save directory")
args = parser.parse_args()


def get_scores(loader):
    with torch.no_grad():
        score_list = []
        label_list = []
        for _, batch in tqdm(enumerate(loader), total=len(loader)):
            features, labels, _ = batch
            features.to(device)
            labels.to(device)
            score_list.append(nll_score(model, features, labels))
            label_list.append(labels)
            torch.cuda.empty_cache()
        scores = torch.cat(score_list, dim=0)
        labels = torch.cat(label_list)
        truey_scores = scores[torch.arange(scores.size(0)), labels]
    return truey_scores.cpu().numpy()


def plot_distributions(oodtest, idtest, idval):
    """
    Plots the empirical distributions (histograms with KDE) of three arrays
    on a single set of axes for clear visual comparison.
    """

    # Flatten/squeeze each array to 1D
    a_1d = oodtest.ravel()
    b_1d = idtest.ravel()
    c_1d = idval.ravel()

    # plt.figure(figsize=(10, 6))

    # Plot histograms + KDE for each array as 1D
    sns.histplot(a_1d, color='blue',  alpha=0.7, stat='density', kde=False, label='OOD Test')
    sns.histplot(b_1d, color='red',   alpha=0.7, stat='density', kde=False, label='ID Test')
    sns.histplot(c_1d, color='green', alpha=0.7, stat='density', kde=False, label='ID Val')

    plt.legend()
    plt.title('Comparison of Distributions')
    plt.xlabel('Value')
    plt.ylabel('Density')
    plt.tight_layout()
    plt.savefig('wilds_score_plt.png', dpi=300, bbox_inches='tight')
    plt.show()
    

"""
Set-up Stage
"""

# Load model
print('Loading model ...')
model = models.resnet50(weights=None, progress=False)
model.fc = nn.Linear(2048, 182)
model.eval()

# Load parameters
print('Loading parameters ...')

checkpoint = torch.load('../pretrained_models/wilds_model.pth')

state_dict = checkpoint['algorithm']
new_state_dict = {}
for k, v in state_dict.items():
    name = k[6:]
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.to(device)

# load dataset
dataset = get_dataset(dataset="iwildcam", download=False)

# TEST OOD
ood_test_data = dataset.get_subset(
    "test",
    transform=transforms.Compose(
        [transforms.Resize((448, 448)), transforms.ToTensor()]
    ),
)
# TEST ID
id_test_data = dataset.get_subset(
    "id_test",
    transform=transforms.Compose(
        [transforms.Resize((448, 448)), transforms.ToTensor()]
    ),
)
# VAL ID
id_val_data = dataset.get_subset(
    "id_val",
    transform=transforms.Compose(
        [transforms.Resize((448, 448)), transforms.ToTensor()]
    ),
)
oodtest_loader = get_eval_loader("standard", ood_test_data, batch_size=args.batch_size)
idtest_loader = get_eval_loader("standard", id_test_data, batch_size=args.batch_size)
idval_loader = get_eval_loader("standard", id_val_data, batch_size=args.batch_size)

# compute scores
oodtst_scores = get_scores(oodtest_loader)
idtst_scores = get_scores(idtest_loader)
idval_scores = get_scores(idval_loader)

# save scores
np.savez(f'wilds_scores.npz', oodtst_scores=oodtst_scores, idtst_scores=idtst_scores, idval_scores=idval_scores)

# plot_distributions(oodtst_scores, idtst_scores, idval_scores)
