# Databricks notebook source
# MAGIC %md
# MAGIC #### Evaluation Perturbation Strategy - ResNet18

# COMMAND ----------

# MAGIC %md
# MAGIC Init

# COMMAND ----------

import os
import json
import numpy as np
import csv
import cv2
import torch
import torchvision
import math
from PIL import Image
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
import shap

# ----------- Custom Dataset Wrapper -----------
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.img_paths = [p for p, _ in dataset.imgs]

    def __getitem__(self, index):
        img, label = self.dataset[index]
        img_path = self.img_paths[index]
        return img, label, os.path.basename(img_path)

    def __len__(self):
        return len(self.dataset)

# ----------- Device -----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ----------- Utility Functions -----------
def nchw_to_nhwc(x: torch.Tensor) -> torch.Tensor:
    return x if x.shape[-1] == 3 else x.permute(0, 2, 3, 1) if x.dim() == 4 else x.permute(1, 2, 0)

def nhwc_to_nchw(x: torch.Tensor) -> torch.Tensor:
    return x if x.shape[1] == 3 else x.permute(0, 3, 1, 2) if x.dim() == 4 else x.permute(2, 0, 1)

# ----------- Image Transformations -----------
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std  = [0.229, 0.224, 0.225]

img_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])

mask_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x[0].unsqueeze(0))
])

# ----------- Paths -----------
val_root = "/Workspace/Users/ANONYM/ShapCon/data/validation"
segmented_val_root = "/Workspace/Users/ANONYM/ShapCon/data/validation-segmentation"

imagenets50_class_synset = sorted(os.listdir(val_root))
imagenet_class_index_url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
with open(shap.datasets.cache(imagenet_class_index_url)) as f:
    imagenet_class_index = json.load(f)

class_names = [entry[1] for entry in imagenet_class_index.values()]
synset_names = [entry[0] for entry in imagenet_class_index.values()]

imagenets50_to_imagenet1k = {
    i: synset_names.index(synset) for i, synset in enumerate(imagenets50_class_synset)
}
target_transform = lambda label: imagenets50_to_imagenet1k[label]


imagenet1k_to_imagenets50 = {v: k for k, v in imagenets50_to_imagenet1k.items()}

val_dataset = CustomDataset(ImageFolder(val_root, transform=img_transform, target_transform=target_transform))
segm_dataset = ImageFolder(segmented_val_root, transform=mask_transform, target_transform=target_transform)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
segm_loader = DataLoader(segm_dataset, batch_size=1, shuffle=False)

model = torchvision.models.resnet18(pretrained=True, progress=False).to(device).eval()



# COMMAND ----------



#import matplotlib.pyplot as plt


