# Standard library imports
import os
import uuid
import random
import csv

# Third party library imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset

import hshap

from captum.attr import IntegratedGradients
from captum.attr import GradientShap
from captum.attr import GuidedGradCam
from captum.attr import Occlusion

# Local Imports
import tools
from dataset import CelebAHQ
from clf import Classifier

"""
PREREQS
"""

# Set device
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

# Directory names
root_dir = "/export/io85/data/bbharti1/CelebAMask-HQ"
model_dir = os.path.join("checkpoints", "final_models")
img_dir = os.path.join(root_dir, "CelebA-HQ-img")
results_dir = os.path.join("results")

# Load background sample
bkgd = torch.load(os.path.join("data", "mean_img.pt")).to(device)
h, w = bkgd.shape[-2:]
num_mask_entries = h*w

# Load trained classifier
clf_state_dict_path = os.path.join(model_dir, "clf.pt")
clf = tools.load_classifier(clf_state_dict_path, eval_mode=True)
clf.to(device)

# Load validation csv
val_csv_path = os.path.join(root_dir, "val.csv")
val_csv = pd.read_csv(val_csv_path, index_col=[0])
val_data = CelebAHQ(img_dir, val_csv, "Smiling", transform=False)

# Dataloader
batch_size = 1
dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=4)

# Load different explanation methods
int_grad = IntegratedGradients(clf)
grad_shap = GradientShap(clf)
guided_cam = GuidedGradCam(clf, clf.resnet.layer4)

s = 64
R = np.linspace(0, s, 4, endpoint=False)
A = np.linspace(0, 2 * np.pi, 8, endpoint=False)
hexp = hshap.src.Explainer(model=clf, background=bkgd.squeeze(0))

# Load trained suff and nec explainer
suff_exp_state_dict_path = os.path.join(model_dir, "suff_explainer_v2.pt")
suff_explainer = tools.load_explainer(clf, device, bkgd, state_dict_path=suff_exp_state_dict_path, eval_mode=True)
suff_explainer.to(device)

nec_exp_state_dict_path = os.path.join(model_dir, "nec_explainer_v2.pt")
nec_explainer = tools.load_explainer(clf, device, bkgd, state_dict_path=nec_exp_state_dict_path, eval_mode=True)
nec_explainer.to(device)

exp_types = ["int_grad", "grad_shap", "guided_cam", "hshap", "sufficiency", "necessity"]
norm_type = "top_kp"

"""
MAIN FUNCTION
"""
if __name__ == "__main__":
    top_kp = 1
    N = 100
    i = 0
    thresholds = np.linspace(0.01, 0.99, 99)
    results_ig = {}
    results_gs = {}
    results_gc = {}
    results_hs = {}
    results_suff = {}
    results_nec = {}

    for j, data in enumerate(dataloader):
        if i >= N:
            break
        x, y = data
        x = x.to(device)
        x_flat = x.view(x.shape[0], x.shape[1], num_mask_entries)
        f_x = clf(x).item()
        y_hat = (f_x >= 0.5) * 1.0
        y_0 = 0.181

        if f_x >= 0.9:  # True positives
            for exp_type in exp_types:
                if exp_type == "int_grad":
                    # Integrated gradients
                    attr = int_grad.attribute(x, n_steps=500).squeeze(0)
                    gs_attr = tools.process_attr(attr, top_kp=top_kp, norm_type=norm_type)
                    suff, necc, l0 = tools.compute_metrics_batch(
                        clf, x, gs_attr, f_x, y_0, bkgd.unsqueeze(0), thresholds
                    )
                    results_ig[i] = {"suff": suff, "necc": necc, "l0": l0}
                elif exp_type == "grad_shap":
                    # Gradient Shapley
                    attr = grad_shap.attribute(
                        x, n_samples=500, baselines=bkgd.unsqueeze(0)
                    ).squeeze(0)
                    gs_attr = tools.process_attr(attr, top_kp=top_kp, norm_type=norm_type)
                    suff, necc, l0 = tools.compute_metrics_batch(
                        clf, x, gs_attr, f_x, y_0, bkgd.unsqueeze(0), thresholds
                    )
                    results_gs[i] = {"suff": suff, "necc": necc, "l0": l0}
                elif exp_type == "guided_cam":
                    # Guided GradCam
                    attr = guided_cam.attribute(x, 0).squeeze(0)
                    attr = attr.detach()
                    gs_attr = tools.process_attr(attr, top_kp=top_kp, norm_type=norm_type)
                    suff, necc, l0 = tools.compute_metrics_batch(
                        clf, x, gs_attr, f_x, y_0, bkgd, thresholds
                    )
                    results_gc[i] = {"suff": suff, "necc": necc, "l0": l0}
                elif exp_type == "hshap":
                    # h-shap
                    with torch.no_grad():
                        try:
                            attr = hexp.cycle_explain(
                                x=x.squeeze(0),
                                label=0,
                                s=s,
                                R=R,
                                A=A,
                                threshold_mode="absolute",
                                threshold=0.0,
                                softmax_activation=False,
                                batch_size=2,
                                binary_map=False,
                            )
                        except:
                            attr = torch.zeros(1, 256, 256)
                    attr = attr.to(device)
                    gs_attr = tools.process_attr(attr, top_kp=top_kp, norm_type=norm_type)
                    suff, necc, l0 = tools.compute_metrics_batch(
                        clf, x, gs_attr, f_x, y_0, bkgd.unsqueeze(0), thresholds
                    )
                    results_hs[i] = {"suff": suff, "necc": necc, "l0": l0}
                elif exp_type == "sufficiency":
                    # Sufficiency
                    mask = torch.sigmoid(suff_explainer(x).squeeze(1))
                    gs_attr = tools.process_attr(mask, top_kp=top_kp, norm_type=norm_type)
                    suff, necc, l0 = tools.compute_metrics_batch(
                        clf, x, gs_attr, f_x, y_0, bkgd.unsqueeze(0).to(device), thresholds)
                    results_suff[i] = {"suff": suff, "necc": necc, "l0": l0}
                elif exp_type == "necessity":
                    # Necessity
                    mask = torch.sigmoid(nec_explainer(x).squeeze(1))
                    gs_attr = tools.process_attr(mask, top_kp=top_kp, norm_type=norm_type)
                    suff, necc, l0 = tools.compute_metrics_batch(
                    clf, x, gs_attr, f_x, y_0, bkgd.unsqueeze(0).to(device), thresholds)

                    results_nec[i] = {"suff": suff, "necc": necc, "l0": l0}

            i += 1
            print(f"Progress: {i}/{N} complete", end="\r")
    torch.save(results_ig, os.path.join(results_dir, norm_type + "_" + "ig_results.pt"))
    torch.save(results_gs, os.path.join(results_dir, norm_type + "_" + "gs_results.pt"))
    torch.save(results_gc, os.path.join(results_dir, norm_type + "_" + "gc_results.pt"))
    torch.save(results_hs, os.path.join(results_dir, norm_type + "_" + "hs_results.pt"))
    torch.save(results_suff, os.path.join(results_dir, norm_type + "_" + "suff_results.pt"))
    torch.save(results_nec, os.path.join(results_dir, norm_type + "_" + "nec_results.pt"))