# Standard library imports
import os
import uuid
import random
import csv

# Third party library imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset

import hshap

from captum.attr import IntegratedGradients
from captum.attr import GradientShap
from captum.attr import GuidedGradCam
from captum.attr import Occlusion

# Local Imports
import utils
from dataset import RSNADataset
from clf import HemorrhageDetector

"""
PREREQS
"""

# Set device
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

# Load background
bkgd = torch.load("mean_img.pt").to(device)

# Dataset
model_dir = 'models'
rsna_dir = '/export/gaon1/data/jteneggi/data/rsna-intracranial-hemorrhage-detection/RSNA'
val_data = RSNADataset(
    data_dir=rsna_dir,
    op="val",
    weak_supervision=True,
)
results_dir = os.path.join("results")
df = pd.read_csv("test_idx.csv", header=None)
test_idx = df.iloc[:,0].tolist()

# Load classifier
clf_state_dict = torch.load(os.path.join(model_dir, 'wl_model.pt'))
clf = HemorrhageDetector(
    encoder="resnet18",
    n_dim=128,
    hidden_size=64,
    embedding_dropout=0.50,
    attention_dropout=0.25,
    attention_activation="softmax",
)
clf.load_state_dict(clf_state_dict)
clf = clf.to(device)
clf.eval()
print('Classifier loaded successfully')

# Load different explanation methods
int_grad = IntegratedGradients(clf)
grad_shap = GradientShap(clf)
guided_cam = GuidedGradCam(clf, clf.encoder.layer4)

s = 64
R = np.linspace(0, s, 4, endpoint=False)
A = np.linspace(0, 2 * np.pi, 8, endpoint=False)
hexp = hshap.src.Explainer(model=clf, background=bkgd.squeeze(0))
exp_types = ["int_grad"]
norm_type = "top_kp"

"""
MAIN FUNCTION
"""
if __name__ == "__main__":
    top_kp = 1
    N = len(test_idx)
    i = 0
    thresholds = np.linspace(0, 1, 101)
    results_ig = {}
    results_gs = {}
    results_gc = {}
    results_hs = {}
    results_suff = {}
    results_nec = {}

    for idx in test_idx:
        img, x, label = val_data[idx]
        x = x.to(device).unsqueeze(0)
        f_x = clf(x).item()
        y_hat = (f_x >= 0.5) * 1.0
        y_0 = 0
        for exp_type in exp_types:
            if exp_type == "int_grad":
                # Integrated gradients
                attr = int_grad.attribute(x, n_steps=100).squeeze(0)
                gs_attr = utils.process_attr(attr, top_kp=top_kp, norm_type=norm_type)
                suff, necc, l0 = utils.compute_metrics_batch(
                    clf, x, gs_attr, y_hat, y_0, bkgd.unsqueeze(0), thresholds
                )
                results_ig[i] = {"suff": suff, "necc": necc, "l0": l0}
            elif exp_type == "grad_shap":
                # Gradient Shapley
                attr = grad_shap.attribute(
                    x, n_samples=5, baselines=bkgd.unsqueeze(0)
                ).squeeze(0)
                gs_attr = utils.process_attr(attr, top_kp=top_kp, norm_type=norm_type)
                suff, necc, l0 = utils.compute_metrics_batch(
                    clf, x, gs_attr, y_hat, y_0, bkgd.unsqueeze(0), thresholds
                )
                results_gs[i] = {"suff": suff, "necc": necc, "l0": l0}
            elif exp_type == "guided_cam":
                # Guided GradCam
                attr = guided_cam.attribute(x, 0).squeeze(0)
                attr = attr.detach()
                gs_attr = utils.process_attr(attr, top_kp=top_kp, norm_type=norm_type)
                suff, necc, l0 = utils.compute_metrics_batch(
                    clf, x, gs_attr, y_hat, y_0, bkgd, thresholds
                )
                results_gc[i] = {"suff": suff, "necc": necc, "l0": l0}
            elif exp_type == "hshap":
                with torch.no_grad():
                    # h-shap
                    try:
                        attr = hexp.cycle_explain(
                            x=x.squeeze(0),
                            label=0,
                            s=s,
                            R=R,
                            A=A,
                            threshold_mode="absolute",
                            threshold=0.0,
                            softmax_activation=False,
                            batch_size=2,
                            binary_map=False,
                        )
                    except:
                        attr = torch.zeros(1, 512, 512)
                attr = attr.to(device)
                gs_attr = utils.process_attr(attr, top_kp=top_kp, norm_type=norm_type)
                suff, necc, l0 = utils.compute_metrics_batch(
                    clf, x, gs_attr, y_hat, y_0, bkgd.unsqueeze(0), thresholds
                )
                results_hs[i] = {"suff": suff, "necc": necc, "l0": l0}
            elif exp_type == "sufficiency":
                # Sufficiency
                mask = suff_explainer(x)
                mask_01, x_S, x_Sc = suff_explainer.sample(x, mask, bkgd, 5)
                attr = torch.mean(mask_01, dim=1)
                gs_attr = utils.process_attr(attr, top_kp=top_kp, norm_type=norm_type)
                suff, necc, l0 = utils.compute_metrics_batch(
                    clf, x, gs_attr, y_hat, y_0, bkgd.unsqueeze(0), thresholds
                )
                results_suff[i] = {"suff": suff, "necc": necc, "l0": l0}
            elif exp_type == "necessity":
                # Necessity
                mask = nec_explainer(x)
                mask_01, x_S, x_Sc = nec_explainer.sample(x, mask, bkgd, 5)
                attr = torch.mean(mask_01, dim=1)
                gs_attr = utils.process_attr(attr, top_kp=top_kp, norm_type=norm_type)
                suff, necc, l0 = utils.compute_metrics_batch(
                    clf, x, gs_attr, y_hat, y_0, bkgd.unsqueeze(0), thresholds
                )
                results_nec[i] = {"suff": suff, "necc": necc, "l0": l0}

        i += 1
        print(f"Progress: {i}/{N} complete", end="\r")
        torch.save(results_ig, os.path.join(results_dir, norm_type + "_" + "ig_results.pt"))
        # torch.save(results_gs, os.path.join(results_dir, norm_type + "_" + "gs_results.pt"))
        # torch.save(results_gc, os.path.join(results_dir, norm_type + "_" + "gc_results.pt"))
        # torch.save(results_hs, os.path.join(results_dir, norm_type + "_" + "hs_results.pt"))
        # torch.save(results_suff, os.path.join(results_dir, norm_type + "_" + "suff_results.pt"))
        # torch.save(results_nec, os.path.join(results_dir, norm_type + "_" + "nec_results.pt"))