import torch
from torch.utils.data import DataLoader
from sklearn.metrics import roc_curve, accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from torchvision import datasets, transforms
from colorcubenet import CustomEfficientNet  # Import the CustomEfficientNet model
from colorcube import ColorCubeTransform  # Import the ColorCubeTransform
import os
import numpy as np

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Path to the test dataset folder
test_link = '/projects/emarasco/iiitd_patch/test'  #<------------Drop the test link for the dataset here

# Path to the saved model
model_path = './pth/colorcubenet_iiitd.pth'

# Define transformation for test data
transform_val_test = transforms.Compose([
    transforms.Resize(256),          # Resize the image to 256x256 pixels
    transforms.CenterCrop(224),      # Center crop to 224x224
    ColorCubeTransform(),            # Convert to the custom ColorCube (9 channels)
    transforms.Normalize(mean=[0.485] * 9, std=[0.229] * 9)  # Normalize for 9 channels
])

# Load the test dataset
test_dataset = datasets.ImageFolder(test_link, transform=transform_val_test)

# Create the test DataLoader
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
print(f"Loaded test dataset from {test_link}")



# Load the saved model
model = CustomEfficientNet(num_classes=2)  # For binary classification
model.load_state_dict(torch.load(model_path, map_location=device))  # Load model weights
model = model.to(device)
print(f"Loaded model from {model_path}")

# Function to compute EER
def compute_eer(fpr, tpr):
    """ Returns equal error rate (EER) and the corresponding threshold. """
    fnr = 1 - tpr
    abs_diffs = np.abs(fpr - fnr)
    min_index = np.argmin(abs_diffs)
    eer = np.mean((fpr[min_index], fnr[min_index]))
    return eer * 100

# Function to compute BPCER at given APCER targets
def calculate_bpcer_at_apcer(fpr, tpr, thresholds, apcer_targets):
    results = {}
    if isinstance(apcer_targets, float):
        apcer_targets = [apcer_targets]

    for apcer_target in apcer_targets:
        idx = np.argmin(np.abs(fpr - apcer_target))
        bpcer = 1 - tpr[idx]
        threshold = thresholds[idx]
        results[apcer_target] = (bpcer, threshold)

    return results

# Evaluation function
def evaluate_model(model, test_loader, device):
    model.eval()

    test_scores = []
    test_true_labels = []
    test_pred_labels = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(inputs)
            
            # Softmax to get probabilities, and take the probability of the positive class (index 1)
            probs = torch.nn.functional.softmax(outputs, dim=1)[:, 1]
            
            # Collect scores and labels
            test_scores.extend(probs.cpu().numpy())
            test_true_labels.extend(labels.cpu().numpy())
            test_pred_labels.extend(torch.argmax(outputs, dim=1).cpu().numpy())

    # Compute FPR, TPR, and thresholds for ROC
    fpr, tpr, thresholds = roc_curve(test_true_labels, test_scores, pos_label=1)

    # Compute EER
    eer = compute_eer(fpr, tpr)

    # Compute additional metrics
    accuracy = accuracy_score(test_true_labels, test_pred_labels)
    precision = precision_score(test_true_labels, test_pred_labels)
    recall = recall_score(test_true_labels, test_pred_labels)
    f1 = f1_score(test_true_labels, test_pred_labels)
    cm = confusion_matrix(test_true_labels, test_pred_labels)

    # Calculate BPCER at specific APCER thresholds
    apcer_targets = [0.05, 0.10, 0.01, 0.005]
    bpcer_results = calculate_bpcer_at_apcer(fpr, tpr, thresholds, apcer_targets)

    # Print metrics
    print(f'Test EER: {eer:.2f}%')
    print(f'Test Accuracy: {accuracy:.4f}')
    print(f'Test Precision: {precision:.4f}')
    print(f'Test Recall: {recall:.4f}')
    print(f'Test F1 Score: {f1:.4f}')
    print(f'Test Confusion Matrix:\n{cm}')

    for apcer_target in apcer_targets:
        bpcer, threshold = bpcer_results[apcer_target]
        print(f'BPCER at APCER = {apcer_target * 100:.1f}%: {bpcer:.4f}, threshold: {threshold:.4f}')

    return eer, accuracy, precision, recall, f1, cm

# Evaluate the model
evaluate_model(model, test_loader, device)
