#%%
import sys
import os
sys.path.append(os.path.abspath(os.path.join(__file__, "../src")))
sys.path.append(os.path.abspath(os.path.join(__file__, "../src/record")))
import numpy as np
import h5py
import time
import yaml
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib as mpl
from PIL import Image

mpl.rcParams["text.color"] = "k"
mpl.rcParams["xtick.color"] = "k"
mpl.rcParams["ytick.color"] = "k"
mpl.rcParams["axes.labelcolor"] = "k"

from benchmark.misc import load_data_model, random_seed, show_samples, absjoin
random_seed(42)
from benchmark.methods import give_method
from bort.datasets import _get_transform

class CONFIG:
    # NOTE: Model path
    _model_path_ = ...
    setting = "base"
    act_type = ...

    dataset = ...
    model = ...
    optim_name = ...
    act_type = ...
    xai_name = "tracetopk"
    layer_name = "act8"
    recon_ratio = 1.

    _data_path_ = {
        "mnist": {
            "data": "~/datasets/mnist"
        },
        "cifar10": {
            "data": "~/datasets/cifar10"
        },
        "imagenet": {
            "data": absjoin(__file__, "../data/imagenet50.p"),
            "label": absjoin(__file__, "../data/imagenet50_labels.p"),
            "map": absjoin(__file__, "../data/imagenet_mapping_labels.p")
        }
    }

# dataset = "mnist"; idx_list = [943, 296, 847, 70, 281, 914, 394, 456, 71, 39, 336, 979]; alpha1 = 0.3; alpha2 = 0.8
# dataset = "cifar10"; idx_list = [943, 296, 847, 70, 281, 914, 394, 456, 71, 39, 336, 979]; alpha1 = 0.9; alpha2 = 0.5
dataset = "imagenet"; idx_list = [15, 27, 1, 8, 43, 37, 29, 49, 18, 41, 19, 46]; alpha1 = 0.9; alpha2 = 0.4 # [15, 27, 1, 8, 43, 37, 29, 49, 18, 41, 19, 46]

optim = "bort" 
# optim = "sgd"

CONFIG.dataset = dataset
if optim == "bort":
    if dataset == "mnist":
        mpath = "./log/mnist-simple-bort"; ch = 1
    elif dataset == "cifar10":
        mpath = "./log/cifar10-simple-bort"; ch = 3
    elif dataset == "imagenet":
        mpath = "./log/allconv12-base-guided-abort"; ch = 3
        CONFIG.model = "allconv"
        transform = _get_transform(CONFIG, False)
elif optim == "sgd":
    if dataset == "mnist":
        mpath = "./log/mnist-simple-sgd"; ch = 1
    elif dataset == "cifar10":
        mpath = "./log/cifar10-simple-sgd"; ch = 3
    elif dataset == "imagenet":
        mpath = "./log/allconv12-base-guided-adamw"; ch = 3
        CONFIG.model = "allconv"
        transform = _get_transform(CONFIG, False)

CONFIG._model_path_ = mpath

# Read config file
config_dict = yaml.load(open(os.path.join(CONFIG._model_path_, "config.yaml")), Loader=yaml.FullLoader)
CONFIG.dataset = config_dict["dataset"]
CONFIG.model = config_dict["model"]
CONFIG.optim_name = config_dict["optim"]
CONFIG.act_type = config_dict.get("act_type", "leaky")
for k in dir(CONFIG):
    if not k.startswith("_") and not k.endswith("_"):
        print(f"{k}: {getattr(CONFIG, k)}")
data, label, mapping, model = load_data_model(CONFIG)
if data.ndim == 4 and data.shape[-1] != 3:
    data = data.transpose(0, 2, 3, 1)

def show_samples(config, data, label, mapping):
    for i in range(len(idx_list)):
        _, ax = plt.subplots(1, 1, figsize=(5, 5))
        ax.axis("off")
        idx = idx_list[i]
        # label_str = mapping[label[idx]]
        if data.ndim == 3:
            ax.imshow(data[idx, :], cmap="gray")
        elif data.ndim == 4:
            if data.shape[-1] != 3:
                data = data.transpose(0, 2, 3, 1)
            if dataset != "imagenet":
                if data.min() < -0.1:
                    data = 0.5 * data + 0.5
                ax.imshow(data[idx, :])
            else:
                data_to_show = data[idx, :] / 255.
                ax.imshow(data_to_show)
        else:
            raise TypeError(f"Invalid ndim of images: {data.ndim}")
        # ax[i].set_title(label_str)
        plt.tight_layout()
        plt.savefig(f"./results/saliency_map/{dataset}/{idx}-ori.png")
    return idx_list

vis_idx_list = show_samples(CONFIG, data, label, mapping)

xai_obj = give_method(CONFIG)
layer_name = CONFIG.layer_name
for idx in vis_idx_list:
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    sub_label, sub_img = label[idx], data[idx]
    if sub_img.shape[-1] == 3:
        input = sub_img.transpose(2, 0, 1)
    else:
        input = sub_img

    if dataset == "imagenet":
        input = Image.fromarray(input.astype("uint8").transpose(1, 2, 0))
        input = transform(input)

    start = time.time()
    output, mask, pred = xai_obj(model, input, layer_name)
    end = time.time()
    print(f"Running time: {end - start}s")

    ax.imshow(sub_img/255, alpha=alpha1, cmap="gray")
    ax.imshow(mask, alpha=alpha2, cmap="jet")
    ax.axis("off")
    plt.tight_layout()
    plt.savefig(f"./results/saliency_map/{dataset}/{idx}-{optim}-{CONFIG.xai_name}.png")

print("done")

#%%