# Databricks notebook source
# MAGIC %md
# MAGIC # Guidance: Flowers ViT
# MAGIC
# MAGIC Load Data

# COMMAND ----------

import os
import json
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import timm


# ----------- Device -----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ----------- Transformations -----------
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

img_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])

mask_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),  # ergibt Werte in [0.0, 1.0]
    transforms.Lambda(lambda x: (x > 0.498).float()),  # 127/255 threshold
    transforms.Lambda(lambda x: x[0].unsqueeze(0))  # (1, H, W)
])

# ----------- Utility Functions -----------
def nchw_to_nhwc(x: torch.Tensor) -> torch.Tensor:
    return x if x.shape[-1] == 3 else x.permute(0, 2, 3, 1) if x.dim() == 4 else x.permute(1, 2, 0)

def nhwc_to_nchw(x: torch.Tensor) -> torch.Tensor:
    return x if x.shape[1] == 3 else x.permute(0, 3, 1, 2) if x.dim() == 4 else x.permute(2, 0, 1)

# ----------- Dataset Classes -----------
class OxfordFlowersImageDataset(Dataset):
    def __init__(self, img_dir, filenames, label_map, transform=None):
        self.img_dir = img_dir
        self.filenames = filenames
        self.label_map = label_map
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fname = self.filenames[idx]
        img_path = os.path.join(self.img_dir, fname)
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = self.label_map[fname]
        return img, label, os.path.basename(img_path)


class OxfordFlowersMaskDataset(Dataset):
    def __init__(self, mask_dir, filenames, transform=None):
        self.mask_dir = mask_dir
        self.filenames = filenames
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fname = self.filenames[idx]
        mask_fname = os.path.splitext(fname)[0] + ".png"
        mask_path = os.path.join(self.mask_dir, mask_fname)
        mask = Image.open(mask_path).convert("L")
        if self.transform:
            mask = self.transform(mask)
        return mask, os.path.basename(mask_path)

# ----------- Paths -----------
oxford_root = "/Workspace/Users/ANONYM/Base_Constrain_Analysis/data_oxford_flowers"
oxford_img_dir = os.path.join(oxford_root, "images")
oxford_mask_dir = os.path.join(oxford_root, "masks")
oxford_labels_file = os.path.join(oxford_root, "labels.json")

# ----------- Load Labels from JSON -----------
with open(oxford_labels_file, "r") as f:
    label_data = json.load(f)

split_filenames = list(label_data.keys())
filename_to_label = {fname: label for fname, label in label_data.items()}

# ----------- Initialize Datasets & Loaders -----------
oxford_img_dataset = OxfordFlowersImageDataset(
    img_dir=oxford_img_dir,
    filenames=split_filenames,
    label_map=filename_to_label,
    transform=img_transform
)

oxford_mask_dataset = OxfordFlowersMaskDataset(
    mask_dir=oxford_mask_dir,
    filenames=split_filenames,
    transform=mask_transform
)

oxford_img_loader = DataLoader(oxford_img_dataset, batch_size=1, shuffle=False)
oxford_mask_loader = DataLoader(oxford_mask_dataset, batch_size=1, shuffle=False)


# COMMAND ----------

# MAGIC %md
# MAGIC #### Load Model

# COMMAND ----------

# Manuell einen FeatureExtractor initialisieren
model = timm.create_model("hf_hub:anonauthors/flowers102-timm-vit_base_patch16_224.orig_in21k_ft_in1k", pretrained=True)
model = model.to(device)
model.eval()

# ----------- Example: Process One Sample -----------
for (img, label, fname_img), (mask, fname_mask) in zip(oxford_img_loader, oxford_mask_loader):
    assert os.path.splitext(fname_img[0])[0] == os.path.splitext(fname_mask[0])[0]


    img = img.to(device)
    output = model(img)
    pred = torch.argmax(output, dim=1).item()

    print(f"Processed: {fname_img[0]} | True Label: {label.item()} | Predicted Class ID: {pred}")
    break  # Remove break to process entire dataset


