#%%
import sys
import os
sys.path.append(os.path.abspath(os.path.join(__file__, "../src")))
sys.path.append(os.path.abspath(os.path.join(__file__, "../src/record")))
import argparse
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

import torch
from torch import Tensor

from train import give_config, initiate
import logger
import tbwriter

#%%
parser = argparse.ArgumentParser("bort-infer")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--device", type=int, nargs="+", default=[0])
parser.add_argument("--save_path", type=str, default=None)
# dataset
parser.add_argument("--dataset", type=str, default="mnist", 
                    help="cifar10 / mnist")
# model
parser.add_argument("--model", type=str, default="simple", 
                    help="simple / lenetfc / lenet / lenetc / nomaxnetfc / resnet50(nn.Conv2d) / resnet18(nn.Conv2d)")
parser.add_argument("--recon_ratio", type=float, default=0.95)
parser.add_argument("--act_type", type=str, default="guided",
                        help="It only applies to AllConv12 model: guided / leaky")
# optimizer
parser.add_argument("--optim", type=str, default="bort",
                    help="sgd / adamw / bort / abort")
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--wc", type=float, default=0.001)
parser.add_argument("--wd", type=float, default=0.01)
parser.add_argument("--scheduler", type=str, default=None, help="cosine / none")
parser.add_argument("--warmup_lr", type=float, default=1e-6)
parser.add_argument("--min_lr", type=float, default=1e-7)
parser.add_argument("--warmup_epochs", type=int, default=5)
# training
parser.add_argument("--bs", type=int, default=96)
parser.add_argument("--epochs", type=int, default=20)

# inference
parser.add_argument("--resume", type=str, default=None, 
                    help="path to the resume model checkpoint")
parser.add_argument("--trace_layer", type=str, default="gpool")
parser.add_argument("--nvis", type=int, default=6)

opt = parser.parse_args(args=[])
config, args_dict = give_config(opt, "infer-")
config.device = torch.device(f"cuda:{config.device[0]}")

# suffix = "sgd"
suffix = "bort"

dataset = "mnist"
# dataset = "cifar10"
# dataset = "imagenet"

if suffix == "sgd":
    if dataset == "mnist":
        opt.resume = "./log/mnist-simple-sgd"
elif suffix == "bort":
    if dataset == "mnist":
        opt.resume = "./log/mnist-simple-bort"

def normalize(input: Tensor):
    if input.ndim == 3:
        input = input.permute(2, 0, 1)
        input[0] = normalize(input[0])
        input[1] = normalize(input[1])
        input[2] = normalize(input[2])
        input = input.permute(1, 2, 0)
    else:
        input = (input - input.min()) / (input.max() - input.min())
    return input

def show_image(images, size=(5,5), title="", ax=None, is_save=True):
    bs = images.shape[0]
    cmap = None
    if images.ndim == 4:
        if images.shape[1] == 1:
            cmap = "gray"
    elif images.ndim == 3:
        cmap = "gray"
    
    if ax is None:
        fig, ax = plt.subplots(1, bs, figsize=(bs*size[0], size[1]))
    if bs == 1:
        tmp = images[0]
        if tmp.ndim == 3:
            tmp = tmp.permute(1, 2, 0)
        if dataset == "cifar10":
            tmp = 0.5 * tmp + 0.5
        tmp = torch.maximum(tmp, torch.zeros_like(tmp))
        tmp = tmp.squeeze().cpu().numpy()
        tmp = normalize(tmp)
        ax.imshow(tmp, cmap=cmap, vmin=0.0, vmax=1.0)
        # ax.set_title(title)
    else:
        for i in range(bs):
            tmp = images[i]
            if tmp.ndim == 3:
                tmp = tmp.permute(1, 2, 0)
            tmp = normalize(tmp)
            tmp = tmp - tmp.min()
            tmp = tmp.squeeze().cpu().numpy()
            ax[i].imshow(tmp, cmap=cmap, vmin=0.0, vmax=1.0)
            # ax[i].set_title(title)
    plt.axis("off")
    if is_save:
        plt.tight_layout()
        plt.savefig(f"./results/decompose/{title}.png", dpi=500)

