# Databricks notebook source
# MAGIC %md
# MAGIC # Visualizations

# COMMAND ----------

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy.optimize import curve_fit
from utils.evaluation import analyze_global_degradation_fit
from IPython.display import clear_output


# COMMAND ----------

# MAGIC %md
# MAGIC #### Random

# COMMAND ----------


STRATEGY = "A1_random"



#########################################          #########################################
PERTURBATION_MODE = "mdp"
NUM_GRID_ROW = 14
#########################################

MODEL = ["resnet34", "resnet50", "vgg16", "efficientnet", "densenet", "vit"]
rel_all_aubc = []
rel_all_rpbs = []
rel_all_acp05 = []
rel_all_acp08 = []
blindness_mode="rel_blindness"
for m in MODEL:
    path = "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + str(PERTURBATION_MODE) + "/imagenet/" + str(m) + "_" + str(STRATEGY) + ".csv"

    aubc, rbs, acps = analyze_global_degradation_fit(
        csv_path=path,
        output_path=None,
        title=None,
        blindness_col=blindness_mode,
        corr_analysis=False,
        quantile=0.999,
        return_data=False,
        return_metrics=True
    )
    rel_all_aubc.append(aubc)
    rel_all_rpbs.append(rbs)
    rel_all_acp05.append(acps['ACP_0.5'])
    rel_all_acp08.append(acps['ACP_0.8'])

all_aubc = []
all_rpbs = []
all_acp05 = []
all_acp08 = []
blindness_mode="blindness"
for m in MODEL:
    path = "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + str(PERTURBATION_MODE) + "/imagenet/" + str(m) + "_" + str(STRATEGY) + ".csv"
    aubc, rbs, acps = analyze_global_degradation_fit(
        csv_path=path,
        output_path=None,
        title=None,
        blindness_col=blindness_mode,
        corr_analysis=False,
        quantile=0.999,
        return_data=False,
        return_metrics=True
    )
    all_aubc.append(aubc)
    all_rpbs.append(rbs)
    all_acp05.append(acps['ACP_0.5'])
    all_acp08.append(acps['ACP_0.8'])


clear_output()
print("RELATIVE")
print(f"Mean AUBC: {np.nanmean(rel_all_aubc):.4f}")
print(f"Mean RPB: {np.nanmean(rel_all_rpbs):.4f}")
print(f"Mean ACP@0.5: {np.nanmean(rel_all_acp05):.4f}")
print(f"Mean ACP@0.8: {np.nanmean(rel_all_acp08):.4f}")
print("ABSOLUTE")
print(f"Mean AUBC: {np.nanmean(all_aubc):.4f}")
print(f"Mean RPB: {np.nanmean(all_rpbs):.4f}")
print(f"Mean ACP@0.5: {np.nanmean(all_acp05):.4f}")
print(f"Mean ACP@0.8: {np.nanmean(all_acp08):.4f}")
#########################################          #########################################


# COMMAND ----------

# MAGIC %md
# MAGIC #### SHAP

# COMMAND ----------


STRATEGY = "B_shap"



#########################################          #########################################
PERTURBATION_MODE = "mdp"
NUM_GRID_ROW = 14
#########################################

MODEL = ["resnet34", "resnet50", "vgg16", "efficientnet", "densenet", "vit"]
rel_all_aubc = []
rel_all_rpbs = []
rel_all_acp05 = []
rel_all_acp08 = []
blindness_mode="rel_blindness"
for m in MODEL:
    path = "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + str(PERTURBATION_MODE) + "/imagenet/" + str(m) + "_" + str(STRATEGY) + ".csv"

    aubc, rbs, acps = analyze_global_degradation_fit(
        csv_path=path,
        output_path=None,
        title=None,
        blindness_col=blindness_mode,
        corr_analysis=False,
        quantile=0.999,
        return_data=False,
        return_metrics=True
    )
    rel_all_aubc.append(aubc)
    rel_all_rpbs.append(rbs)
    rel_all_acp05.append(acps['ACP_0.5'])
    rel_all_acp08.append(acps['ACP_0.8'])

