# Databricks notebook source
# MAGIC %md
# MAGIC # Guidance: Imagenet vit
# MAGIC
# MAGIC Load Data

# COMMAND ----------


# Import
import os
import json
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
import shap

# Own Code
#from utils.imagenet_dataloader import *


# Custom dataset with image names
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.img_paths = [p for p, _ in dataset.imgs]

    def __getitem__(self, index):
        img, label = self.dataset[index]
        img_path = self.img_paths[index]
        return img, label, os.path.basename(img_path)

    def __len__(self):
        return len(self.dataset)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)



# ----------- Image Transformations -----------
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std  = [0.229, 0.224, 0.225]

img_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])

mask_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x[0].unsqueeze(0))  # keep only first channel → shape [1, H, W]
])



# ----------- Paths and Class Mapping -----------
val_root = "/Workspace/Users/ANONYM/Base_Constrain_Analysis/data/validation"
segmented_val_root = "/Workspace/Users/ANONYM/Base_Constrain_Analysis/data/validation-segmentation"

# List of synset names (subfolder names)
imagenets50_class_synset = sorted(os.listdir(val_root))

# Load ImageNet-1k class index (synset → 0–999)
imagenet_class_index_url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
with open(shap.datasets.cache(imagenet_class_index_url)) as f:
    imagenet_class_index = json.load(f)

class_names = [entry[1] for entry in imagenet_class_index.values()]
synset_names = [entry[0] for entry in imagenet_class_index.values()]

# Mapping S50 indices → ImageNet-1k indices
imagenets50_to_imagenet1k = {
    i: synset_names.index(synset) for i, synset in enumerate(imagenets50_class_synset)
}
target_transform = lambda label: imagenets50_to_imagenet1k[label]


# Mapping ImageNet-1k indices  → Mapping S50 indices
imagenet1k_to_imagenets50 = {v: k for k, v in imagenets50_to_imagenet1k.items()}


# ----------- Dataset and Dataloaders -----------
val_dataset = CustomDataset(ImageFolder(val_root, transform=img_transform, target_transform=target_transform))
segm_dataset = ImageFolder(segmented_val_root, transform=mask_transform, target_transform=target_transform)

val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
segm_loader = DataLoader(segm_dataset, batch_size=1, shuffle=False)



# COMMAND ----------

# MAGIC %md
# MAGIC #### Load Model

# COMMAND ----------


from transformers import ViTForImageClassification, ViTImageProcessor

model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

model = model.to(device).eval()

# COMMAND ----------

print("Model architecture:\n")
print(model)

# COMMAND ----------

# MAGIC %md
# MAGIC ### Degradation

# COMMAND ----------

from utils.framework import *


# MAPPING IMAGENET
# torch.Size([1, 1, 224, 224]) -> (224,224)
def imagenet_mapping(segm, label, map_dict=None):
    # Segmentation
    segm = segm.squeeze(0).cpu().numpy()   # [1, H, W]
    segm = segm.sum(axis=0)                       # [H, W]
    segm = segm * 255                             # Because its Normalized
    segm = segm.astype(int)

    # Label mapping
    if map_dict is not None:    
        label = map_dict[label.item()]
    label = int(label) + 1

    # Binary
    obj_region = np.where(segm == label, 1, 0)
    
    return obj_region

def predict(model, image_tensor):
    with torch.no_grad():
        output = model(image_tensor)
        logits = output.logits  # <- this is the tensor 
        probs = torch.softmax(logits, dim=1)
    return probs


#PERTURBATION_MODE = "mdp"
#NUM_GRID_ROW = 7

#PERTURBATION_MODE = "mdp"
#NUM_GRID_ROW = 14

#PERTURBATION_MODE = "mdp"
#NUM_GRID_ROW = 32

PERTURBATION_MODE = "blur"
NUM_GRID_ROW = 14