#%%
config.resume = opt.resume
objects = initiate(config, force_replace={
    "is_dist": False,
    "save_path": "./log/infer-resume-save-cache"
})
model = objects["model"]
loader = objects["loaders"]["test"]
device = objects["device"]

data, label = next(iter(loader))
data = data.to(device)[:12]
label = label[:12]

# register forward hook
def register_hook(model, data):
    handle_dict = {}
    infeat_dict = {}
    outfeat_dict = {}
    def get_hook(name):
        def _hook(module, input, output):
            infeat_dict[name] = input
            outfeat_dict[name] = output
        return _hook

    if dataset == "imagenet" and model == "vgg16":
        for k, v in model.features.named_children():
            handle_dict[k] = v.register_forward_hook(get_hook(k))
        handle_dict["31"] = model.avgpool.register_forward_hook(get_hook("31"))
        handle_dict["fc"] = model.classifier.register_forward_hook(get_hook("fc"))
    else:
        for k, v in model.named_children():
            handle_dict[k] = v.register_forward_hook(get_hook(k))

    # model forward
    model.eval()
    output = model(data)
    for v in handle_dict.values(): v.remove()
    return infeat_dict, outfeat_dict, output

model.trace(True)
infeat_dict, outfeat_dict, output = register_hook(model, data)

k = "act3"
to_vis = "conv3"
k_list = ["act1", "conv2", "act2", "conv3", "act3"]
to_vis_list = ["conv1", "apool1", "conv2", "apool2", "conv3"]

idx = 2
c_idx = 1
figsize = 18
fmt = ".2f"
reserve_idx = 0
topk_n = 20

hidden_feature = infeat_dict[k][0]
all_hidden_feature = hidden_feature.clone()
x = hidden_feature.clone()

# model traceback
# NOTE: trace back 
recon_data, recon_all = model.trace_back(
    x=x, 
    module_name=k, return_all=True
)
model.trace(False)
print(f"optim: {suffix}, vis layer: {to_vis}")
print(f"recon max value: {recon_data.max()}")
print(f"recon min value: {recon_data.min()}")