all_aubc = []
all_rpbs = []
all_acp05 = []
all_acp08 = []
blindness_mode="blindness"
for m in MODEL:
    path = "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + str(PERTURBATION_MODE) + "/imagenet/" + str(m) + "_" + str(STRATEGY) + ".csv"
    aubc, rbs, acps = analyze_global_degradation_fit(
        csv_path=path,
        output_path=None,
        title=None,
        blindness_col=blindness_mode,
        corr_analysis=False,
        quantile=0.999,
        return_data=False,
        return_metrics=True
    )
    all_aubc.append(aubc)
    all_rpbs.append(rbs)
    all_acp05.append(acps['ACP_0.5'])
    all_acp08.append(acps['ACP_0.8'])


clear_output()
print("RELATIVE")
print(f"Mean AUBC: {np.nanmean(rel_all_aubc):.4f}")
print(f"Mean RPB: {np.nanmean(rel_all_rpbs):.4f}")
print(f"Mean ACP@0.5: {np.nanmean(rel_all_acp05):.4f}")
print(f"Mean ACP@0.8: {np.nanmean(rel_all_acp08):.4f}")
print("ABSOLUTE")
print(f"Mean AUBC: {np.nanmean(all_aubc):.4f}")
print(f"Mean RPB: {np.nanmean(all_rpbs):.4f}")
print(f"Mean ACP@0.5: {np.nanmean(all_acp05):.4f}")
print(f"Mean ACP@0.8: {np.nanmean(all_acp08):.4f}")
#########################################          #########################################


# COMMAND ----------

# MAGIC %md
# MAGIC #### GradCam

# COMMAND ----------


STRATEGY = "C_gradcam"



#########################################          #########################################
PERTURBATION_MODE = "mdp"
NUM_GRID_ROW = 14
#########################################

MODEL = ["resnet34", "resnet50", "vgg16", "efficientnet", "densenet", "vit"]
rel_all_aubc = []
rel_all_rpbs = []
rel_all_acp05 = []
rel_all_acp08 = []
blindness_mode="rel_blindness"
for m in MODEL:
    path = "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + str(PERTURBATION_MODE) + "/imagenet/" + str(m) + "_" + str(STRATEGY) + ".csv"

    aubc, rbs, acps = analyze_global_degradation_fit(
        csv_path=path,
        output_path=None,
        title=None,
        blindness_col=blindness_mode,
        corr_analysis=False,
        quantile=0.999,
        return_data=False,
        return_metrics=True
    )
    rel_all_aubc.append(aubc)
    rel_all_rpbs.append(rbs)
    rel_all_acp05.append(acps['ACP_0.5'])
    rel_all_acp08.append(acps['ACP_0.8'])

all_aubc = []
all_rpbs = []
all_acp05 = []
all_acp08 = []
blindness_mode="blindness"
for m in MODEL:
    path = "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + str(PERTURBATION_MODE) + "/imagenet/" + str(m) + "_" + str(STRATEGY) + ".csv"
    aubc, rbs, acps = analyze_global_degradation_fit(
        csv_path=path,
        output_path=None,
        title=None,
        blindness_col=blindness_mode,
        corr_analysis=False,
        quantile=0.999,
        return_data=False,
        return_metrics=True
    )
    all_aubc.append(aubc)
    all_rpbs.append(rbs)
    all_acp05.append(acps['ACP_0.5'])
    all_acp08.append(acps['ACP_0.8'])


clear_output()
print("RELATIVE")
print(f"Mean AUBC: {np.nanmean(rel_all_aubc):.4f}")
print(f"Mean RPB: {np.nanmean(rel_all_rpbs):.4f}")
print(f"Mean ACP@0.5: {np.nanmean(rel_all_acp05):.4f}")
print(f"Mean ACP@0.8: {np.nanmean(rel_all_acp08):.4f}")
print("ABSOLUTE")
print(f"Mean AUBC: {np.nanmean(all_aubc):.4f}")
print(f"Mean RPB: {np.nanmean(all_rpbs):.4f}")
print(f"Mean ACP@0.5: {np.nanmean(all_acp05):.4f}")
print(f"Mean ACP@0.8: {np.nanmean(all_acp08):.4f}")
#########################################          #########################################


# COMMAND ----------

# MAGIC %md
# MAGIC #### SmoothedCam

# COMMAND ----------


STRATEGY = "D_smoothedcam"



#########################################          #########################################
PERTURBATION_MODE = "mdp"
NUM_GRID_ROW = 14
#########################################

