# Databricks notebook source
# MAGIC %md
# MAGIC # Guidance: Pets Vit
# MAGIC
# MAGIC Load Data

# COMMAND ----------

import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import matplotlib.pyplot as plt
import timm

# ----------- Device -----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ----------- Transformations -----------
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

img_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])

mask_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x[0].unsqueeze(0))  # Convert to (1, H, W)
])

# ----------- Utility Functions -----------
def nchw_to_nhwc(x: torch.Tensor) -> torch.Tensor:
    return x if x.shape[-1] == 3 else x.permute(0, 2, 3, 1) if x.dim() == 4 else x.permute(1, 2, 0)

def nhwc_to_nchw(x: torch.Tensor) -> torch.Tensor:
    return x if x.shape[1] == 3 else x.permute(0, 3, 1, 2) if x.dim() == 4 else x.permute(2, 0, 1)

# ----------- Dataset Classes -----------
class OxfordPetsImageDataset(Dataset):
    def __init__(self, img_dir, filenames, image_to_label, transform=None):
        self.img_dir = img_dir
        self.filenames = filenames
        self.image_to_label = image_to_label
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fname = self.filenames[idx]
        img_path = os.path.join(self.img_dir, fname + ".jpg")
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = self.image_to_label[fname]
        return img, label, os.path.basename(img_path)

class OxfordPetsMaskDataset(Dataset):
    def __init__(self, mask_dir, filenames, transform=None):
        self.mask_dir = mask_dir
        self.filenames = filenames
        self.transform = transform

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fname = self.filenames[idx]
        mask_path = os.path.join(self.mask_dir, fname + ".png")
        mask = Image.open(mask_path).convert("L")
        if self.transform:
            mask = self.transform(mask)
        return mask, os.path.basename(mask_path)

# ----------- Paths -----------
oxford_root = "/Workspace/Users/ANONYM/Base_Constrain_Analysis/data_oxford_pets"
oxford_img_dir = os.path.join(oxford_root, "images")
oxford_mask_dir = os.path.join(oxford_root, "masks")
oxford_split_file = os.path.join(oxford_root, "annotations/test.txt")

# ----------- Load Filenames and Labels from test.txt -----------
split_filenames = []
image_to_label = {}

with open(oxford_split_file, "r") as f:
    for line in f:
        parts = line.strip().split()
        if len(parts) >= 2:
            fname, label = parts[0], int(parts[1]) - 1  # Convert to 0-based label
            split_filenames.append(fname)
            image_to_label[fname] = label

# ----------- Initialize Datasets & Loaders -----------
oxford_img_dataset = OxfordPetsImageDataset(
    img_dir=oxford_img_dir,
    filenames=split_filenames,
    image_to_label=image_to_label,
    transform=img_transform
)

oxford_mask_dataset = OxfordPetsMaskDataset(
    mask_dir=oxford_mask_dir,
    filenames=split_filenames,
    transform=mask_transform
)

oxford_img_loader = DataLoader(oxford_img_dataset, batch_size=1, shuffle=False)
oxford_mask_loader = DataLoader(oxford_mask_dataset, batch_size=1, shuffle=False)




# COMMAND ----------

# MAGIC %md
# MAGIC #### Load Model

# COMMAND ----------

# Load model directly
from transformers import AutoImageProcessor, AutoModelForImageClassification

processor = AutoImageProcessor.from_pretrained("muellje3/vit-base-oxford-iiit-pets")
model = AutoModelForImageClassification.from_pretrained("muellje3/vit-base-oxford-iiit-pets")

model = model.to(device) 
model.eval()

# ----------- Example: Process One Sample -----------
for (img, label, fname_img), (mask, fname_mask) in zip(oxford_img_loader, oxford_mask_loader):
    assert fname_img[0].replace(".jpg", "") == fname_mask[0].replace(".png", "")
    img = img.to(device)
    output = model(img)
    #pred = torch.argmax(output, dim=1).item()
    pred = torch.argmax(output.logits, dim=1).item()



    print(f"Processed: {fname_img[0]} | Predicted Class ID: {pred}")
    break  # Remove break to process entire dataset


"""
ATTENTION Model Output is INVERSE TO DATASET Labeling!
"""

# COMMAND ----------

# MAGIC %md
# MAGIC ### Degradation

# COMMAND ----------

from utils.framework import *

#######################################
"""
    def _calc_blindness(self, mask, segm_tensor, label, map_dict):

        # ATTENTION!!! label + 1  because background is already 0 !!!
        if map_dict is not None:    
            label = map_dict[label.item()]

        # Mask Mapping Oxford-Pets
        segm = segm_tensor.squeeze(0).squeeze(0).cpu().numpy() *255
        segm = (segm != 1).astype(np.uint8)
        mask = mask.squeeze(0).cpu().numpy()[0]       # [H,W]

        # Be sure its Int
        mask = mask.astype(int)
        segm = segm.astype(int)
        segm = 1 - segm

        blind_region = np.where((mask == 0) & (segm == 1), 1, 0)
        obj_region = np.where((segm == 1), 1, 0)
        
        blind_region = blind_region.sum()
        obj_region = obj_region.sum()
        if obj_region > 0:
            rel_blindness = blind_region / obj_region
        else:
            rel_blindness = 0

        # Compute overall blindness
        image_area = mask.shape[0] * mask.shape[1]
        blindness = (1 - mask).sum() / image_area
        
        return blindness, rel_blindness"""
    ##########################################



