# Standard library imports
import os
import uuid
import random
import csv

# Third party library imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader

# Local Imports
import utils
from dataset import RSNADataset
from clf import HemorrhageDetector
from instance_explainer import InstanceExplainer

"""
PREREQS
"""
# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Relevant directories
model_dir = 'models'
data_dir = 'data'
rsna_dir = '/export/gaon1/data/jteneggi/data/rsna-intracranial-hemorrhage-detection/RSNA'
results_dir = os.path.join("results")

# Load data and background image
val_data = RSNADataset(
    data_dir=rsna_dir,
    op="val",
    weak_supervision=True,
)
bkgd = torch.load(os.path.join(data_dir, "mean_img.pt")).to(device)
df = pd.read_csv(os.path.join(data_dir, "test_idx.csv"), header=None)
test_idx = df.iloc[:,0].tolist()

# Load classifier
clf_state_dict = torch.load(os.path.join(model_dir, 'wl_model.pt'))
clf = HemorrhageDetector(
    encoder="resnet18",
    n_dim=128,
    hidden_size=64,
    embedding_dropout=0.50,
    attention_dropout=0.25,
    attention_activation="softmax",
)
clf.load_state_dict(clf_state_dict)
clf = clf.to(device)
clf.eval()
print('Classifier loaded successfully')

# Load explainers
area = 0.02
num_samples = 3
num_steps = 2500
learning_rate = 1e-2
norm_type = "top_kp"
suff_explainer = InstanceExplainer(clf, device, 1, constraint=False, num_channels=3)
necc_explainer = InstanceExplainer(clf, device, 0, constraint=False, num_channels=3)
uni_explainer = InstanceExplainer(clf, device, 0.5, constraint=False, num_channels=3)
exp_types = ["sufficiency", "necessity", "unified"]


"""
MAIN FUNCTION
"""
if __name__ == "__main__":
    y_0 = 0.647
    init_mask = "rand"

    sp_mult_1 = 3
    sm_mult_1 = 20

    sp_mult_2 = 3
    sm_mult_2 = 20
    sh_mult = 0

    top_kp = 1
    N = len(test_idx)
    thresholds = np.linspace(0, 1, 101)
    results_suff = {}
    results_nec = {}
    results_uni = {}

    for i, idx in enumerate(test_idx):
        img, x, label = val_data[idx]
        x = x.to(device).unsqueeze(0)
        f_x = clf(x).item()
        y_hat = (f_x >= 0.5) * 1.0
        y_0 = 0
        for exp_type in exp_types:
            if exp_type == "sufficiency":
                # Sufficiency
                mask, mask_01 = suff_explainer(x, y_0, bkgd, num_samples, num_steps, learning_rate, 
                                sp_mult_1, sm_mult_1, sh_mult, 0.01, C=0, M=4, log_frac=1, 
                                init_mask=init_mask, return_logs=False)
                suff_list = []
                necc_list = []
                l0_list = []
                for k in range(num_samples):
                    gs_attr = utils.process_attr(mask_01[0][k].unsqueeze(0), top_kp=top_kp, norm_type=norm_type)
                    suff, necc, l0 = utils.compute_metrics_batch(
                        clf, x, gs_attr, y_hat, y_0, bkgd.unsqueeze(0).to(device), thresholds
                    )
                    suff_list.append(suff)
                    necc_list.append(necc)
                    l0_list.append(l0)

                suff = np.array(suff_list).mean(axis=0)
                necc = np.array(necc_list).mean(axis=0)
                l0 = np.array(l0_list).mean(axis=0)

                results_suff[i] = {"suff": suff, "necc": necc, "l0": l0}
            elif exp_type == "necessity":
                # Necessity
                mask, mask_01 = necc_explainer(x, y_0, bkgd, num_samples, num_steps, learning_rate, 
                                sp_mult_2, sm_mult_2, sh_mult, 0.01, C=0, M=4, log_frac=1, 
                                init_mask=init_mask, return_logs=False)
                suff_list = []
                necc_list = []
                l0_list = []
                for k in range(num_samples):
                    gs_attr = utils.process_attr(mask_01[0][k].unsqueeze(0), top_kp=top_kp, norm_type=norm_type)
                    suff, necc, l0 = utils.compute_metrics_batch(
                        clf, x, gs_attr, y_hat, y_0, bkgd.unsqueeze(0).to(device), thresholds
                    )
                    suff_list.append(suff)
                    necc_list.append(necc)
                    l0_list.append(l0)

                suff = np.array(suff_list).mean(axis=0)
                necc = np.array(necc_list).mean(axis=0)
                l0 = np.array(l0_list).mean(axis=0)

                results_nec[i] = {"suff": suff, "necc": necc, "l0": l0}
            
            elif exp_type == "unified":
                # Unification
                mask, mask_01 = uni_explainer(x, y_0, bkgd, num_samples, num_steps, learning_rate, 
                                sp_mult_2, sm_mult_2, sh_mult, 0.01, C=0, M=4, log_frac=1, 
                                init_mask=init_mask, return_logs=False)
                suff_list = []
                necc_list = []
                l0_list = []
                for k in range(num_samples):
                    gs_attr = utils.process_attr(mask_01[0][k].unsqueeze(0), top_kp=top_kp, norm_type=norm_type)
                    suff, necc, l0 = utils.compute_metrics_batch(
                        clf, x, gs_attr, y_hat, y_0, bkgd.unsqueeze(0).to(device), thresholds
                    )
                    suff_list.append(suff)
                    necc_list.append(necc)
                    l0_list.append(l0)

                suff = np.array(suff_list).mean(axis=0)
                necc = np.array(necc_list).mean(axis=0)
                l0 = np.array(l0_list).mean(axis=0)

                results_uni[i] = {"suff": suff, "necc": necc, "l0": l0}
        print(f"Progress: {i}/{N} complete", end="\r")
    torch.save(results_suff, os.path.join(results_dir, norm_type + "_" + "suff_results_final.pt"))
    torch.save(results_nec, os.path.join(results_dir, norm_type + "_" + "nec_results_final.pt"))
    torch.save(results_uni, os.path.join(results_dir, norm_type + "_" + "uni_results_final.pt"))