MODEL = ["resnet34", "resnet50", "vgg16", "efficientnet", "densenet", "vit"]
rel_all_aubc = []
rel_all_rpbs = []
rel_all_acp05 = []
rel_all_acp08 = []
blindness_mode="rel_blindness"
for m in MODEL:
    path = "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + str(PERTURBATION_MODE) + "/imagenet/" + str(m) + "_" + str(STRATEGY) + ".csv"

    aubc, rbs, acps = analyze_global_degradation_fit(
        csv_path=path,
        output_path=None,
        title=None,
        blindness_col=blindness_mode,
        corr_analysis=False,
        quantile=0.999,
        return_data=False,
        return_metrics=True
    )
    rel_all_aubc.append(aubc)
    rel_all_rpbs.append(rbs)
    rel_all_acp05.append(acps['ACP_0.5'])
    rel_all_acp08.append(acps['ACP_0.8'])

all_aubc = []
all_rpbs = []
all_acp05 = []
all_acp08 = []
blindness_mode="blindness"
for m in MODEL:
    path = "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + str(PERTURBATION_MODE) + "/imagenet/" + str(m) + "_" + str(STRATEGY) + ".csv"
    aubc, rbs, acps = analyze_global_degradation_fit(
        csv_path=path,
        output_path=None,
        title=None,
        blindness_col=blindness_mode,
        corr_analysis=False,
        quantile=0.999,
        return_data=False,
        return_metrics=True
    )
    all_aubc.append(aubc)
    all_rpbs.append(rbs)
    all_acp05.append(acps['ACP_0.5'])
    all_acp08.append(acps['ACP_0.8'])


clear_output()
print("RELATIVE")
print(f"Mean AUBC: {np.nanmean(rel_all_aubc):.4f}")
print(f"Mean RPB: {np.nanmean(rel_all_rpbs):.4f}")
print(f"Mean ACP@0.5: {np.nanmean(rel_all_acp05):.4f}")
print(f"Mean ACP@0.8: {np.nanmean(rel_all_acp08):.4f}")
print("ABSOLUTE")
print(f"Mean AUBC: {np.nanmean(all_aubc):.4f}")
print(f"Mean RPB: {np.nanmean(all_rpbs):.4f}")
print(f"Mean ACP@0.5: {np.nanmean(all_acp05):.4f}")
print(f"Mean ACP@0.8: {np.nanmean(all_acp08):.4f}")
#########################################          #########################################


# COMMAND ----------

# MAGIC %md
# MAGIC #### IG

# COMMAND ----------


STRATEGY = "E_ig"



#########################################          #########################################
PERTURBATION_MODE = "mdp"
NUM_GRID_ROW = 14
#########################################

MODEL = ["resnet34", "resnet50", "vgg16", "efficientnet", "densenet", "vit"]
rel_all_aubc = []
rel_all_rpbs = []
rel_all_acp05 = []
rel_all_acp08 = []
blindness_mode="rel_blindness"
for m in MODEL:
    path = "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + str(PERTURBATION_MODE) + "/imagenet/" + str(m) + "_" + str(STRATEGY) + ".csv"

    aubc, rbs, acps = analyze_global_degradation_fit(
        csv_path=path,
        output_path=None,
        title=None,
        blindness_col=blindness_mode,
        corr_analysis=False,
        quantile=0.999,
        return_data=False,
        return_metrics=True
    )
    rel_all_aubc.append(aubc)
    rel_all_rpbs.append(rbs)
    rel_all_acp05.append(acps['ACP_0.5'])
    rel_all_acp08.append(acps['ACP_0.8'])

all_aubc = []
all_rpbs = []
all_acp05 = []
all_acp08 = []
blindness_mode="blindness"
for m in MODEL:
    path = "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + str(PERTURBATION_MODE) + "/imagenet/" + str(m) + "_" + str(STRATEGY) + ".csv"
    aubc, rbs, acps = analyze_global_degradation_fit(
        csv_path=path,
        output_path=None,
        title=None,
        blindness_col=blindness_mode,
        corr_analysis=False,
        quantile=0.999,
        return_data=False,
        return_metrics=True
    )
    all_aubc.append(aubc)
    all_rpbs.append(rbs)
    all_acp05.append(acps['ACP_0.5'])
    all_acp08.append(acps['ACP_0.8'])


clear_output()
print("RELATIVE")
print(f"Mean AUBC: {np.nanmean(rel_all_aubc):.4f}")
print(f"Mean RPB: {np.nanmean(rel_all_rpbs):.4f}")
print(f"Mean ACP@0.5: {np.nanmean(rel_all_acp05):.4f}")
print(f"Mean ACP@0.8: {np.nanmean(rel_all_acp08):.4f}")
print("ABSOLUTE")
print(f"Mean AUBC: {np.nanmean(all_aubc):.4f}")
print(f"Mean RPB: {np.nanmean(all_rpbs):.4f}")
print(f"Mean ACP@0.5: {np.nanmean(all_acp05):.4f}")
print(f"Mean ACP@0.8: {np.nanmean(all_acp08):.4f}")
#########################################          #########################################