# ----------- Evaluation Class -----------
class GridPerturbationEvaluator:
    def __init__(self, model, val_loader, segm_loader, imagenet1k_to_imagenets50,
                 num_cells=100, perturbation_method="blur(3,3)", csv_path="saliency_outputs"):

        self.model = model
        self.val_loader = val_loader
        self.segm_loader = segm_loader
        self.map_dict = imagenet1k_to_imagenets50
        self.num_cells = num_cells
        self.input_size = 224

        self.grid_rows = int(np.floor(np.sqrt(num_cells)))
        self.grid_cols = int(np.ceil(num_cells / self.grid_rows))
        self.patch_height = int(self.input_size / self.grid_rows)
        self.patch_width = int(self.input_size / self.grid_cols)

        self.ys = np.linspace(0, self.input_size - self.patch_height, self.grid_rows).astype(int)
        self.xs = np.linspace(0, self.input_size - self.patch_width, self.grid_cols).astype(int)

        self.perturbation_method = perturbation_method
        self.csv_path = csv_path

    def run(self):
        print(f"Running perturbation method: {self.perturbation_method}")
        results = []
        csv_rows = []
        header = ["image", "label"] + [f"s_{i}_{j}" for i in range(self.grid_rows) for j in range(self.grid_cols)]

        for (img_tensor, label, img_name), (segm_tensor, _) in tqdm(zip(self.val_loader, self.segm_loader),
                                                                    total=len(self.val_loader),
                                                                    desc=f"Evaluating ({self.perturbation_method})"):
            img_tensor = img_tensor.to(device)
            segm_tensor = segm_tensor.to(device)
            img_name = img_name[0]
            imagenet1k_label = label.item()

            if imagenet1k_label not in self.map_dict:
                print(f"Skipping {img_name} (label {imagenet1k_label} not in mapping).")
                continue

            s50_label = self.map_dict[imagenet1k_label] + 1
            segm_np = segm_tensor.squeeze(0).squeeze(0).cpu().numpy() * 255

            with torch.no_grad():
                baseline_prob = torch.softmax(self.model(img_tensor), dim=1)[0, imagenet1k_label].item()

            saliency_map = np.full((self.grid_rows, self.grid_cols), np.nan, dtype=np.float32)
            img_np = nchw_to_nhwc(img_tensor.squeeze(0).cpu()).numpy()

            for i, y in enumerate(self.ys):
                for j, x in enumerate(self.xs):
                    region_labels = segm_np[y:y+self.patch_height, x:x+self.patch_width]

                    if not (region_labels == s50_label).any():
                        continue

                    mask = np.zeros(img_np.shape[:2], dtype=bool)
                    mask[y:y+self.patch_height, x:x+self.patch_width] = True

                    pert_np = self.perturb(mask, img_np, self.perturbation_method)[0]
                    pert_tensor = nhwc_to_nchw(torch.tensor(pert_np).float()).to(device)

                    with torch.no_grad():
                        pert_prob = torch.softmax(self.model(pert_tensor), dim=1)[0, imagenet1k_label].item()

                    saliency_map[i, j] = baseline_prob - pert_prob

            results.append({"image": img_name, "label": imagenet1k_label, "saliency": saliency_map})
            csv_row = [img_name, imagenet1k_label] + saliency_map.flatten().tolist()
            csv_rows.append(csv_row)

        os.makedirs(os.path.dirname(self.csv_path), exist_ok=True)
        with open(self.csv_path, mode='w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(header)
            writer.writerows(csv_rows)

        print(f"Saved: {self.csv_path}")

    def perturb(self, mask, x, method):
        x = x.copy()

        if mask.ndim == 2:
            mask = mask[:, :, None]  # expand to (H, W, 1)
        mask = mask.astype(bool)

        if isinstance(method, (int, float)):
            fill_rgb = np.array([method, method, method], dtype=np.float32)
            x[mask[:, :, 0]] = fill_rgb

        elif method.startswith("blur"):
            k = tuple(map(int, method.split("(")[1].strip(")").split(",")))
            blurred = cv2.blur(x, k)
            x[mask.repeat(3, axis=2)] = blurred[mask.repeat(3, axis=2)]

        elif method == "inpaint_telea":
            x = self._inpaint(x, mask, "INPAINT_TELEA")

        elif method == "inpaint_ns":
            x = self._inpaint(x, mask, "INPAINT_NS")

        elif method == "mdp":
            visible_pixels = x[mask.repeat(3, axis=2)].reshape(-1, 3)
            mean_rgb = visible_pixels.mean(axis=0)
            farthest_rgb = self._farthest_rgb(mean_rgb)
            x[mask[:, :, 0]] = farthest_rgb  # ← cleanest fix

        elif method == "median_mdp":
            visible_pixels = x[mask.repeat(3, axis=2)].reshape(-1, 3)
            median_rgb = np.median(visible_pixels, axis=0)
            farthest_rgb = self._farthest_rgb(median_rgb)
            x[mask[:, :, 0]] = farthest_rgb  # ← cleanest fix

        else:
            raise ValueError(f"Unknown perturbation method: {method}")

        return (x[None, ...],)  # shape: (1, H, W, 3)


    def _farthest_rgb(self, mean_rgb):
        corners = np.array([
            [0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1],
            [1, 1, 0], [1, 0, 1], [0, 1, 1], [1, 1, 1]
        ])
        return corners[np.argmax(np.linalg.norm(corners - mean_rgb, axis=1))]
        

    def _inpaint(self, x, mask, method):
        """Fill in the masked parts of the image through inpainting."""
        imagenet_mean = np.array([0.485, 0.456, 0.406])
        imagenet_std = np.array([0.229, 0.224, 0.225])

        # Unnormalize the image
        x_unnorm = x * imagenet_std + imagenet_mean
        x_unnorm = np.clip(x_unnorm, 0, 1)

        # Convert to uint8 for OpenCV
        x_uint8 = (x_unnorm * 255).astype(np.uint8)
        mask_uint8 = (mask[:, :, 0].astype(np.uint8)) * 255

        # Inpaint
        inpainted_uint8 = cv2.inpaint(
            x_uint8,
            mask_uint8,
            inpaintRadius=3,
            flags=getattr(cv2, method)
        )

        # Convert back to float and renormalize
        inpainted = inpainted_uint8.astype(np.float32) / 255.0
        inpainted = (inpainted - imagenet_mean) / imagenet_std

        return inpainted



# COMMAND ----------

# MAGIC %md
# MAGIC #### Calculation

# COMMAND ----------

# ----------- Run Evaluation -----------
evaluator = GridPerturbationEvaluator(
    model=model,
    val_loader=val_loader,
    segm_loader=segm_loader,
    imagenet1k_to_imagenets50=imagenet1k_to_imagenets50,
    num_cells=154,
    perturbation_method="mdp",
    csv_path="/Workspace/Users/ANONYM/SHAP_MDP/results/imagenet_resnet18_mdp.csv"
)

evaluator.run()

# COMMAND ----------

# ----------- Run Evaluation -----------
evaluator = GridPerturbationEvaluator(
    model=model,
    val_loader=val_loader,
    segm_loader=segm_loader,
    imagenet1k_to_imagenets50=imagenet1k_to_imagenets50,
    num_cells=154,
    perturbation_method="median_mdp",
    csv_path="/Workspace/Users/ANONYM/SHAP_MDP/results/imagenet_resnet18_median_mdp.csv"
)

evaluator.run()

# COMMAND ----------

# ----------- Run Evaluation -----------
evaluator = GridPerturbationEvaluator(
    model=model,
    val_loader=val_loader,
    segm_loader=segm_loader,
    imagenet1k_to_imagenets50=imagenet1k_to_imagenets50,
    num_cells=154,
    perturbation_method="median_mdp",
    csv_path="/Workspace/Users/ANONYM/SHAP_MDP/results/imagenet_resnet18_median_mdp.csv"
)

evaluator.run()

# COMMAND ----------

# ----------- Run Evaluation -----------
evaluator = GridPerturbationEvaluator(
    model=model,
    val_loader=val_loader,
    segm_loader=segm_loader,
    imagenet1k_to_imagenets50=imagenet1k_to_imagenets50,
    num_cells=154,
    perturbation_method=0,
    csv_path="/Workspace/Users/ANONYM/SHAP_MDP/results/imagenet_resnet18_udp_black.csv"
)

evaluator.run()

# COMMAND ----------

# ----------- Run Evaluation -----------
evaluator = GridPerturbationEvaluator(
    model=model,
    val_loader=val_loader,
    segm_loader=segm_loader,
    imagenet1k_to_imagenets50=imagenet1k_to_imagenets50,
    num_cells=154,
    perturbation_method=1,
    csv_path="/Workspace/Users/ANONYM/SHAP_MDP/results/imagenet_resnet18_udp_white.csv"
)

evaluator.run()



# COMMAND ----------

# ----------- Run Evaluation -----------
evaluator = GridPerturbationEvaluator(
    model=model,
    val_loader=val_loader,
    segm_loader=segm_loader,
    imagenet1k_to_imagenets50=imagenet1k_to_imagenets50,
    num_cells=154,
    perturbation_method="blur(15,15)",
    csv_path="/Workspace/Users/ANONYM/SHAP_MDP/results/imagenet_resnet18_blurr.csv"
)

evaluator.run()

# COMMAND ----------

# ----------- Run Evaluation -----------
evaluator = GridPerturbationEvaluator(
    model=model,
    val_loader=val_loader,
    segm_loader=segm_loader,
    imagenet1k_to_imagenets50=imagenet1k_to_imagenets50,
    num_cells=154,
    perturbation_method="inpaint_telea",
    csv_path="/Workspace/Users/ANONYM/SHAP_MDP/results/imagenet_resnet18_inpaint_telea.csv"
)

evaluator.run()

# COMMAND ----------

# ----------- Run Evaluation -----------
evaluator = GridPerturbationEvaluator(
    model=model,
    val_loader=val_loader,
    segm_loader=segm_loader,
    imagenet1k_to_imagenets50=imagenet1k_to_imagenets50,
    num_cells=154,
    perturbation_method="inpaint_ns",
    csv_path="/Workspace/Users/ANONYM/SHAP_MDP/results/imagenet_resnet18_inpaint_ns.csv"
)

evaluator.run()

# COMMAND ----------

# MAGIC %md
# MAGIC #### Evaluation

# COMMAND ----------

from utils.eva import *

# COMMAND ----------

# MAGIC %md
# MAGIC Ours - MDP

# COMMAND ----------

# Example usage:
csv_file = "/Workspace/Users/ANONYM/SHAP_MDP/results/imagenet_resnet18_mdp.csv"
analyze_saliency(csv_file)

# COMMAND ----------

# MAGIC %md
# MAGIC MeanDP

# COMMAND ----------

# Example usage:
csv_file = "/Workspace/Users/ANONYM/SHAP_MDP/results/imagenet_resnet18_median_mdp.csv"
analyze_saliency(csv_file)

# COMMAND ----------

# MAGIC %md
# MAGIC

# COMMAND ----------

# MAGIC %md
# MAGIC UDP-Black

# COMMAND ----------

# Example usage:
csv_file = "/Workspace/Users/ANONYM/SHAP_MDP/results/imagenet_resnet18_udp_black.csv"
analyze_saliency(csv_file)

# COMMAND ----------

# MAGIC %md
# MAGIC UDP-White

# COMMAND ----------

# Example usage:
csv_file = "/Workspace/Users/ANONYM/SHAP_MDP/results/imagenet_resnet18_udp_white.csv"
analyze_saliency(csv_file)

# COMMAND ----------

# MAGIC %md
# MAGIC Blurr

# COMMAND ----------

# Example usage:
csv_file = "/Workspace/Users/ANONYM/SHAP_MDP/results/imagenet_resnet18_blurr.csv"
analyze_saliency(csv_file)

# COMMAND ----------

# MAGIC %md
# MAGIC Inpaint Telea

# COMMAND ----------

# Example usage:
csv_file = "/Workspace/Users/ANONYM/SHAP_MDP/results/imagenet_resnet18_inpaint_telea.csv"
analyze_saliency(csv_file)

# COMMAND ----------

# MAGIC %md
# MAGIC Inpaint NS

# COMMAND ----------

# Example usage:
csv_file = "/Workspace/Users/ANONYM/SHAP_MDP/results/imagenet_resnet18_inpaint_ns.csv"
analyze_saliency(csv_file)

# COMMAND ----------

"""
                    ####### DEBUG
                    # Debug visualization (before sending to model)
                    fig, axes = plt.subplots(1, 3, figsize=(12, 4))

                    axes[0].imshow(img_np)
                    axes[0].set_title("Original Image")
                    axes[0].axis('off')

                    axes[1].imshow(mask.squeeze(), cmap='gray')
                    axes[1].set_title(f"Mask Patch ({i},{j})")
                    axes[1].axis('off')

                    axes[2].imshow(pert_np.squeeze(0))
                    axes[2].set_title("Perturbed Image")
                    axes[2].axis('off')

                    plt.tight_layout()
                    plt.show()
                    #######
"""