# Databricks notebook source
# MAGIC %md
# MAGIC # ResNet18 Evaluation with Random Image Region Deleting
# MAGIC
# MAGIC

# COMMAND ----------

import os
import json
import numpy as np
from PIL import Image
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
import shap

# Custom dataset with image names
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.img_paths = [p for p, _ in dataset.imgs]

    def __getitem__(self, index):
        img, label = self.dataset[index]
        img_path = self.img_paths[index]
        return img, label, os.path.basename(img_path)

    def __len__(self):
        return len(self.dataset)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ----------- Utility Functions -----------

def nchw_to_nhwc(x: torch.Tensor) -> torch.Tensor:
    return x if x.shape[-1] == 3 else x.permute(0, 2, 3, 1) if x.dim() == 4 else x.permute(1, 2, 0)

def nhwc_to_nchw(x: torch.Tensor) -> torch.Tensor:
    return x if x.shape[1] == 3 else x.permute(0, 3, 1, 2) if x.dim() == 4 else x.permute(2, 0, 1)

def predict(img: np.ndarray) -> torch.Tensor:
    img_tensor = nhwc_to_nchw(torch.tensor(img)).unsqueeze(0).to(device)
    return model(img_tensor)

# ----------- Image Transformations -----------
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std  = [0.229, 0.224, 0.225]

img_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])

mask_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x[0].unsqueeze(0))  # keep only first channel → shape [1, H, W]
])



# ----------- Paths and Class Mapping -----------
val_root = "/Workspace/Users/ANONYM/ShapCon/data/validation"
segmented_val_root = "/Workspace/Users/ANONYM/ShapCon/data/validation-segmentation"

# List of synset names (subfolder names)
imagenets50_class_synset = sorted(os.listdir(val_root))

# Load ImageNet-1k class index (synset → 0–999)
imagenet_class_index_url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
with open(shap.datasets.cache(imagenet_class_index_url)) as f:
    imagenet_class_index = json.load(f)

class_names = [entry[1] for entry in imagenet_class_index.values()]
synset_names = [entry[0] for entry in imagenet_class_index.values()]

# Mapping S50 indices → ImageNet-1k indices
imagenets50_to_imagenet1k = {
    i: synset_names.index(synset) for i, synset in enumerate(imagenets50_class_synset)
}
target_transform = lambda label: imagenets50_to_imagenet1k[label]


# Mapping ImageNet-1k indices  → Mapping S50 indices
imagenet1k_to_imagenets50 = {v: k for k, v in imagenets50_to_imagenet1k.items()}


# ----------- Dataset and Dataloaders -----------
val_dataset = CustomDataset(ImageFolder(val_root, transform=img_transform, target_transform=target_transform))
segm_dataset = ImageFolder(segmented_val_root, transform=mask_transform, target_transform=target_transform)

val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
segm_loader = DataLoader(segm_dataset, batch_size=1, shuffle=False)



# COMMAND ----------

count = 0
for images, _, _ in val_loader:  
    count += images.size(0)

print(f"Number of images: {count}")


# COMMAND ----------

# MAGIC %md
# MAGIC #### Load Model

# COMMAND ----------

import torchvision.models as models

model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
model.to(device)
model.eval()
print("ViT-B/16 (ImageNet) loaded!")


# COMMAND ----------

# MAGIC %md
# MAGIC #### Calculate

# COMMAND ----------

from tqdm import tqdm
import math
import random
import pandas as pd

# for debugging
import matplotlib.pyplot as plt

