# Databricks notebook source
# MAGIC %md
# MAGIC # Visualizations

# COMMAND ----------

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy.optimize import curve_fit
from utils.evaluation import analyze_global_degradation_fit


def plot_degradation_curves_for_strategies(
    model,
    csv_paths,
    strategy_names,
    save_path,
    blindness_mode="blindness",
    quantile=0.999,
    thresholds=[0.5, 0.8],
    num_bins=100,
    plot_legend=False
):
    """
    Plot degradation curves from multiple CSVs and fit sigmoid to each.
    """
    assert len(csv_paths) == len(strategy_names), "Each path must have a strategy name"

    def neg_sigmoid(x, a, b, c):
        return a * (1 - 1 / (1 + np.exp(-b * (x - c))))

    # Plot setup
    plt.figure(figsize=(8, 6))
    results = []

    for path, name in zip(csv_paths, strategy_names):
        try:
            a, b, c = analyze_global_degradation_fit(
                csv_path=path,
                output_path=None,
                title=None,
                blindness_col=blindness_mode,
                corr_analysis=False,
                quantile=quantile,
                thresholds=thresholds,
                num_bins=num_bins,
                return_data=True
            )

            # Generate x/y for fitted curve
            x_fit = np.linspace(0, 1, 100)
            y_fit = neg_sigmoid(x_fit, a, b, c)
            plt.plot(x_fit, y_fit, label=name, linewidth=2)

            # Compute metrics
            dy = np.gradient(y_fit, x_fit)
            rbs = x_fit[np.argmax(np.abs(dy))]
            aubc = np.trapz(y_fit, x_fit)
            acps = {}
            for tau in thresholds:
                below = y_fit < tau
                acps[f"ACP_{tau}"] = np.min(x_fit[below]) if np.any(below) else np.nan

            # Store results
            result = {"strategy": name, "a": a, "b": b, "c": c, "rbs": rbs, "aubc": aubc}
            result.update(acps)
            results.append(result)

        except Exception as e:
            print(f"[WARN] Failed for {name}: {e}")

    # Finalize plot
    if blindness_mode == "rel_blindness":
        plt.xlabel("Relative Blindness", fontsize=20)
    else:
        plt.xlabel("Blindness", fontsize=20)


    plt.ylabel("Confidence C", fontsize=20)
    plt.grid(True)
    """
    if model == "resnet34":
         plt.title("ResNet34", fontsize=16)

    elif model == "resnet50":
        plt.title("ResNet50", fontsize=16)

    elif model == "vgg16":
        plt.title("VGG16", fontsize=16)

    elif model == "efficientnet":
        plt.title("EfficientNet", fontsize=16)

    elif model == "densenet":
        plt.title("DenseNet", fontsize=16)

    elif model == "vit":
        plt.title("ViT", fontsize=16)

    """

    if plot_legend:
        plt.legend(fontsize=18)

    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.show()
    plt.close()

    # Return metrics as DataFrame
    return pd.DataFrame(results)





# COMMAND ----------

# MAGIC %md
# MAGIC #### ResNet14

# COMMAND ----------

MODEL = "resnet34"
PERTURBATION_MODE = "mdp"
NUM_GRID_ROW = 14

csvs = [
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_A1_random.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_B_shap.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_C_gradcam.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_D_smoothedcam.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_E_ig.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_F_am.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_G_os.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_H_loss.csv"
]

strategies = ["Rand.", "SHAP", "GC", "SC", "IG", "AM", "OS", "Loss"]

df = plot_degradation_curves_for_strategies(
    model=MODEL,
    csv_paths=csvs,
    strategy_names=strategies,
    save_path="/Workspace/Users/ANONYM/Constrain_Framework/figs/degra_" + str(MODEL) + ".pdf",
    blindness_mode="blindness",
)

df = plot_degradation_curves_for_strategies(
    model=MODEL,
    csv_paths=csvs,
    strategy_names=strategies,
    save_path="/Workspace/Users/ANONYM/Constrain_Framework/figs/rel_degra_" + str(MODEL) + ".pdf",
    blindness_mode="rel_blindness",
    plot_legend=True
)



# COMMAND ----------

# MAGIC %md
# MAGIC #### ResNet50

# COMMAND ----------

MODEL = "resnet50"
PERTURBATION_MODE = "mdp"
NUM_GRID_ROW = 14

csvs = [
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_A1_random.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_B_shap.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_C_gradcam.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_D_smoothedcam.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_E_ig.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_F_am.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_G_os.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_H_loss.csv"
]

strategies = ["Rand.", "SHAP", "GC", "SC", "IG", "AM", "OS", "Loss"]

df = plot_degradation_curves_for_strategies(
    model=MODEL,
    csv_paths=csvs,
    strategy_names=strategies,
    save_path="/Workspace/Users/ANONYM/Constrain_Framework/figs/degra_" + str(MODEL) + ".pdf",
    blindness_mode="blindness"
)