"""
Patchgröße (p × p)	Patches pro Seite	Gesamtanzahl Patches
224	    1	    1
112	    2	    4
56	    4	    16
32	    7	    49
28	    8	    64
16	    14	    196
8	    28	    784
7	    32	    1024

--> 7
--> 28
--> 32 
"""


# COMMAND ----------

# MAGIC %md
# MAGIC #### (A) Random

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/vit_A_random.csv",
    mode="random",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=imagenet_mapping,
    predict = predict
    )

evaluator.evaluate(val_loader, segm_loader, map_dict=imagenet1k_to_imagenets50, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (A1) Random ObjFirst

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/vit_A1_random_objfirst.csv",
    mode="random_objfirst",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=imagenet_mapping,
    predict = predict
    )

evaluator.evaluate(val_loader, segm_loader, map_dict=imagenet1k_to_imagenets50, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (A2) Background First

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/vit_A2_background_first.csv",
    mode="random",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=imagenet_mapping,
    predict = predict
    )

evaluator.evaluate(val_loader, segm_loader, map_dict=imagenet1k_to_imagenets50, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (A3) Gauss

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/vit_A3_gauss.csv",
    mode="gauss",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=imagenet_mapping,
    predict = predict
    )

evaluator.evaluate(val_loader, segm_loader, map_dict=imagenet1k_to_imagenets50, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (A4) Perlin

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/vit_A4_perlin.csv",
    mode="perlin",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=imagenet_mapping,
    predict = predict
    )

evaluator.evaluate(val_loader, segm_loader, map_dict=imagenet1k_to_imagenets50, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (A5) Gauss + Perlin

# COMMAND ----------

"""evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/vit_A5_gaussplusperlin.csv",
    mode="centred_perlin",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=imagenet_mapping,
    predict = predict
    )

evaluator.evaluate(val_loader, segm_loader, map_dict=imagenet1k_to_imagenets50, grid_size = NUM_GRID_ROW)"""

# COMMAND ----------

# MAGIC %md
# MAGIC #### (B) SHAP

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/vit_B_shap.csv",
    mode="shap",
    max_rows_per_file=20000,
    class_names = class_names,
    per_mode= PERTURBATION_MODE,
    segm_mapping=imagenet_mapping,
    predict = predict
    )

evaluator.evaluate(val_loader, segm_loader, map_dict=imagenet1k_to_imagenets50, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (C) Grad-Cam

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/vit_C_gradcam.csv",
    mode="gradcam",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=imagenet_mapping,
    predict = predict,
    model_name = "vit_hf",
    debug=False
    )

evaluator.evaluate(val_loader, segm_loader, map_dict=imagenet1k_to_imagenets50, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (D) SmoothedCam

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/vit_D_smoothedcam.csv",
    mode="scam",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=imagenet_mapping,
    predict = predict,
    debug=False
    )

evaluator.evaluate(val_loader, segm_loader, map_dict=imagenet1k_to_imagenets50, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (E) Integrated Gradients

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/vit_E_ig.csv",
    mode="ig",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=imagenet_mapping,
    predict = predict,
    debug=False
    )

evaluator.evaluate(val_loader, segm_loader, map_dict=imagenet1k_to_imagenets50, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (F) Activation maximization

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/vit_F_am.csv",
    mode="am",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=imagenet_mapping,
    predict = predict,
    debug=False
    )

evaluator.evaluate(val_loader, segm_loader, map_dict=imagenet1k_to_imagenets50, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (G) Occlusion Sensitivity

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/vit_G_os.csv",
    mode="os",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=imagenet_mapping,
    predict = predict,
    debug=False
    )

evaluator.evaluate(val_loader, segm_loader, map_dict=imagenet1k_to_imagenets50, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (H) Loss

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/vit_H_loss.csv",
    mode="loss",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=imagenet_mapping,
    predict = predict,
    debug=False
    )

evaluator.evaluate(val_loader, segm_loader, map_dict=imagenet1k_to_imagenets50, grid_size = NUM_GRID_ROW)