import torch
import argparse
import random
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch.nn as nn
from wilds import get_dataset
from wilds.common.data_loaders import get_eval_loader, get_train_loader
import torchvision.transforms as transforms
from torchvision import models
import matplotlib.pyplot as plt

# specify device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def evaluate_model_accuracy(model, dat_loader):
    model.eval()  # Set the model to evaluation mode
    num_correct = 0
    total_samples = 0

    with torch.no_grad():  # Disable gradient computation for evaluation
        for x, y_true, _ in tqdm(val_loader, desc="Evaluating", unit="batch"):
            x, y_true = x.to(device), y_true.to(device)

            # Forward pass to get predictions
            y_pred = model(x)
            
            # Compute predicted labels
            label_pred = torch.argmax(softmax(y_pred), dim=1)

            # Count correct predictions
            num_correct += (label_pred == y_true).sum().item()
            total_samples += y_true.size(0)

    # Compute overall accuracy
    accuracy = num_correct / total_samples
    print(f'Overall Accuracy: {accuracy:.4f}')
    return accuracy


"""
Set-up Stage
"""

# Load model
print('Loading model ...')
model = models.resnet50(weights=False, progress=False)
model.fc = nn.Linear(2048, 182)

# Load parameters
print('Loading parameters ...')
# checkpoint = torch.load('../pretrained_models/iwildcam_seed_0_epoch_best_model_erm.pth', map_location=device)
checkpoint = torch.load('../pretrained_models/iwildcam_seed_0_epoch_best_model_aug.pth', map_location=device)
state_dict = checkpoint['algorithm']
new_state_dict = {}
for k, v in state_dict.items():
    name = k[6:]
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.to(device)

softmax = nn.Softmax(dim=1)

# load datasets
dataset = get_dataset(dataset="iwildcam", download=False)

train_data = dataset.get_subset(
    "train",
    transform=transforms.Compose(
        [transforms.Resize((448, 448)), transforms.ToTensor()]
    ),
)
train_loader = get_train_loader("standard", train_data, batch_size=32)

id_test_data = dataset.get_subset(
    "id_test",
    transform=transforms.Compose(
        [transforms.Resize((448, 448)), transforms.ToTensor()]
    ),
)
idtest_loader = get_eval_loader("standard", id_test_data, batch_size=32)

test_data = dataset.get_subset(
    "test",
    transform=transforms.Compose(
        [transforms.Resize((448, 448)), transforms.ToTensor()]
    ),
)

# Prepare the evaluation data loader
test_loader = get_eval_loader("standard", test_data, batch_size=32)

evaluate_model_accuracy(model, train_loader)
evaluate_model_accuracy(model, idtest_loader)
evaluate_model_accuracy(model, test_loader)