##### --- RANDOM --- #####
class RandomMaskingBlindnessEvaluator:
    def __init__(self, model, class_names, csv_save_path, device="cuda"):
        self.model = model.eval().to(device)
        self.class_names = class_names
        self.csv_save_path = csv_save_path
        self.device = device
        

    def evaluate(self, val_loader, segm_loader, map_dict=None, grid_size=33):
        total_correct = 0
        total_samples = 0

        for batch_idx, (image_tensor, label, img_name) in enumerate(tqdm(val_loader)):
            image_tensor = image_tensor.to(self.device)
            label = label.to(self.device)
            B, C, H, W = image_tensor.shape

            # Prepare segmentation mask
            segmented_tensor, _ = next(self._get_batch(segm_loader, batch_idx))
            segmented_tensor = segmented_tensor.to(self.device)

            # Prepare patch coordinates
            patch_h = H // grid_size
            patch_w = W // grid_size
            patch_coords = [(i, j) for i in range(grid_size) for j in range(grid_size)]
            random.shuffle(patch_coords)

            # Baseline prediction
            with torch.no_grad():
                baseline_pred = torch.softmax(self.model(image_tensor), dim=1)
            baseline_pred = baseline_pred[0, label.item()].item()

            results = []
            current_mask = torch.ones((1, 1, H, W), device=self.device)

            for step, (i, j) in enumerate(patch_coords):
                y_start, y_end = i * patch_h, (i + 1) * patch_h
                x_start, x_end = j * patch_w, (j + 1) * patch_w
                current_mask[:, :, y_start:y_end, x_start:x_end] = 0.0

                masked = image_tensor * current_mask

                # Predict
                with torch.no_grad():
                    probs = torch.softmax(self.model(masked), dim=1)

                pred_class_idx = torch.argmax(probs, dim=1)
                gt_conf = probs[0, label.item()].item()
                pred_conf = probs[0, pred_class_idx.item()].item()
                blindness, rel_blindness = self._calc_blindness(current_mask, segmented_tensor, label, map_dict)

                info_loss = baseline_pred - gt_conf
                abs_info_loss = abs(info_loss)

                x0 = self.unnormalize(image_tensor)
                x1 = self.unnormalize(masked)
                delta = x0 - x1
                l2_distance = torch.norm(delta)

                if l2_distance.item() == 0.0:
                    lip_const = 0.0
                    abs_lip_const = 0.0
                else:
                    lip_const = info_loss / l2_distance.item()
                    abs_lip_const = abs_info_loss / l2_distance.item()

                correct = int(pred_class_idx.item() == label.item())
                total_correct += correct
                total_samples += 1

                results.append({
                    "img_name": img_name[0],
                    "gt_class": self.class_names[label.item()],
                    "gt_confidence": gt_conf,
                    "pred_class": self.class_names[pred_class_idx.item()],
                    "pred_confidence": pred_conf,
                    "blindness": blindness,
                    "rel_blindness": rel_blindness,
                    "information_loss": info_loss,
                    "abs_info_loss": abs_info_loss, 
                    "l2_distance": l2_distance.item(),
                    "lipstitz": lip_const,
                    "abs_lipstitz": abs_lip_const,
                    "correct": correct
                })

            # Save batch results
            df = pd.DataFrame(results)
            df.to_csv(self.csv_save_path, mode="a", index=False, header=not os.path.exists(self.csv_save_path))

        # Print overall accuracy
        accuracy = total_correct / total_samples if total_samples > 0 else 0.0
        print(f"\nClassification Accuracy during masking: {accuracy:.4f}")



    def unnormalize(self, img):
        mean = [0.485, 0.456, 0.406]
        std  = [0.229, 0.224, 0.225]

        mean = torch.tensor(mean).view(1, 3, 1, 1).to(img.device)
        std = torch.tensor(std).view(1, 3, 1, 1).to(img.device)
        return img * std + mean

    
    def _calc_blindness(self, mask, segm_tensor, label, map_dict):

        # ATTENTION!!! label + 1  because background is already 0 !!!
        if map_dict is not None:    
            label = map_dict[label.item()]


        mask = mask.squeeze(0).cpu().numpy()[0]       # [H,W]
        segm = segm_tensor.squeeze(0).cpu().numpy()   # [1, H, W]
        segm = segm.sum(axis=0)                       # [H, W]
        segm = segm * 255

        # Be sure its Int
        mask = mask.astype(int)
        segm = segm.astype(int)
        label = int(label) + 1

        blind_region = np.where((mask == 0) & (segm == label), 1, 0)
        obj_region = np.where(segm == label, 1, 0).sum()

        if obj_region > 0:
            rel_blindness = blind_region.sum() / obj_region
        else:
            rel_blindness = 0

        # Compute overall blindness
        image_area = mask.shape[0] * mask.shape[1]
        blindness = (1 - mask).sum() / image_area
        
        return blindness, rel_blindness

    def _get_batch(self, loader, target_idx):
        for idx, batch in enumerate(loader):
            if idx == target_idx:
                yield batch
                return
            

