from vit_pytorch import SimpleViT
from mlp_mixer_pytorch import MLPMixer
from models.resnet import resnet50, CIFAR_ResNet18
from torchvision.models import efficientnet_v2_s, convnext_tiny
from dataset import load_dataset
from torchvision.utils import save_image
from captum.attr import (
    DeepLift,
    IntegratedGradients
)
import cv2
import torchvision.utils as vutils
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision.utils as vutils
import argparse
import numpy as np
import os
import random
import math
import time
torch.set_printoptions(profile="full")


def save_cam_image(img, mask, filename):
    heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    img = img - np.min(img)
    img = img / np.max(img)
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    cv2.imwrite(filename, np.uint8(255 * cam))

def save_image2(img, filename):
    img = img - np.min(img)
    img = img / np.max(img)
    cv2.imwrite(filename, np.uint8(255 * np.float32(img)))

def _init_fn(worker_id):
    np.random.seed(args.seed+worker_id)


parser = argparse.ArgumentParser()
parser.add_argument("--eval_bnum", type=int)
parser.add_argument("--gpu", default=7, type=int)
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--name", default="c10_ss1_r18", type=str)
parser.add_argument("--dataset", default="cifar10", type=str, help="imagenet | cifar10 | Caltech256")
parser.add_argument("--dataroot", default="./data", type=str, help="data dir")
parser.add_argument("--model", default="CIFAR_ResNet18", type=str, help="CIFAR_ResNet18 | resnet50 | eff2s")
args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

train_loader, test_loader, dim, hw = load_dataset(args.dataset, args.dataroot, bs=1, worker_init_fn=_init_fn, drop_last=False, subset_num=1, seed=args.seed)
device = torch.device("cuda:%d"%args.gpu)

remove = 0
in_channels = 3
models = [] 

def load_models(dim, remove, in_channels, name, device):
    global args
    if args.dataset == "Caltech256":
        patch_size, Dim, depth, heads, mlp_dim = 32, 1024, 6, 16, 2048
        m_patch_size, m_Dim, m_depth = 16, 512, 12
    else:
        patch_size, Dim, depth, heads, mlp_dim = 4, 512, 6, 8, 512
        m_patch_size, m_Dim, m_depth = 4, 512, 6
    if args.model == "CIFAR_ResNet18":
        model = CIFAR_ResNet18(pretrained=False, num_classes=dim, remove=remove, in_channels=in_channels, retain_grad=False)
    elif args.model == "resnet50":
        model = resnet50(pretrained=False, num_classes=dim, remove=remove, in_channels=in_channels, retain_grad=False)
    elif args.model == "eff2s":
        model = efficientnet_v2_s(num_classes=dim, pretrained=False)
    elif args.model == "convnext":
        model = convnext_tiny(num_classes=dim, pretrained=False)
    elif args.model == "vit":
        model = SimpleViT(image_size=hw, patch_size=patch_size, num_classes=dim, dim=Dim, depth=depth, heads=heads, mlp_dim=mlp_dim)
    elif args.model == "mlpm":
        model = MLPMixer(image_size=hw, channels=in_channels, patch_size=m_patch_size, dim=m_Dim, depth=m_depth, num_classes=dim)
    else:
        raise NotImplementedError
    model.load_state_dict(torch.load("./results/%s/best.pt"%name, map_location=device)["model"])
    return model

names, models = args.name.split(","), []
for name in names:
    models.append(load_models(dim, remove, in_channels, name, device))
os.makedirs("./ig_viz/%s/"%names[0], exist_ok=True)
print(args)
cpu = torch.device("cpu")

def f(loader, device, test=False):
    global args, hw, names, models, cpu
    NG = [0, 1, 5, 20, 50]
    STDs = [0, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0]
    T = "test" if test else "train"
    for i, (images, labels) in enumerate(loader):
        if i > args.eval_bnum:
            break
        print(i, args.eval_bnum)
        save_file = "./ig_viz/%s/%s_%d.jpg"%(names[0], T, i)
        save_image2(np.transpose(images[0].detach().cpu().numpy(), (1,2,0)), save_file)
        images, labels = images.to(device), labels.to(device)
        bb, cc, hh, ww = images.size()
        for kk, model in enumerate(models):
            model = model.eval().to(device)
            ig = IntegratedGradients(model)
            dl = DeepLift(model)
            for ng in NG:
                for std in STDs:
                    if ng != 0:
                        noise = np.concatenate([np.random.normal(0, std, (ng, cc, hh, ww)).astype(np.float32) for n in range(bb)])
                        imgs = torch.tensor(images.repeat(ng, 1, 1, 1).to(device) + torch.from_numpy(noise).to(device), device=device, requires_grad=True)
                    else:
                        imgs = torch.tensor(images, device=device, requires_grad=True)
                    class_idx = model(imgs).max(1).indices
                    attributions, delta = ig.attribute(imgs, torch.zeros_like(imgs, device=device), target=class_idx, n_steps=20, return_convergence_delta=True)
                    if ng != 0:
                        attributions = attributions.reshape(ng, bb, cc, hh, ww).mean(0).abs().max(1).values.detach().cpu()
                    else:
                        attributions = attributions.abs().max(1).values.detach().cpu()
                    vutils.save_image(attributions, "./ig_viz/%s/T%d_l%d_M%d_ng%d_std%.1f.png"%(names[0], int(test), i, kk, ng, std), normalize=True)
                    attributions, delta = dl.attribute(imgs, torch.zeros_like(imgs, device=device), target=class_idx,return_convergence_delta=True)
                    if ng != 0:
                        attributions = attributions.reshape(ng, bb, cc, hh, ww).mean(0).abs().max(1).values.detach().cpu()
                    else:
                        attributions = attributions.abs().max(1).values.detach().cpu()
                    vutils.save_image(attributions, "./ig_viz/%s/dl_T%d_l%d_M%d_ng%d_std%.1f.png"%(names[0], int(test), i, kk, ng, std), normalize=True)
            model = model.to(cpu)

f(train_loader, device, test=False)
f(test_loader, device, test=True)
