# Tim Dieter Eberhardt - KIT (ITIV)

import os
import numpy as np
import torch
import torch.nn.functional as F

from tqdm import tqdm
import math
import random
import pandas as pd
import glob
import noise
import shap
from IPython.display import clear_output

# for debugging
import matplotlib.pyplot as plt

# own libs
from .image import Image as SHAPImage  # your custom SHAP masker


##### --- DEGRADATION FRAMEWORK --- #####
class ModelDegradation:
    def __init__(self, model, csv_save_path, max_rows_per_file=20000, mode="random", per_mode="black", segm_mapping=None, label_mapping=None, predict=None, class_names=None, model_name="resnet", device="cuda", debug=False):
        self.model = model.eval().to(device)
        self.csv_save_path = csv_save_path
        self.device = device
        self.max_rows_per_file = max_rows_per_file
        self.mode = mode
        self.per_mode = per_mode
        self.segm_mapping = segm_mapping
        self.label_mapping = label_mapping
        self.predict = predict
        self.class_names = class_names
        self.model_name = model_name

        self.debug = debug

        if self.model_name == "vgg16":
            self.replace_relu_with_out_of_place(self.model)


    def current_csv_path(self, base_path, part_idx):
        return f"{base_path}_part{part_idx}.csv"


    def evaluate(self, val_loader, segm_loader, map_dict=None, grid_size=33):
        # Check if there is some data already
        base_path = self.csv_save_path.replace(".csv", "")
        part_idx = 1
        row_counter = 0
        self.test = 0

        # Check if there is data
        already_done = set()
        base_path = self.csv_save_path.replace(".csv", "")

        # Check all existing part files
        part_files = sorted(glob.glob(base_path + "_part*.csv"))
        if part_files:
            try:
                existing_dfs = [pd.read_csv(f) for f in part_files]
                existing_df = pd.concat(existing_dfs, ignore_index=True)
                already_done = set(existing_df["img_name"].unique())
                print(f"[INFO] Skipping {len(already_done)} unique images from {len(part_files)} files.")
                base_path += "_continued"
            except Exception as e:
                print(f"[WARN] Could not read part files due to: {e}")


        #### ITERATE THROUGH DATASE
        all_results = []

        for batch_idx, ((image_tensor, label, img_name), (segm, _)) in enumerate(tqdm(zip(val_loader, segm_loader), total=len(val_loader))):

            #### CHECK IF DATA IS ALREADY MADE
            if img_name in already_done:
                continue

            #### PREPROCESSING
            if self.label_mapping is not None:
                label = self.label_mapping(label)

            if self.segm_mapping is not None:
                segm = self.segm_mapping(segm, label, map_dict)


    

            # Init
            B, C, H, W = image_tensor.shape
            results = []
            image_tensor = image_tensor.to(self.device)

            # Prepare patch coordinates
            patch_h = H // grid_size
            patch_w = W // grid_size

            ########## GUIDANCE ##########

            ##### (A) Random
            if self.mode == "random":
                patch_coords = [(i, j) for i in range(grid_size) for j in range(grid_size)]
                random.shuffle(patch_coords)

            ##### (A1) ObjFirst-Random
            elif self.mode == "random_objfirst":
                obj_coords = []
                bg_coords = []

                for i in range(grid_size):
                    for j in range(grid_size):
                        patch = segm[i * patch_h : (i + 1) * patch_h, j * patch_w : (j + 1) * patch_w]
                        if patch.sum() > 0:  # Contains object pixels
                            obj_coords.append((i, j))
                        else:
                            bg_coords.append((i, j))

                random.shuffle(obj_coords)
                random.shuffle(bg_coords)
                patch_coords = obj_coords + bg_coords

            ##### (A2) Background First -Random
            elif self.mode == "random_bgfirst":
                obj_coords = []
                bg_coords = []

                for i in range(grid_size):
                    for j in range(grid_size):
                        patch = segm[i * patch_h : (i + 1) * patch_h, j * patch_w : (j + 1) * patch_w]
                        if patch.sum() > 0:  # Contains object pixels
                            obj_coords.append((i, j))
                        else:
                            bg_coords.append((i, j))

                random.shuffle(bg_coords)
                random.shuffle(obj_coords)
                patch_coords = bg_coords + obj_coords


            
            else:

                ##### (A3) Gauss Noise with Middlepoint
                if self.mode == "gauss":
                    # Compute object centroid
                    obj_indices = np.argwhere(segm == 1)
                    if len(obj_indices) == 0:
                        print("label", label, "segm unique", np.unique(segm), "IMAGENAME", img_name)

                    cy, cx = obj_indices.mean(axis=0)
                    # Generate Gaussian heatmap
                    Y, X = np.meshgrid(np.arange(segm.shape[0]), np.arange(segm.shape[1]), indexing='ij')
                    sigma = segm.shape[0] / 6.0  # You can tune this hyperparameter
                    heatmap = np.exp(-((X - cx) ** 2 + (Y - cy) ** 2) / (2 * sigma ** 2))


                ##### (A4) Perlin Noise
                elif self.mode == "perlin":
                    # Generate Perlin noise heatmap
                    heatmap = self.generate_perlin_noise(segm.shape[0], segm.shape[1], scale=50.0)




                ##### (A5) Gauss + Noise
                elif self.mode == "centred_perlin":
                    # Compute object centroid
                    obj_indices = np.argwhere(segm == 1)
                    if len(obj_indices) == 0:
                        raise ValueError("Segmentation mask contains no object pixels.")

                    cy, cx = obj_indices.astype(np.float32).mean(axis=0)
                    H, W = segm.shape

                    # Gaussian mask centered on object
                    Y, X = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')
                    sigma = H / 6.0
                    gaussian_center = np.exp(-((X - cx)**2 + (Y - cy)**2) / (2 * sigma**2))

                    # Generate Perlin noise
                    perlin = self.generate_perlin_noise(H, W, scale=50)

                    # Modulate Perlin noise with center-focused Gaussian
                    heatmap = perlin * gaussian_center


                ##### (B) SHAP
                elif self.mode == "shap":
                    # Compute Shap
                    heatmap = self.compute_shap(image_tensor, label)

                ##### (C) Grad-Cam
                elif self.mode == "gradcam":
                    # Compute Grad-Cam
                    heatmap = self.compute_gradcam(image_tensor, label)

                ##### (D) SmoothedCam
                elif self.mode == "scam":
                    # Compute Smoothed-Cam
                    heatmap = self.compute_smoothgrad(image_tensor, label)

                ##### (E) Integrated Gradients
                elif self.mode == "ig":
                    # Compute Integrated Gradients heatmap
                    heatmap = self.compute_integrated_gradients(image_tensor, label)

                ##### (F) Activation Maximation
                elif self.mode == "am":
                    # Compute Activation Maximation
                    heatmap = self.compute_activation_maximization(image_tensor, label)

                ##### (G) Occupancy Sensisivity
                elif self.mode == "os":
                    # Compute Loss Maxi
                    heatmap = self.compute_occlusion_sensitivity(image_tensor, label, grid_size)

                ##### (H) Loss
                elif self.mode == "loss":
                    # Compute Loss Maxi
                    heatmap = self.compute_loss_gradient(image_tensor, label)
                    
                ##### Error
                else:
                    raise ValueError(f"Unknown masking mode: {self.mode}")

                # Compute ranking
                cell_scores, patch_coords = self.compute_grid_cell_ranking(heatmap, grid_size)

                ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ##########
                if self.debug:
                    plt.imshow(heatmap)
                    plt.show()
                ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ##########

            #### BASELINE PREDICTION
            # Generate
            probs = self.predict(self.model, image_tensor)
            baseline_pred = probs[0, label].item()

            ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ##########
            #print("PROBS", probs)
            #print("LABEL", label)
            #if label != 32:
            #    stop
            ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ##########

            # Add to results
            results.append({
                "img_name": img_name,
                "gt_confidence": probs[0, label].item(),
                "pred_confidence": probs[0,  torch.argmax(probs, dim=1).item()].item(),
                "blindness": 0,
                "rel_blindness": 0,
            })


            #### ITERATE THORUGH IMAGE
            current_mask = torch.ones((1, 1, H, W), device=self.device)
            for step, (i, j) in enumerate(patch_coords):
                self.test = self.test + 1

                # Zero out next patch
                y_start, y_end = i * patch_h, (i + 1) * patch_h
                x_start, x_end = j * patch_w, (j + 1) * patch_w
                current_mask[:, :, y_start:y_end, x_start:x_end] = 0.0


                # ImageNet Normalization stats
                imagenet_mean = [0.485, 0.456, 0.406]
                imagenet_std  = [0.229, 0.224, 0.225]

                # Slice mask region
                region = image_tensor[:, :, y_start:y_end, x_start:x_end]

                
                if self.per_mode == "zero":
                    zero_pixel  = torch.tensor([0.0, 0.0, 0.0], device=image_tensor.device)
                    region[:] = zero_pixel.view(3, 1, 1)

                elif self.per_mode == "black":
                    black_pixel = torch.tensor([(0.0 - m) / s for m, s in zip(imagenet_mean, imagenet_std)], device=image_tensor.device)
                    region[:] = black_pixel.view(3, 1, 1)

                elif self.per_mode == "white":
                    white_pixel = torch.tensor([(1.0 - m) / s for m, s in zip(imagenet_mean, imagenet_std)], device=image_tensor.device)
                    region[:] = white_pixel.view(3, 1, 1)

                elif self.per_mode.startswith("blur"):
                    # Dynamic kernel size from per_mode (e.g. blur5, blur7, blur15)
                    try:
                        kernel_size = int(self.per_mode.replace("blur", ""))
                        if kernel_size % 2 == 0:
                            raise ValueError("Kernel size must be odd for symmetric padding")
                    except:
                        kernel_size = 5  # fallback
                    sigma = kernel_size / 3.0  # safe default
                    padding = kernel_size // 2
                    channels = region.shape[1]
                    H, W = region.shape[2:]

                    # Don't allow blur if kernel exceeds patch size
                    if kernel_size > H or kernel_size > W:
                        continue  # skip this patch

                    def get_gaussian_kernel2d(k, sigma, device):
                        x = torch.arange(k, dtype=torch.float32, device=device) - k // 2
                        gauss = torch.exp(-x**2 / (2 * sigma**2))
                        gauss = gauss / gauss.sum()
                        kernel_2d = gauss[:, None] @ gauss[None, :]
                        return kernel_2d / kernel_2d.sum()

                    kernel_2d = get_gaussian_kernel2d(kernel_size, sigma, region.device)
                    kernel_2d = kernel_2d.expand(channels, 1, kernel_size, kernel_size)

                    region_padded = F.pad(region, (padding, padding, padding, padding), mode='reflect')
                    region_blurred = F.conv2d(region_padded, kernel_2d, groups=channels)

                    region[:] = region_blurred

                elif self.per_mode == "mdp":
                    # Compute mean of the region in normalized space (no denormalization)
                    mean_rgb = region.mean(dim=[0, 2, 3])  # shape: [3]

                    # Define RGB cube corners in [0, 1]
                    rgb_corners = torch.tensor([
                        [0.0, 0.0, 0.0],  # black
                        [1.0, 0.0, 0.0],  # red
                        [0.0, 1.0, 0.0],  # green
                        [0.0, 0.0, 1.0],  # blue
                        [1.0, 1.0, 0.0],  # yellow
                        [1.0, 0.0, 1.0],  # magenta
                        [0.0, 1.0, 1.0],  # cyan
                        [1.0, 1.0, 1.0],  # white
                    ], device=image_tensor.device)

                    # Normalize corners to ImageNet model input space
                    imagenet_mean = torch.tensor([0.485, 0.456, 0.406], device=image_tensor.device)
                    imagenet_std  = torch.tensor([0.229, 0.224, 0.225], device=image_tensor.device)
                    normalized_corners = (rgb_corners - imagenet_mean[None, :]) / imagenet_std[None, :]

                    # Compute distances in normalized space
                    dists = torch.norm(normalized_corners - mean_rgb[None, :], dim=1)
                    farthest_pixel = normalized_corners[dists.argmax()]  # already normalized

                    # Replace the patch with the farthest color
                    region[:] = farthest_pixel.view(3, 1, 1)
                


                # Apply mask
                masked = image_tensor.clone()
                masked[:, :, y_start:y_end, x_start:x_end] = region


                ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ##########
                if self.debug:
                    # Inline-Denormalisierung ohne Funktion
                    imagenet_mean = [0.485, 0.456, 0.406]
                    imagenet_std  = [0.229, 0.224, 0.225]

                    new_masked = masked.clone()

                    # Falls masked Shape [1, 3, H, W] hat
                    if new_masked.dim() == 4:
                        new_masked = new_masked[0]

                    # Denormalisieren direkt inline
                    for c in range(3):
                        new_masked[c] = new_masked[c] * imagenet_std[c] + imagenet_mean[c]

                    # Visualisierung
                    plt.imshow(new_masked.cpu().permute(1, 2, 0).clamp(0, 1).numpy())
                    plt.title("Masked (Denormalized)")
                    plt.show()
                    if step > 2:
                        stop
                ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ##########


                # Predict
                probs = self.predict(self.model, masked)

                pred_class_idx = torch.argmax(probs, dim=1)
                gt_conf = probs[0, label].item()
                pred_conf = probs[0, pred_class_idx.item()].item()
                blindness, rel_blindness = self.calc_blindness(current_mask, segm, map_dict)

          
                results.append({
                    "img_name": img_name,
                    "gt_confidence": gt_conf,
                    "pred_confidence": pred_conf,
                    "blindness": blindness,
                    "rel_blindness": rel_blindness,
                })

            # Add batch results
            all_results.extend(results)

            # Empty memory
            torch.cuda.empty_cache()

            # Save in chunks of max_rows_per_file
            while len(all_results) >= self.max_rows_per_file:
                to_save = all_results[:self.max_rows_per_file]
                df = pd.DataFrame(to_save)
                df.to_csv(self.current_csv_path(base_path, part_idx), 
                          mode="a", 
                          index=False, 
                          header=not os.path.exists(self.current_csv_path(base_path, part_idx)))
                
                print(f"[INFO] Wrote {len(to_save)} rows to {self.current_csv_path(base_path, part_idx)}")
                all_results = all_results[self.max_rows_per_file:]
                part_idx += 1

        #### SAVE RESULTS
        # Save any remaining entries
        if all_results:
            df = pd.DataFrame(all_results)
            df.to_csv(self.current_csv_path(base_path, part_idx), 
                      mode="a", 
                      index=False, 
                      header=not os.path.exists(self.current_csv_path(base_path, part_idx)))

            print(f"[INFO] Wrote final {len(all_results)} rows to {self.current_csv_path(base_path, part_idx)}")




    ##### (A4) Perlin
    def generate_perlin_noise(self, height, width, scale=10.0, octaves=1, persistence=0.5, lacunarity=2.0, seed=None):
        if seed is None:
            seed = np.random.randint(0, 1000)

        perlin_noise = np.zeros((height, width))
        for i in range(height):
            for j in range(width):
                x = i / scale
                y = j / scale
                perlin_noise[i][j] = noise.pnoise2(
                    x, y,
                    octaves=octaves,
                    persistence=persistence,
                    lacunarity=lacunarity,
                    repeatx=width,
                    repeaty=height,
                    base=seed
                )

        # Normalize to [0, 1]
        perlin_noise = (perlin_noise - perlin_noise.min()) / (perlin_noise.max() - perlin_noise.min() + 1e-8)
        return perlin_noise


    ##### (B) SHAP
    def oldold_predict_fn(self, img_batch_nhwc: np.ndarray) -> np.ndarray:
        # Predict function for SHAP (expects NHWC numpy, returns softmax)
        with torch.no_grad():
            tensor = torch.tensor(img_batch_nhwc).permute(0, 3, 1, 2).to(self.device)
            preds = self.model(tensor)
            return torch.softmax(preds, dim=1).cpu().numpy()
        
    def predict_fn(self, img_batch_nhwc: np.ndarray) -> np.ndarray:
        # Predict function for SHAP (expects NHWC numpy, returns softmax probabilities)
        with torch.no_grad():
            tensor = torch.tensor(img_batch_nhwc).permute(0, 3, 1, 2).to(self.device)
            output = self.model(tensor)

            # Safe extraction of logits
            logits = self.get_logits(output)

            # Apply softmax
            probs = torch.softmax(logits, dim=1)
            return probs.cpu().numpy()

        
        
    def compute_shap(self, image_tensor, label):

        # Convert to NHWC numpy
        image_nhwc = image_tensor.permute(0, 2, 3, 1).cpu().numpy()

         # Init masker and explainer (per image, for SHAP.Image masking)
        self.masker = SHAPImage("mdp", image_nhwc.shape[1:])

        if self.class_names is None:
            self.explainer = shap.Explainer(self.predict_fn, self.masker)
        else:
            self.explainer = shap.Explainer(self.predict_fn, self.masker, output_names=self.class_names)


        # Compute SHAP values for this image (only for GT class)
        shap_values = self.explainer(image_nhwc, max_evals=1000, outputs=label)
        gt_shap = shap_values.values[0]     # shape: [B, H, W, C]
        gt_shap = gt_shap.squeeze()

        # Reduce to grayscale saliency
        saliency = np.mean(gt_shap, axis=-1)  # shape: [B, H, W]

        # Clear Jupyter Notebook
        clear_output(wait=True)

        return saliency


    ##### (C) GradCam

    # ----------  PATCH IN-PLACE RELU IN VGG ----------
    def replace_relu_with_out_of_place(self, module):
        for name, child in module.named_children():
            if isinstance(child, torch.nn.ReLU) and child.inplace:
                setattr(module, name, torch.nn.ReLU(inplace=False))
            else:
                self.replace_relu_with_out_of_place(child)  # <-- 🔧 Fix: `self.` hinzufügen

    # ---------------------------------------------------
    def compute_gradcam(self, image_tensor, label):
        # ----------- SETTINGS -----------
        if self.model_name == "resnet":
            target_layer = self.model.layer4
            is_vit = False
        elif self.model_name == "vgg16":
            target_layer = self.model.features[29]  # Last VGG16 conv layer
            is_vit = False

        elif self.model_name == "efficientnet":
            # Final MBConv -> block[3] is Conv2dNormActivation -> [0] is Conv2d
            target_layer = self.model.features[7][0].block[3][0]
            is_vit = False

        elif self.model_name == "densenet":
            # Last convolutional layer in denseblock4
            target_layer = self.model.features.denseblock4.denselayer16.conv2
            is_vit = False


        elif self.model_name == "vit_hf":
            # HuggingFace ViT: use last transformer block's LayerNorm before attention
            target_layer = self.model.vit.encoder.layer[11].layernorm_before
            is_vit = True

        elif self.model_name == "vitpets":
            # For ViT from timm (or similar), use last block norm (e.g., encoder block)
            target_layer = self.model.vit.encoder.layer[11].layernorm_before
            is_vit = True
        elif self.model_name == "vitflowers":
            # For ViT from timm (or similar), use last block norm (e.g., encoder block)
            #target_layer = self.model.blocks[11].mlp.fc2   Bad
            target_layer = self.model.blocks[11].norm1
            is_vit = True
        else:
            raise ValueError(f"Model {self.model_name} not supported for Grad-CAM.")
        # ----------- END SETTINGS -----------

        activations = []
        gradients = []

        def forward_hook(module, input, output):
            activations.append(output)

        def backward_hook(module, grad_input, grad_output):
            gradients.append(grad_output[0])

        handle_f = target_layer.register_forward_hook(forward_hook)
        handle_b = target_layer.register_full_backward_hook(backward_hook)

        # Forward pass
        output = self.model(image_tensor)
        logits = self.get_logits(output)
        class_idx = label
        score = logits[:, class_idx]

        # Backward pass
        self.model.zero_grad()
        score.backward(retain_graph=True)


        # ------------ Grad-CAM calculation ------------
        
        if is_vit:
            A = activations[0]       # [B, N, C] (including CLS)
            dA = gradients[0]        # [B, N, C]

            A = A[:, 1:, :]          # remove CLS token -> [B, N-1, C]
            dA = dA[:, 1:, :]        # same

            token_importance = (dA * A).sum(dim=2)  # [B, N-1]
            num_tokens = token_importance.shape[1]
            h = w = int(num_tokens ** 0.5)

            cam = token_importance.reshape(1, 1, h, w)
            cam = F.relu(cam)
        else:
            A = activations[0]       # [1, C, H', W']
            dA = gradients[0]        # [1, C, H', W']
            weights = dA.mean(dim=(2, 3), keepdim=True)  # [1, C, 1, 1]
            cam = (weights * A).sum(dim=1, keepdim=True)  # [1, 1, H', W']
            cam = F.relu(cam.clone())

        # Interpolate and normalize
        cam = F.interpolate(cam, size=(image_tensor.shape[2], image_tensor.shape[3]), mode='bilinear', align_corners=False)
        cam = cam.squeeze().detach().cpu().numpy()
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)  # normalize to [0, 1]

        handle_f.remove()
        handle_b.remove()
        return cam

    def old_compute_gradcam(self, image_tensor, label):

        # ----------- SETTINGS -----------
        if self.model_name == "resnet":
            target_layer = self.model.layer4
        elif self.model_name == "vgg16":
            target_layer = self.model.features[29]  # Last VGG16 convolutional layer
        else:
            raise ValueError(f"Model {self.model_name} not supported for Grad-CAM.")
        # ----------- END SETTINGS


        activations = []
        gradients = []

        def forward_hook(module, input, output):
            activations.append(output)

        def backward_hook(module, grad_input, grad_output):
            gradients.append(grad_output[0])

        handle_f = target_layer.register_forward_hook(forward_hook)
        handle_b = target_layer.register_full_backward_hook(backward_hook)

        # Forward pass
        output = self.model(image_tensor)
        logits = self.get_logits(output)
        class_idx = label
        score = logits[:, class_idx]
        
        # Backward pass
        self.model.zero_grad()
        score.backward(retain_graph=True)

        # Grad-CAM calculation
        A = activations[0]  # [1, C, H', W']
        dA = gradients[0]   # [1, C, H', W']
        weights = dA.mean(dim=(2, 3), keepdim=True)  # [1, C, 1, 1]
        cam = (weights * A).sum(dim=1, keepdim=True)  # [1, 1, H', W']

        if self.model_name == "resnet":
            cam = F.relu(cam)
        elif self.model_name == "vgg16":
            cam = F.relu(cam.clone())
        else:
            raise ValueError(f"Model {self.model_name} not supported for Grad-CAM.")


        cam = F.interpolate(cam, size=(image_tensor.shape[2], image_tensor.shape[3]), mode='bilinear', align_corners=False)
        cam = cam.squeeze().detach().cpu().numpy()
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)  # normalize to [0, 1]

        handle_f.remove()
        handle_b.remove()
        return cam
    
    ##### (D) SmoothedCam
    def compute_smoothgrad(self, image_tensor, label, n_samples=25, noise_level=0.15):
        self.model.zero_grad()
        image_tensor = image_tensor.clone().detach().requires_grad_(True)

        smooth_grad = torch.zeros_like(image_tensor)

        for _ in range(n_samples):
            noise = torch.randn_like(image_tensor) * noise_level
            noisy_input = image_tensor + noise

            output = self.model(noisy_input)
            logits = self.get_logits(output)
            score = logits[0, label]


            self.model.zero_grad()
            score.backward(retain_graph=True)

            grad = image_tensor.grad.detach()
            smooth_grad += grad
            image_tensor.grad.zero_()

        smooth_grad /= n_samples
        saliency = smooth_grad.abs().mean(dim=1, keepdim=True)  # Average across channels
        saliency = F.interpolate(saliency, size=(image_tensor.shape[2], image_tensor.shape[3]), mode='bilinear', align_corners=False)
        saliency = saliency.squeeze().detach().cpu().numpy()
        saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-8)
        return saliency
    

    ##### (E) Integrated Gradients
    def compute_integrated_gradients(self, image_tensor, label, baseline=None, steps=50):
        self.model.zero_grad()
        image_tensor = image_tensor.clone().detach().requires_grad_(True)

        if baseline is None:
            baseline = torch.zeros_like(image_tensor).to(self.device)

        # Interpolation from baseline to input
        scaled_inputs = [baseline + (float(i) / steps) * (image_tensor - baseline) for i in range(1, steps + 1)]
        scaled_inputs = torch.cat(scaled_inputs, dim=0).detach().requires_grad_()  # [steps, C, H, W]

        # Forward pass
        outputs = self.model(scaled_inputs)
        logits = self.get_logits(outputs)
        target_scores = logits[:, label]

        # Gradient of sum of target scores w.r.t. input
        grads = torch.autograd.grad(
            outputs=target_scores.sum(),
            inputs=scaled_inputs,
            create_graph=False,
            retain_graph=False,
            only_inputs=True
        )[0]  # [steps, C, H, W]

        avg_grads = grads.view(steps, *image_tensor.shape[1:]).mean(dim=0)  # [C, H, W]
        integrated_grad = (image_tensor - baseline).squeeze(0) * avg_grads  # [C, H, W]

        saliency = integrated_grad.abs().mean(dim=0, keepdim=True)  # [1, H, W]
        saliency = F.interpolate(saliency.unsqueeze(0), size=(image_tensor.shape[2], image_tensor.shape[3]), mode='bilinear', align_corners=False)
        saliency = saliency.squeeze().detach().cpu().numpy()
        saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-8)
        return saliency

    

    ##### (F) Activation Maximation
    def compute_activation_maximization(self, image_tensor, label, steps=30, lr=1e-2):
        image_tensor = image_tensor.clone().detach().to(self.device)
        image_tensor.requires_grad_(True)

        optimizer = torch.optim.Adam([image_tensor], lr=lr)

        for _ in range(steps):
            self.model.zero_grad()

            output = self.model(image_tensor)
            logits = self.get_logits(output)
            loss = -logits[0, label]

            loss.backward()
            optimizer.step()
            image_tensor.data = torch.clamp(image_tensor.data, -1.0, 1.0)  # Adjust if your input is in [0, 1]

        # Saliency is based on gradient of final optimized input
        saliency = image_tensor.grad.abs().mean(dim=1, keepdim=True)  # [1, 1, H, W]
        saliency = F.interpolate(saliency, size=(image_tensor.shape[2], image_tensor.shape[3]), mode='bilinear', align_corners=False)
        saliency = saliency.squeeze().detach().cpu().numpy()
        saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-8)
        return saliency
    

    ##### (H) Occlusion Sensitivity
    def compute_occlusion_sensitivity(self, image_tensor, label, grid_size):
        B, C, H, W = image_tensor.shape

        base_logits = self.get_logits(self.model(image_tensor))
        base_output = F.softmax(base_logits, dim=1)


        base_conf = base_output[0, label].item()

        heatmap = np.zeros((H, W))

        # Dynamic edges to guarantee full coverage
        y_edges = np.linspace(0, H, grid_size + 1, dtype=int)
        x_edges = np.linspace(0, W, grid_size + 1, dtype=int)

        for i in range(grid_size):
            for j in range(grid_size):
                y0, y1 = y_edges[i], y_edges[i + 1]
                x0, x1 = x_edges[j], x_edges[j + 1]

                
                #print("i=", i, "j=", j, "x0,x1,y0,y1", x0,x1,y0,y1)

                masked = image_tensor.clone()
                masked[:, :, y0:y1, x0:x1] = 0.0

                with torch.no_grad():
                    masked_logits = self.get_logits(self.model(masked))
                    out = F.softmax(masked_logits, dim=1)

                    masked_conf = out[0, label].item()
                    drop = base_conf - masked_conf

                heatmap[y0:y1, x0:x1] = drop

        # Normalize for visualization
        heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
        return heatmap
    

    ##### (H) Loss Maximation
    def compute_loss_gradient(self, image_tensor, label):
        self.model.zero_grad()
        image_tensor = image_tensor.clone().detach().requires_grad_(True)

        # Forward pass
        output = self.model(image_tensor)
        logits = self.get_logits(output)
        loss = F.cross_entropy(logits, torch.tensor([label], device=image_tensor.device))

        # Backward pass
        loss.backward()

        # Gradient of the loss w.r.t. the input
        grad = image_tensor.grad.detach()  # shape: [1, C, H, W]
        saliency = grad.abs().mean(dim=1, keepdim=True)  # [1, 1, H, W]
        saliency = F.interpolate(saliency, size=(image_tensor.shape[2], image_tensor.shape[3]), mode='bilinear', align_corners=False)
        
        saliency = saliency.squeeze().cpu().numpy()
        saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min() + 1e-8)
        return saliency





    ############ Helper Functions ############
    def compute_grid_cell_ranking(self, heatmap, grid_size):
        H, W = heatmap.shape
        y_edges = np.linspace(0, H, grid_size + 1, dtype=int)
        x_edges = np.linspace(0, W, grid_size + 1, dtype=int)

        scores = []
        coords = []

        for i in range(grid_size):
            for j in range(grid_size):
                y0, y1 = y_edges[i], y_edges[i + 1]
                x0, x1 = x_edges[j], x_edges[j + 1]

                #print("i=", i, "j=", j, "x0,x1,y0,y1", x0,x1,y0,y1)

                cell_mean = np.mean(heatmap[y0:y1, x0:x1])
                scores.append(cell_mean)
                coords.append((i, j))

        sorted_cells = sorted(zip(scores, coords), key=lambda x: -x[0])
        scores_sorted, coords_sorted = zip(*sorted_cells)
        return scores_sorted, coords_sorted

    def calc_blindness(self, mask, segm, map_dict):
        # Mask (0 -> Perturbated, 1 -> Original)
        # Segm (0 -> Background,  1 -> Object)
        # Convert
        mask = mask.squeeze(0).cpu().numpy()[0].astype(int)
        
        # Inverse Segm
        segm = 1 - segm


        ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ##########
        if self.debug:
            print("MASK")
            plt.imshow(mask)
            plt.colorbar()
            plt.show()
            print("SEGM")
            plt.imshow(segm)
            plt.colorbar()
            plt.show()

            fig, axs = plt.subplots(1, 2, figsize=(10, 5))
            axs[0].imshow(mask, cmap='viridis')
            axs[0].set_title("MASK")
            axs[0].axis("off")
            fig.colorbar(axs[0].imshow(mask, cmap='viridis'), ax=axs[0], shrink=0.7)
            axs[1].imshow(segm, cmap='viridis')
            axs[1].set_title("SEGM")
            axs[1].axis("off")
            fig.colorbar(axs[1].imshow(segm, cmap='viridis'), ax=axs[1], shrink=0.7)
            if self.test > 10:
                stop
        ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ##########


        # Calculate Metrics
        blind_region = np.where((mask == 0) & (segm == 1), 1, 0)
        obj_region = np.where((segm == 1), 1, 0)

        obj_region = obj_region.sum()
        blind_region = blind_region.sum()

        if obj_region > 0:
            rel_blindness = blind_region / obj_region
        else:
            rel_blindness = 0

        # Compute overall blindness
        image_area = mask.shape[0] * mask.shape[1]
        blindness = (1 - mask).sum() / image_area

        ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ##########
        if self.debug:
            print("BLIDNESS=", blindness, " Relative BLIDNESS=", rel_blindness)
        ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ########## DEBUG ##########
        return blindness, rel_blindness


    def _get_batch(self, loader, target_idx):
        for idx, batch in enumerate(loader):
            if idx == target_idx:
                yield batch
                return
            

    def get_logits(self, output):
        """
        Extracts logits from various possible model output types.
        Supports: Tensor, tuple/list, dict, and Huggingface ModelOutput.
        """
        if isinstance(output, (tuple, list)):
            # Meist Output[0] = logits (z.B. timm, torchvision, ältere torch Modelle)
            return output[0]
        elif hasattr(output, 'logits'):
            return output.logits
        elif isinstance(output, dict) and "logits" in output:
            return output["logits"]
        else:
            # fallback: ist vermutlich schon ein Tensor
            return output