evaluator = RandomMaskingBlindnessEvaluator(
    model=model,
    class_names=class_names,
    csv_save_path="/Workspace/Users/ANONYM/ShapCon/model_comparison/results/vit.csv",
)

evaluator.evaluate(val_loader, segm_loader, map_dict=imagenet1k_to_imagenets50, grid_size=13)


# COMMAND ----------

# MAGIC %md
# MAGIC ### Overall Results

# COMMAND ----------

from utils.helper import *
csv_path = "/Workspace/Users/ANONYM/ShapCon/model_comparison/results/vit.csv"
#print("BLIDNESS")
#analyze_global_degradation_fit(csv_path, blindness_col="rel_blindness", corr_analysis=True, quantile=0.99999)
print("RELATIVE BLINDNESS")
analyze_global_degradation_fit(csv_path, blindness_col="rel_blindness", corr_analysis=True, quantile=0.99999)


# COMMAND ----------

from utils.helper import *
csv_path = "/Workspace/Users/ANONYM/ShapCon/comparison/results/01_random_resnet34.csv"
analyze_global_degradation_fit(csv_path)


# COMMAND ----------

analyze_global_degradation_fit(csv_path, blindness_col="rel_blindness")

# COMMAND ----------

# MAGIC %md
# MAGIC ### Class-wise Results

# COMMAND ----------

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error

# Load paths
csv_save_path = "/Workspace/Users/ANONYM/ShapCon/comparison/results/01_random_resnet34.csv"
output_path = "/Workspace/Users/ANONYM/ShapCon/comparison/figs/01_random_resnet34_first12.pdf"

# Load data
data = pd.read_csv(csv_save_path)
gt_classes = data.gt_class.unique()

# Plot settings (only first 12 classes)
n_classes = 12
n_cols = 6
n_rows = 2
xlim = (data.blindness.min(), data.blindness.max())
ylim = (data.gt_confidence.min(), data.gt_confidence.max())

plt.figure(figsize=(n_cols * 4, n_rows * 3))
metric_results = []

# Negative sigmoid function
def neg_sigmoid(x, a, b, c):
    return a * (1 - 1 / (1 + np.exp(-b * (x - c))))