#%% 
# Decompose the top-layer features
topk_num = 64
sample_idx = 0
def filter_bound(max_idx, h, w):
    mask = (max_idx // w != 0) & (max_idx // w != (h-1)) & (max_idx % w != 0 ) & (max_idx % w != (w-1))
    return mask

model.trace(True)
sub_data = data[sample_idx:(sample_idx + 1)].clone()
show_image(sub_data, title=f"{sample_idx}-ori")
infeat_dict, outfeat_dict, output = register_hook(model, sub_data)
hidden_feature = infeat_dict["act3"][0]

sample_x = hidden_feature.clone()
tmp_x = torch.zeros_like(sample_x)

model.trace(True)
recon_img = model.trace_back(x=sample_x*10, module_name="act3")
show_image(recon_img, title=f"{sample_idx}-recon-{suffix}")
model.trace(False)

# edit tmp_x
max_value, max_idx = sample_x.view(1, 2592, -1).max(dim=-1)
best_value, best_idx = max_value.max(dim=-1)
tmp_x[0][best_idx[0]][3][3] = 2000

# edit sample_x
bs, ch_n, h, w = sample_x.size()
sample_x = sample_x.view(bs, ch_n, -1)
chm_sample_x, max_idx = sample_x.max(dim=-1)
mask_bound = filter_bound(max_idx, h, w)
chm_sample_x[~mask_bound] = - 100
chm_topk_v, chm_topk = chm_sample_x.topk(k=topk_num, dim=-1, largest=True)
# chm_topk = torch.randint(low=0, high=2592, size=(1, topk_num)).to(sample_x.device)
mask_ch = torch.zeros(bs, ch_n).to(sample_x.device)
# mask_ch.scatter_(1, chm_topk, 1)
mask_ch.scatter_(1, chm_topk, chm_topk_v)
mask_ch = mask_ch.view(bs, ch_n, 1)
mask_x = torch.zeros_like(sample_x)
mask_x.scatter_(2, max_idx.unsqueeze(-1), 1)
sample_x = sample_x * mask_x * mask_ch
sample_x = sample_x.view(bs, ch_n, h, w) * 50

model.trace(True)
recon_img = model.trace_back(x=sample_x, module_name="act3")
show_image(recon_img, title=f"{sample_idx}-topk64-{suffix}")
model.trace(False)

model.trace(True)
recon_zeros = model.trace_back(x=torch.zeros_like(sample_x), module_name="act3")
# show_image(recon_zeros)
model.trace(False)

# heat = (recon_img - recon_zeros).abs().sum(dim=1).squeeze()
# fig, ax = plt.subplots(figsize=(10, 10))
# sns.heatmap(heat.detach().cpu().numpy(), fmt=fmt, linewidths=0.5, ax=ax)

# draw the decomposition parts
recon_img_list = []
fig, ax = plt.subplots(8, 8, figsize=(20, 20))
for i in range(8):
    for j in range(8):
        idx = 8*i + j
        tmp_x = torch.zeros_like(sample_x).view(bs, ch_n, -1)
        cur_chidx = chm_topk.flatten()[idx]
        cur_spidx = max_idx.flatten()[cur_chidx]
        tmp_x[0][cur_chidx][cur_spidx] = 3000
        model.trace(True)
        recon_img = model.trace_back(x=tmp_x.view(bs, ch_n, h, w), module_name="act3")
        recon_img_list.append(recon_img)
        # show_image(recon_img, ax=ax[i][j], is_save=False)
        ax[i][j].imshow(normalize(recon_img.cpu().detach().numpy()).squeeze(), cmap="gray")
        ax[i][j].axis("off")
        model.trace(False)
recon_img_list = torch.cat(recon_img_list, dim=0)
recon_img_list = recon_img_list.cpu().numpy()
recon_img_list = recon_img_list.reshape(64, -1)
plt.savefig(f"./results/decompose/all_decom-{suffix}.png", dpi=500)

# img_show_tmp = np.matmul(chm_topk_v.detach().cpu().numpy(), recon_img_list)
# img_show_tmp = img_show_tmp.reshape(32, 32)
# plt.figure()
# plt.imshow(img_show_tmp)

# Kmeans clustering
from sklearn.cluster import KMeans
n_clusters = 8
n_rows = 4
n_cols = 2
assert n_rows * n_cols == n_clusters
kmeans = KMeans(n_clusters=n_clusters).fit(recon_img_list)
indices_dict = {}
for i in range(n_clusters):
    indices_dict[i] = np.where(kmeans.labels_ == i)

fig, ax = plt.subplots(n_rows, n_cols, figsize=(n_cols*5, n_rows*5))
for i in range(n_clusters):
    indices = indices_dict[i][0]
    tmp_x = torch.zeros_like(sample_x).view(bs, ch_n, -1)
    for j in indices:
        cur_chidx = chm_topk.flatten()[j]
        cur_spidx = max_idx.flatten()[cur_chidx]
        tmp_x[0][cur_chidx][cur_spidx] = 1000 # / np.sqrt(len(indices))
    model.trace(True)
    recon_img = model.trace_back(x=tmp_x.view(bs, ch_n, h, w), module_name="act3")
    cur_ax = ax[i//n_cols][i%n_cols]
    cur_ax.imshow(normalize(recon_img.squeeze().detach().cpu().numpy()), cmap="gray")
    cur_ax.set_title(f"Count: {len(indices)}", fontsize=30)
    cur_ax.axis("off")
    model.trace(False)
plt.savefig(f"./results/decompose/kmeans-{suffix}.png", dpi=500)

#%% Generate adversarial samples
trg_idx = 0
src_idx = 3
trg_x = all_hidden_feature.clone()[trg_idx:(trg_idx+1)]
src_x = all_hidden_feature.clone()[src_idx:(src_idx+1)]

def get_stats(x):
    n_ch = x.size(1)
    x = x.reshape(n_ch, -1)
    mean = x.mean(dim=-1)
    pos = torch.maximum(x - mean.unsqueeze(1), torch.zeros_like(x)).max(dim=-1)[0]
    neg = torch.maximum(mean.unsqueeze(1) - x, torch.zeros_like(x)).max(dim=-1)[0]
    return mean.view(1,n_ch,1,1), pos.view(1,n_ch,1,1), neg.view(1,n_ch,1,1)

### Version 3
def transform(src_x, trg_x, topk_n=512):
    chm_idx = trg_x.view(bs, ch_n, -1).max(dim=-1)[0].topk(k=topk_n, largest=True)[1]
    _, max_i = src_x.view(bs, ch_n, -1).max(dim=-1, keepdim=True)
    out = torch.zeros_like(src_x.view(bs, ch_n, -1))
    out.scatter_(2, max_i, 1)
    mask_ch = torch.zeros(bs, ch_n).to(src_x.device)
    mask_ch.scatter_(1, chm_idx, 1)
    out = out * mask_ch.unsqueeze(-1)
    out = out.view(bs, ch_n, h, w) * 200
    return out

def leave_max(x):
    x = x.view(bs, ch_n, -1)
    max_v, max_i = x.max(dim=-1, keepdim=True)
    out = torch.zeros_like(x)
    out.scatter_(2, max_i, max_v)
    out = out.view(bs, ch_n, h, w) * 200
    return out

def show_hist(x1, x2, idx_list):
    for i in range(len(idx_list)):
        fig, ax = plt.subplots(1, 2, figsize=(12, 2))
        sns.histplot(x1[0][i].flatten().detach().cpu(), ax=ax[0])
        sns.histplot(x2[0][i].flatten().detach().cpu(), ax=ax[1])

def predict(model, data):
    model.trace(False)
    output = model(data)
    output = output.flatten()
    predict = output.max(dim=-1)[1]
    return predict, output.reshape(1, 10)

def gen_bar(data, title=""):
    sns.set_theme(style="white", context="talk")
    rs = np.random.RandomState(8)

    # Set up the matplotlib figure
    fig, ax = plt.subplots(1, 1, figsize=(5, 3), dpi=500)

    # Generate some sequential data
    x = np.array(list("0123456789"))
    y1 = data.detach().cpu().numpy()
    sns.barplot(x=x, y=y1, palette="rocket", ax=ax)
    ax.axhline(0, color="k", clip_on=False)
    ax.set_ylabel("Class score", fontsize=20)
    ax.tick_params(labelsize=20)

    # Finalize the plot
    sns.despine(bottom=True)
    plt.tight_layout(h_pad=2)
    plt.savefig(f"./results/decompose/{title}-bar.png", dpi=500)

def vis_prediction(x, title=""):
    model.trace(True)
    recon_img = model.trace_back(x=x, module_name="act3")
    pred, output = predict(model, recon_img)
    show_image(recon_img, title=f"src-x: {pred.item()}-{title}")
    # _, ax = plt.subplots(1, 1, figsize=(10,1))
    # sns.heatmap(output.detach().cpu(), annot=True, linewidths=0.5, fmt=".2f", ax=ax)
    gen_bar(output.squeeze(), title=f"src-x: {pred.item()}-{title}")
    model.trace(False)

# original image
vis_prediction(leave_max(src_x), "srcimg")

# target image
vis_prediction(leave_max(trg_x), "trgimg")

# adversarial image
src_x_trans = transform(src_x, trg_x)
trans_out = model.fc(src_x_trans.view(bs, ch_n, -1).max(dim=-1)[0])
print(trans_out)
vis_prediction(src_x_trans, "trasimg")

#%%

tmp_x = torch.zeros_like(x)[0:1]
tmp_x[0][2][2][2] = 4000
tmp_x[0][2][3][3] = 4000
tmp_x[0][2][4][4] = 4000
model.trace(True)
recon_img = model.trace_back(
    x=tmp_x,
    module_name="act3"
)
show_image(recon_img)
plt.figure()
sns.heatmap(tmp_x[0][2].cpu())
model.trace(False)