# MAPPING IMAGENET
# torch.Size([1, 1, 224, 224]) -> (224,224)
def pets_mapping(segm, label, map_dict=None):

    # Segmentation
    segm = segm.squeeze(0).squeeze(0).cpu().numpy() *255
    segm = (segm != 1).astype(np.uint8)                  # [H, W]
    segm = 1 - segm

    obj_region = np.where((segm == 1), 1, 0)

    #### DEBUG ####
    """print(label)
    plt.imshow(segm)
    plt.colorbar()
    plt.show()
    stop"""
    #### DEBUG ####
    
    return obj_region


def label_mapping(label):
    print("TEST Label BEFORE", label)
    label = abs(label - 32)     #### PETS
    print("TEST Label AFTER", label)
    return label

def predict(model, image_tensor):
    with torch.no_grad():
        logits = model(image_tensor).logits
        probs = torch.softmax(logits, dim=1)
    return probs


PERTURBATION_MODE = "mdp"
NUM_GRID_ROW = 14

"""
Patchgröße (p × p)	Patches pro Seite	Gesamtanzahl Patches
224	    1	    1
112	    2	    4
56	    4	    16
32	    7	    49
28	    8	    64
16	    14	    196
8	    28	    784
7	    32	    1024

--> 7
--> 28
--> 32 
"""


# COMMAND ----------

# MAGIC %md
# MAGIC #### (A) Random

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/pets/TESTvit_A_random.csv",
    mode="random",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=pets_mapping,
    label_mapping = label_mapping,
    predict = predict
    )

evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (A1) Random ObjFirst

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/pets/vit_A1_random_backfirst.csv",
    mode="random_objfirst",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=pets_mapping,
    label_mapping = label_mapping,
    predict = predict
    )


evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (A2) Background First

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/pets/vit_A2_background_first.csv",
    mode="random",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=pets_mapping,
    label_mapping = label_mapping,
    predict = predict
    )


evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (A3) Gauss

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/pets/vit_A3_gauss.csv",
    mode="gauss",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=pets_mapping,
    label_mapping = label_mapping,
    predict = predict
    )


evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (A4) Perlin

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/pets/vit_A4_perlin.csv",
    mode="perlin",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=pets_mapping,
    label_mapping = label_mapping,
    predict = predict
    )


evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (A5) Gauss + Perlin

# COMMAND ----------

"""evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/pets/resnet50_A5_gaussplusperlin.csv",
    mode="centred_perlin",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=pets_mapping,
    label_mapping = label_mapping,
    predict = predict
    )


evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)"""

# COMMAND ----------

# MAGIC %md
# MAGIC #### (B) SHAP

# COMMAND ----------


evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/pets/vit_B_shap.csv",
    mode="shap",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=pets_mapping,
    label_mapping = label_mapping,
    predict = predict
    )


evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (C) Grad-Cam

# COMMAND ----------

def print_model_layers(model, prefix=''):
    for name, module in model.named_children():
        full_name = f"{prefix}.{name}" if prefix else name
        print(full_name, '→', module.__class__.__name__)
        print_model_layers(module, full_name)


print_model_layers(model)


# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/pets/vit_C_gradcam.csv",
    mode="gradcam",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=pets_mapping,
    label_mapping = label_mapping,
    predict = predict,
    model_name = "vitpets"
    )


evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (D) SmoothedCam

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/pets/vit_D_smoothedcam.csv",
    mode="scam",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=pets_mapping,
    label_mapping = label_mapping,
    predict = predict
    )


evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (E) Integrated Gradients

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/pets/vit_E_ig.csv",
    mode="ig",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=pets_mapping,
    label_mapping = label_mapping,
    predict = predict
    )

    
evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (F) Activation maximization

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/pets/vit_F_am.csv",
    mode="am",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=pets_mapping,
    label_mapping = label_mapping,
    predict = predict
    )


evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (G) Occlusion Sensitivity

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/pets/vit_G_os.csv",
    mode="os",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=pets_mapping,
    label_mapping = label_mapping,
    predict = predict
    )


evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)

# COMMAND ----------

# MAGIC %md
# MAGIC #### (H) Loss

# COMMAND ----------

evaluator = ModelDegradation(
    model=model,
    csv_save_path="/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/pets/vit_H_loss.csv",
    mode="loss",
    max_rows_per_file=20000,
    per_mode= PERTURBATION_MODE,
    segm_mapping=pets_mapping,
    label_mapping = label_mapping,
    predict = predict
    )


evaluator.evaluate(oxford_img_loader, oxford_mask_loader, grid_size = NUM_GRID_ROW)