# Per-class plots and metrics (only first 12)
for i, gt_class in enumerate(gt_classes[:12]):
    plt.subplot(n_rows, n_cols, i + 1)
    class_data = data[data.gt_class == gt_class].sort_values("rel_blindness")
    x = class_data.rel_blindness.values
    y = class_data.gt_confidence.values

    plt.scatter(x, y, alpha=0.6, color="#00876C")

    if len(x) >= 4:
        num_bins = 100
        bins = np.linspace(x.min(), x.max(), num_bins + 1)
        bin_centers = 0.5 * (bins[:-1] + bins[1:])
        max_vals = []

        for j in range(num_bins):
            bin_mask = (x >= bins[j]) & (x < bins[j + 1])
            if np.any(bin_mask):
                #max_vals.append(np.max(y[bin_mask]))
                max_vals.append(np.quantile(y[bin_mask], 0.99999))
            else:
                max_vals.append(np.nan)

        bin_centers = np.array(bin_centers)
        max_vals = np.array(max_vals)
        valid = ~np.isnan(max_vals)
        bin_centers = bin_centers[valid]
        max_vals = max_vals[valid]

        if len(bin_centers) >= 4:
            try:
                a0 = np.max(max_vals)
                b0 = 10.0
                c0 = np.median(bin_centers)
                popt, _ = curve_fit(neg_sigmoid, bin_centers, max_vals, p0=[a0, b0, c0], maxfev=10000)

                x_fit = np.linspace(bin_centers.min(), bin_centers.max(), 100)
                y_fit = neg_sigmoid(x_fit, *popt)
                plt.plot(x_fit, y_fit, color="#D5001C", linewidth=2)

                dy = np.gradient(y_fit, x_fit)
                max_slope_idx = np.argmax(np.abs(dy))
                rbs = x_fit[max_slope_idx]

                y_pred = neg_sigmoid(bin_centers, *popt)
                r2 = r2_score(max_vals, y_pred)
                rmse = np.sqrt(mean_squared_error(max_vals, y_pred))
                mae = mean_absolute_error(max_vals, y_pred)
            except RuntimeError:
                x_fit = x
                y_fit = y
                rbs = np.nan
                r2, rmse, mae = np.nan, np.nan, np.nan
        else:
            x_fit = x
            y_fit = y
            rbs = np.nan
            r2, rmse, mae = np.nan, np.nan, np.nan
    else:
        x_fit = x
        y_fit = y
        rbs = np.nan
        r2, rmse, mae = np.nan, np.nan, np.nan

    aubc = np.trapz(y_fit, x_fit)
    aubc_norm = aubc / y_fit[0] if y_fit[0] > 0 else np.nan

    thresholds = [0.3, 0.5, 0.75, 0.8, 0.9]
    acps = {}
    for tau in thresholds:
        below = y_fit < tau
        acp_tau = np.min(x_fit[below]) if np.any(below) else np.nan
        acps[f"ACP_{tau}"] = acp_tau

    metric_results.append({
        "class": gt_class,
        "AUBC": aubc,
        "AUBC_norm": aubc_norm,
        "RBS": rbs,
        "R2": r2,
        "RMSE": rmse,
        "MAE": mae,
        **acps
    })

    plt.xlim(xlim)
    plt.ylim(ylim)
    plt.title(f"Class: {gt_class}", fontsize=18)
    plt.xlabel(r"Relative Blindness $B'$", fontsize=14)
    if i % n_cols == 0:
        plt.ylabel(r"Confidence $C$", fontsize=14)

plt.tight_layout()
plt.savefig(output_path, format="pdf", bbox_inches="tight")
plt.show()

# Output metrics as DataFrame
metrics_df = pd.DataFrame(metric_results)

# Print selected metrics for first 12 classes
selected_columns = ["class", "AUBC", "RBS", "ACP_0.5", "ACP_0.8"]
subset_df = metrics_df[selected_columns]

print("\n=== Selected Metrics for First 12 Classes ===\n")
print(subset_df.to_string(index=False, float_format="%.4f"))


# COMMAND ----------

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error

# Load paths
csv_save_path = "/Workspace/Users/ANONYM/ShapCon/comparison/results/01_random_resnet34.csv"
output_path = "/Workspace/Users/ANONYM/ShapCon/comparison/figs/01_random_resnet34.pdf"


# Load data
data = pd.read_csv(csv_save_path)
gt_classes = data.gt_class.unique()

# Plot settings
n_classes = len(gt_classes)
n_cols = 5
n_rows = 10
xlim = (data.blindness.min(), data.blindness.max())
ylim = (data.gt_confidence.min(), data.gt_confidence.max())

plt.figure(figsize=(n_cols * 4, n_rows * 3))
metric_results = []

# Negative sigmoid function
def neg_sigmoid(x, a, b, c):
    return a * (1 - 1 / (1 + np.exp(-b * (x - c))))