# COMMAND ----------

# MAGIC %md
# MAGIC ### Degradation

# COMMAND ----------

from utils.framework import *


# MAPPING IMAGENET
# torch.Size([1, 1, 224, 224]) -> (224,224)
def flowers_mapping(segm, label, map_dict=None):
    # Segmentation
    segm = segm.squeeze(0).cpu().numpy()   # [1, H, W]
    segm = segm.sum(axis=0)                       # [H, W]
    segm = segm.astype(int)

    # Label mapping
    if map_dict is not None:    
        label = map_dict[label.item()]

    # Binary
    obj_region = np.where(segm == 1, 1, 0)

    #### DEBUG ####
    """
    if label > 10:
        print("obj_region")
        plt.imshow(obj_region)
        plt.colorbar()
        plt.show()
        stop
    """
    #### DEBUG ####
    
    return obj_region


def predict(model, image_tensor):
    with torch.no_grad():
        probs  = torch.softmax(model(image_tensor), dim=1)
    return probs



PERTURBATION_MODE = "mdp"
NUM_GRID_ROW = 14

"""
Patchgröße (p × p)	Patches pro Seite	Gesamtanzahl Patches
224	    1	    1
112	    2	    4
56	    4	    16
32	    7	    49
28	    8	    64
16	    14	    196
8	    28	    784
7	    32	    1024

--> 7
--> 28
--> 32 
"""


# COMMAND ----------

# MAGIC %md
# MAGIC #### (A) Random

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/flowers/vit_A_random.csv",
    mode="random",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=flowers_mapping,
    predict = predict
    )

evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (A1) Random ObjFirst

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/flowers/vit_A1_random_backfirst.csv",
    mode="random_objfirst",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=flowers_mapping,
    predict = predict
    )

evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)


# COMMAND ----------

# MAGIC %md
# MAGIC #### (A2) Background First

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/flowers/vit_A2_background_first.csv",
    mode="random",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=flowers_mapping,
    predict = predict
    )

evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (A3) Gauss

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/flowers/vit_A3_gauss.csv",
    mode="gauss",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=flowers_mapping,
    predict = predict
    )

evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (A4) Perlin

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/flowers/vit_A4_perlin.csv",
    mode="perlin",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=flowers_mapping,
    predict = predict
    )

evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (A5) Gauss + Perlin

# COMMAND ----------

"""evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/flowers/vit_A5_gaussplusperlin.csv",
    mode="centred_perlin",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=flowers_mapping,
    predict = predict
    )

evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)"""

# COMMAND ----------

# MAGIC %md
# MAGIC #### (B) SHAP

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/flowers/vit_B_shap.csv",
    mode="shap",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=flowers_mapping,
    predict = predict
    )

evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (C) Grad-Cam

# COMMAND ----------

"""evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/flowers/vit_C_gradcam.csv",
    mode="gradcam",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=flowers_mapping,
    predict = predict
    )

evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)"""

# COMMAND ----------

# MAGIC %md
# MAGIC #### (D) SmoothedCam

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/flowers/vit_D_smoothedcam.csv",
    mode="scam",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=flowers_mapping,
    predict = predict
    )

evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (E) Integrated Gradients

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/flowers/vit_E_ig.csv",
    mode="ig",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=flowers_mapping,
    predict = predict
    )

evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (F) Activation maximization

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/flowers/vit_F_am.csv",
    mode="am",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=flowers_mapping,
    predict = predict
    )

evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (G) Occlusion Sensitivity

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/flowers/vit_G_os.csv",
    mode="os",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=flowers_mapping,
    predict = predict
    )

evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (H) Loss

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/flowers/vit_H_loss.csv",
    mode="loss",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=flowers_mapping,
    predict = predict
    )

evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)