# COMMAND ----------

# MAGIC %md
# MAGIC #### AM

# COMMAND ----------


STRATEGY = "F_am"



#########################################          #########################################
PERTURBATION_MODE = "mdp"
NUM_GRID_ROW = 14
#########################################

MODEL = ["resnet34", "resnet50", "vgg16", "efficientnet", "densenet", "vit"]
rel_all_aubc = []
rel_all_rpbs = []
rel_all_acp05 = []
rel_all_acp08 = []
blindness_mode="rel_blindness"
for m in MODEL:
    path = "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + str(PERTURBATION_MODE) + "/imagenet/" + str(m) + "_" + str(STRATEGY) + ".csv"

    aubc, rbs, acps = analyze_global_degradation_fit(
        csv_path=path,
        output_path=None,
        title=None,
        blindness_col=blindness_mode,
        corr_analysis=False,
        quantile=0.999,
        return_data=False,
        return_metrics=True
    )
    rel_all_aubc.append(aubc)
    rel_all_rpbs.append(rbs)
    rel_all_acp05.append(acps['ACP_0.5'])
    rel_all_acp08.append(acps['ACP_0.8'])

all_aubc = []
all_rpbs = []
all_acp05 = []
all_acp08 = []
blindness_mode="blindness"
for m in MODEL:
    path = "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + str(PERTURBATION_MODE) + "/imagenet/" + str(m) + "_" + str(STRATEGY) + ".csv"
    aubc, rbs, acps = analyze_global_degradation_fit(
        csv_path=path,
        output_path=None,
        title=None,
        blindness_col=blindness_mode,
        corr_analysis=False,
        quantile=0.999,
        return_data=False,
        return_metrics=True
    )
    all_aubc.append(aubc)
    all_rpbs.append(rbs)
    all_acp05.append(acps['ACP_0.5'])
    all_acp08.append(acps['ACP_0.8'])


clear_output()
print("RELATIVE")
print(f"Mean AUBC: {np.nanmean(rel_all_aubc):.4f}")
print(f"Mean RPB: {np.nanmean(rel_all_rpbs):.4f}")
print(f"Mean ACP@0.5: {np.nanmean(rel_all_acp05):.4f}")
print(f"Mean ACP@0.8: {np.nanmean(rel_all_acp08):.4f}")
print("ABSOLUTE")
print(f"Mean AUBC: {np.nanmean(all_aubc):.4f}")
print(f"Mean RPB: {np.nanmean(all_rpbs):.4f}")
print(f"Mean ACP@0.5: {np.nanmean(all_acp05):.4f}")
print(f"Mean ACP@0.8: {np.nanmean(all_acp08):.4f}")
#########################################          #########################################


# COMMAND ----------

# MAGIC %md
# MAGIC #### OC

# COMMAND ----------


STRATEGY = "G_os"



#########################################          #########################################
PERTURBATION_MODE = "mdp"
NUM_GRID_ROW = 14
#########################################

MODEL = ["resnet34", "resnet50", "vgg16", "efficientnet", "densenet", "vit"]
rel_all_aubc = []
rel_all_rpbs = []
rel_all_acp05 = []
rel_all_acp08 = []
blindness_mode="rel_blindness"
for m in MODEL:
    path = "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + str(PERTURBATION_MODE) + "/imagenet/" + str(m) + "_" + str(STRATEGY) + ".csv"

    aubc, rbs, acps = analyze_global_degradation_fit(
        csv_path=path,
        output_path=None,
        title=None,
        blindness_col=blindness_mode,
        corr_analysis=False,
        quantile=0.999,
        return_data=False,
        return_metrics=True
    )
    rel_all_aubc.append(aubc)
    rel_all_rpbs.append(rbs)
    rel_all_acp05.append(acps['ACP_0.5'])
    rel_all_acp08.append(acps['ACP_0.8'])