# Per-class plots and metrics
for i, gt_class in enumerate(gt_classes):
    plt.subplot(n_rows, n_cols, i + 1)
    class_data = data[data.gt_class == gt_class].sort_values("blindness")
    x = class_data.blindness.values
    y = class_data.gt_confidence.values

    # Scatter
    plt.scatter(x, y, alpha=0.6, color="#00876C")

    # Hüllkurve via binning
    if len(x) >= 4:
        num_bins = 20
        bins = np.linspace(x.min(), x.max(), num_bins + 1)
        bin_centers = 0.5 * (bins[:-1] + bins[1:])
        max_vals = []

        for j in range(num_bins):
            bin_mask = (x >= bins[j]) & (x < bins[j + 1])
            if np.any(bin_mask):
                max_vals.append(np.max(y[bin_mask]))
            else:
                max_vals.append(np.nan)

        bin_centers = np.array(bin_centers)
        max_vals = np.array(max_vals)
        valid = ~np.isnan(max_vals)
        bin_centers = bin_centers[valid]
        max_vals = max_vals[valid]

        # Fit sigmoid if enough points
        if len(bin_centers) >= 4:
            try:
                # Initial guess: a=max, b=steepness, c=drop location ~median
                a0 = np.max(max_vals)
                b0 = 10.0
                c0 = np.median(bin_centers)
                popt, _ = curve_fit(neg_sigmoid, bin_centers, max_vals, p0=[a0, b0, c0], maxfev=10000)

                x_fit = np.linspace(bin_centers.min(), bin_centers.max(), 100)
                y_fit = neg_sigmoid(x_fit, *popt)
                plt.plot(x_fit, y_fit, color="#D5001C", linewidth=2)

                # RBS = where slope is steepest
                dy = np.gradient(y_fit, x_fit)
                max_slope_idx = np.argmax(np.abs(dy))
                rbs = x_fit[max_slope_idx]

                # Fit metrics
                y_pred = neg_sigmoid(bin_centers, *popt)
                r2 = r2_score(max_vals, y_pred)
                rmse = np.sqrt(mean_squared_error(max_vals, y_pred))
                mae = mean_absolute_error(max_vals, y_pred)
            except RuntimeError:
                x_fit = x
                y_fit = y
                rbs = np.nan
                r2, rmse, mae = np.nan, np.nan, np.nan
        else:
            x_fit = x
            y_fit = y
            rbs = np.nan
            r2, rmse, mae = np.nan, np.nan, np.nan
    else:
        x_fit = x
        y_fit = y
        rbs = np.nan
        r2, rmse, mae = np.nan, np.nan, np.nan

    # AUBC and normalized AUBC
    aubc = np.trapz(y_fit, x_fit)
    aubc_norm = aubc / y_fit[0] if y_fit[0] > 0 else np.nan

    # ACPs
    thresholds = [0.3, 0.5, 0.75, 0.8, 0.9]
    acps = {}
    for tau in thresholds:
        below = y_fit < tau
        acp_tau = np.min(x_fit[below]) if np.any(below) else np.nan
        acps[f"ACP_{tau}"] = acp_tau

    # Store results
    metric_results.append({
        "class": gt_class,
        "AUBC": aubc,
        "AUBC_norm": aubc_norm,
        "RBS": rbs,
        "R2": r2,
        "RMSE": rmse,
        "MAE": mae,
        **acps
    })

    # Axis and labels
    plt.xlim(xlim)
    plt.ylim(ylim)
    plt.title(f"Class: {gt_class}", fontsize=10)
    plt.xlabel(r"Blindness $B$")
    if i % n_cols == 0:
        plt.ylabel(r"Confidence $C$")

# Save and show plot
plt.tight_layout()
plt.savefig(output_path, format="pdf", bbox_inches="tight")
plt.show()

# Output metrics
metrics_df = pd.DataFrame(metric_results)

print("\n=== Per-Class Robustness & Fit Metrics ===\n")
print(metrics_df.to_string(index=False, float_format="%.4f"))

# Mean metrics
mean_metrics = metrics_df.drop(columns=["class"]).mean(numeric_only=True)
print("\n=== Mean Metrics Across All Classes ===\n")
for metric, value in mean_metrics.items():
    print(f"{metric}: {value:.4f}")



# COMMAND ----------