df = plot_degradation_curves_for_strategies(
    model=MODEL,
    csv_paths=csvs,
    strategy_names=strategies,
    save_path="/Workspace/Users/ANONYM/Constrain_Framework/figs/rel_degra_" + str(MODEL) + ".pdf",
    blindness_mode="rel_blindness",
    plot_legend=True
)



# COMMAND ----------

# MAGIC %md
# MAGIC #### VGG16

# COMMAND ----------

MODEL = "vgg16"
PERTURBATION_MODE = "mdp"
NUM_GRID_ROW = 14

csvs = [
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_A1_random.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_B_shap.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_C_gradcam.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_D_smoothedcam.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_E_ig.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_F_am.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_G_os.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_H_loss.csv"
]

strategies = ["Rand.", "SHAP", "GC", "SC", "IG", "AM", "OS", "Loss"]

df = plot_degradation_curves_for_strategies(
    model=MODEL,
    csv_paths=csvs,
    strategy_names=strategies,
    save_path="/Workspace/Users/ANONYM/Constrain_Framework/figs/degra_" + str(MODEL) + ".pdf",
    blindness_mode="blindness"
)

df = plot_degradation_curves_for_strategies(
    model=MODEL,
    csv_paths=csvs,
    strategy_names=strategies,
    save_path="/Workspace/Users/ANONYM/Constrain_Framework/figs/rel_degra_" + str(MODEL) + ".pdf",
    blindness_mode="rel_blindness"
)



# COMMAND ----------

# MAGIC %md
# MAGIC #### EfficientNet

# COMMAND ----------

MODEL = "efficientnet"
PERTURBATION_MODE = "mdp"
NUM_GRID_ROW = 14

csvs = [
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_A1_random.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_B_shap.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_C_gradcam.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_D_smoothedcam.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_E_ig.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_F_am.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_G_os.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_H_loss.csv"
]

strategies = ["Rand.", "SHAP", "GC", "SC", "IG", "AM", "OS", "Loss"]

df = plot_degradation_curves_for_strategies(
    model=MODEL,
    csv_paths=csvs,
    strategy_names=strategies,
    save_path="/Workspace/Users/ANONYM/Constrain_Framework/figs/degra_" + str(MODEL) + ".pdf",
    blindness_mode="blindness"
)

df = plot_degradation_curves_for_strategies(
    model=MODEL,
    csv_paths=csvs,
    strategy_names=strategies,
    save_path="/Workspace/Users/ANONYM/Constrain_Framework/figs/rel_degra_" + str(MODEL) + ".pdf",
    blindness_mode="rel_blindness"
)


# COMMAND ----------

# MAGIC %md
# MAGIC #### DenseNet

# COMMAND ----------

MODEL = "densenet"
PERTURBATION_MODE = "mdp"
NUM_GRID_ROW = 14

csvs = [
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_A1_random.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_B_shap.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_C_gradcam.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_D_smoothedcam.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_E_ig.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_F_am.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_G_os.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_H_loss.csv"
]

strategies = ["Rand.", "SHAP", "GC", "SC", "IG", "AM", "OS", "Loss"]

df = plot_degradation_curves_for_strategies(
    model=MODEL,
    csv_paths=csvs,
    strategy_names=strategies,
    save_path="/Workspace/Users/ANONYM/Constrain_Framework/figs/degra_" + str(MODEL) + ".pdf",
    blindness_mode="blindness"
)

df = plot_degradation_curves_for_strategies(
    model=MODEL,
    csv_paths=csvs,
    strategy_names=strategies,
    save_path="/Workspace/Users/ANONYM/Constrain_Framework/figs/rel_degra_" + str(MODEL) + ".pdf",
    blindness_mode="rel_blindness"
)


# COMMAND ----------

# MAGIC %md
# MAGIC #### ViT

# COMMAND ----------

MODEL = "vit"
PERTURBATION_MODE = "mdp"
NUM_GRID_ROW = 14

csvs = [
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_A1_random.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_B_shap.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_C_gradcam.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_D_smoothedcam.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_E_ig.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_F_am.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_G_os.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_H_loss.csv"
]

strategies = ["Rand.", "SHAP", "GC", "SC", "IG", "AM", "OS", "Loss"]

df = plot_degradation_curves_for_strategies(
    model=MODEL,
    csv_paths=csvs,
    strategy_names=strategies,
    save_path="/Workspace/Users/ANONYM/Constrain_Framework/figs/degra_" + str(MODEL) + ".pdf",
    blindness_mode="blindness"
)

df = plot_degradation_curves_for_strategies(
    model=MODEL,
    csv_paths=csvs,
    strategy_names=strategies,
    save_path="/Workspace/Users/ANONYM/Constrain_Framework/figs/rel_degra_" + str(MODEL) + ".pdf",
    blindness_mode="rel_blindness"
)
