#%%
import sys
import os
sys.path.append(os.path.abspath(os.path.join(__file__, "../src")))
sys.path.append(os.path.abspath(os.path.join(__file__, "../src/record")))
import argparse
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

import torch
from torch import Tensor

from train import give_config, initiate
import logger
import tbwriter

parser = argparse.ArgumentParser("bort-infer")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--device", type=int, nargs="+", default=[0])
parser.add_argument("--save_path", type=str, default=None)
# dataset
parser.add_argument("--dataset", type=str, default="mnist", 
                    help="cifar10 / mnist")
# model
parser.add_argument("--model", type=str, default="simple", 
                    help="simple / lenetfc / lenet / lenetc / nomaxnetfc / resnet50(nn.Conv2d) / resnet18(nn.Conv2d)")
parser.add_argument("--recon_ratio", type=float, default=0.95)
parser.add_argument("--setting", type=str, default=None, 
                        help="It only applies to AllConv12 model")
parser.add_argument("--act_type", type=str, default="guided",
                        help="It only applies to AllConv12 model: guided / leaky")
# optimizer
parser.add_argument("--optim", type=str, default="bort",
                    help="sgd / adamw / bort / abort")
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--wc", type=float, default=0.001)
parser.add_argument("--wd", type=float, default=0.01)
parser.add_argument("--scheduler", type=str, default=None, help="cosine / none")
parser.add_argument("--warmup_lr", type=float, default=1e-6)
parser.add_argument("--min_lr", type=float, default=1e-7)
parser.add_argument("--warmup_epochs", type=int, default=5)
# training
parser.add_argument("--bs", type=int, default=96)
parser.add_argument("--epochs", type=int, default=20)

# inference
parser.add_argument("--resume", type=str, default=None, 
                    help="path to the resume model checkpoint")
parser.add_argument("--trace_layer", type=str, default="gpool")
parser.add_argument("--nvis", type=int, default=1)

opt = parser.parse_args(args=[])
config, args_dict = give_config(opt, "infer-")
config.device = torch.device(f"cuda:{config.device[0]}")

### NOTE: select settings here first
suffix = "sgd"
# suffix = "bort"

# dataset = "mnist"; gamma = 3; scale = 20; idx_list = [4, 7, 11, 29, 53, 67]
# dataset = "cifar10"; gamma = 2; scale = 10; idx_list = [4, 7, 11, 29, 53, 67]
dataset = "imagenet"; gamma = 1; scale = 1; idx_list = [709, 6, 975, 360, 72, 964, 818, 752, 800, 96, 612, 162]

if suffix == "sgd":
    if dataset == "mnist":
        opt.resume = "./log/mnist-simple-sgd"
    elif dataset == "cifar10":
        opt.resume = "./log/cifar10-simple-sgd"
    elif dataset == "imagenet":
        opt.resume = "./log/allconv12-base-guided-adamw"
elif suffix == "bort":
    if dataset == "mnist":
        opt.resume = "./log/mnist-simple-bort"
    elif dataset == "cifar10":
        opt.resume = "./log/cifar10-simple-bort"
    elif dataset == "imagenet":
        opt.resume = "./log/allconv12-base-guided-abort"

def normalize(input: Tensor, gamma=1, scale=1):
    input = np.power(input, gamma)
    if input.shape[-1] == 3:
        input = input.transpose(2, 0, 1)
        input[0] = normalize(input[0])
        input[1] = normalize(input[1])
        input[2] = normalize(input[2])
        input = input.transpose(1, 2, 0)
    else:
        input = (input - input.min()) / (input.max() - input.min()) * scale
    return input

def show_image(images, size=(5,5), title="", ax=None, gamma=gamma, scale=scale):
    bs = images.shape[0]
    cmap = None
    if images.ndim == 4:
        if images.shape[1] == 1:
            cmap = "gray"
    elif images.ndim == 3:
        cmap = "gray"
    
    if ax is None:
        fig, ax = plt.subplots(1, bs, figsize=(bs*size[0], size[1]))
    if bs == 1:
        tmp = images[0]
        if tmp.ndim == 3:
            tmp = tmp.permute(1, 2, 0)
        if dataset == "cifar10":
            tmp = 0.5 * tmp + 0.5
        tmp = torch.maximum(tmp, torch.zeros_like(tmp))
        tmp = tmp.squeeze().cpu().numpy()
        tmp = normalize(tmp, gamma=gamma, scale=scale)
        ax.imshow(tmp, cmap=cmap, vmin=0.0, vmax=1.0)
    else:
        for i in range(bs):
            tmp = images[i]
            if tmp.ndim == 3:
                tmp = tmp.permute(1, 2, 0)
            tmp = normalize(tmp, gamma=gamma, scale=scale)
            tmp = tmp - tmp.min()
            tmp = tmp.squeeze().cpu().numpy()
            ax[i].imshow(tmp, cmap=cmap, vmin=0.0, vmax=1.0)
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(f"./results/reconstruction/{dataset}/{title}.png", dpi=500)

config.resume = opt.resume
objects = initiate(config, force_replace={
    "is_dist": False,
    "save_path": "./log/infer-resume-save-cache"
})
model = objects["model"]
loader = objects["loaders"]["test"]
device = objects["device"]

def gen_img(idx):
    data, _ = loader.dataset.__getitem__(idx)
    data = data.unsqueeze(0).to(device)

    # draw the original image
    show_image(data, title=f"{idx}-ori")

    # register forward hook
    def register_hook(model, data):
        handle_dict = {}
        infeat_dict = {}
        outfeat_dict = {}
        def get_hook(name):
            def _hook(module, input, output):
                infeat_dict[name] = input
                outfeat_dict[name] = output
            return _hook

        if dataset == "imagenet" and model == "vgg16":
            for k, v in model.features.named_children():
                handle_dict[k] = v.register_forward_hook(get_hook(k))
            handle_dict["31"] = model.avgpool.register_forward_hook(get_hook("31"))
            handle_dict["fc"] = model.classifier.register_forward_hook(get_hook("fc"))
        else:
            for k, v in model.named_children():
                handle_dict[k] = v.register_forward_hook(get_hook(k))

        # model forward
        model.eval()
        output = model(data)
        for v in handle_dict.values(): v.remove()
        return infeat_dict, outfeat_dict, output

    model.trace(True)
    infeat_dict, outfeat_dict, output = register_hook(model, data)

    if dataset == "imagenet":
        l_n = 8
        k = f"act{l_n}"
        to_vis = f"conv{l_n}"
    else:
        k = "act3"
        to_vis = "conv3"

    hidden_feature = infeat_dict[k][0]
    x = hidden_feature.clone()

    # model traceback
    # NOTE: trace back 
    recon_data, recon_all = model.trace_back(
        x=x, 
        module_name=k, return_all=True
    )
    model.trace(False)
    print(f"optim: {suffix}, vis layer: {to_vis}")
    print(f"recon max value: {recon_data.max()}")
    print(f"recon min value: {recon_data.min()}")

    if suffix == "bort":
        show_image(recon_data, title=f"{idx}-bort")
    elif suffix == "sgd":
        show_image(recon_data, title=f"{idx}-sgd")
    else:
        raise KeyError(f"Invalid suffix: {suffix}")

for idx in idx_list:
    gen_img(idx)
#%%