all_aubc = []
all_rpbs = []
all_acp05 = []
all_acp08 = []
blindness_mode="blindness"
for m in MODEL:
    path = "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + str(PERTURBATION_MODE) + "/imagenet/" + str(m) + "_" + str(STRATEGY) + ".csv"
    aubc, rbs, acps = analyze_global_degradation_fit(
        csv_path=path,
        output_path=None,
        title=None,
        blindness_col=blindness_mode,
        corr_analysis=False,
        quantile=0.999,
        return_data=False,
        return_metrics=True
    )
    all_aubc.append(aubc)
    all_rpbs.append(rbs)
    all_acp05.append(acps['ACP_0.5'])
    all_acp08.append(acps['ACP_0.8'])


clear_output()
print("RELATIVE")
print(f"Mean AUBC: {np.nanmean(rel_all_aubc):.4f}")
print(f"Mean RPB: {np.nanmean(rel_all_rpbs):.4f}")
print(f"Mean ACP@0.5: {np.nanmean(rel_all_acp05):.4f}")
print(f"Mean ACP@0.8: {np.nanmean(rel_all_acp08):.4f}")
print("ABSOLUTE")
print(f"Mean AUBC: {np.nanmean(all_aubc):.4f}")
print(f"Mean RPB: {np.nanmean(all_rpbs):.4f}")
print(f"Mean ACP@0.5: {np.nanmean(all_acp05):.4f}")
print(f"Mean ACP@0.8: {np.nanmean(all_acp08):.4f}")
#########################################          #########################################


# COMMAND ----------

# MAGIC %md
# MAGIC #### Loss

# COMMAND ----------


STRATEGY = "H_loss"



#########################################          #########################################
PERTURBATION_MODE = "mdp"
NUM_GRID_ROW = 14
#########################################

MODEL = ["resnet34", "resnet50", "vgg16", "efficientnet", "densenet", "vit"]
rel_all_aubc = []
rel_all_rpbs = []
rel_all_acp05 = []
rel_all_acp08 = []
blindness_mode="rel_blindness"
for m in MODEL:
    path = "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + str(PERTURBATION_MODE) + "/imagenet/" + str(m) + "_" + str(STRATEGY) + ".csv"

    aubc, rbs, acps = analyze_global_degradation_fit(
        csv_path=path,
        output_path=None,
        title=None,
        blindness_col=blindness_mode,
        corr_analysis=False,
        quantile=0.999,
        return_data=False,
        return_metrics=True
    )
    rel_all_aubc.append(aubc)
    rel_all_rpbs.append(rbs)
    rel_all_acp05.append(acps['ACP_0.5'])
    rel_all_acp08.append(acps['ACP_0.8'])

all_aubc = []
all_rpbs = []
all_acp05 = []
all_acp08 = []
blindness_mode="blindness"
for m in MODEL:
    path = "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + str(PERTURBATION_MODE) + "/imagenet/" + str(m) + "_" + str(STRATEGY) + ".csv"
    aubc, rbs, acps = analyze_global_degradation_fit(
        csv_path=path,
        output_path=None,
        title=None,
        blindness_col=blindness_mode,
        corr_analysis=False,
        quantile=0.999,
        return_data=False,
        return_metrics=True
    )
    all_aubc.append(aubc)
    all_rpbs.append(rbs)
    all_acp05.append(acps['ACP_0.5'])
    all_acp08.append(acps['ACP_0.8'])


clear_output()
print("RELATIVE")
print(f"Mean AUBC: {np.nanmean(rel_all_aubc):.4f}")
print(f"Mean RPB: {np.nanmean(rel_all_rpbs):.4f}")
print(f"Mean ACP@0.5: {np.nanmean(rel_all_acp05):.4f}")
print(f"Mean ACP@0.8: {np.nanmean(rel_all_acp08):.4f}")
print("ABSOLUTE")
print(f"Mean AUBC: {np.nanmean(all_aubc):.4f}")
print(f"Mean RPB: {np.nanmean(all_rpbs):.4f}")
print(f"Mean ACP@0.5: {np.nanmean(all_acp05):.4f}")
print(f"Mean ACP@0.8: {np.nanmean(all_acp08):.4f}")
#########################################          #########################################


# COMMAND ----------

csvs = [
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_A1_random.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_B_shap.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_C_gradcam.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_D_smoothedcam.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_E_ig.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_F_am.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_G_os.csv",
    "/Workspace/Users/ANONYM/Constrain_Framework/results/patch_" + str(NUM_GRID_ROW) + "_" + PERTURBATION_MODE + "/imagenet/" + str(MODEL) + "_H_loss.csv"
]