import argparse
import os, sys
import shutil
import time
import copy

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from torch.autograd import Variable
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models
import random
import numpy as np
from scipy.spatial import distance
from collections import OrderedDict
from tqdm import tqdm
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
import matplotlib.pyplot as plt
import cv2

from wrapper import Wrapper
import torchattacks

def visualize_resnet50_feature_maps(model, batch, target, output_dir, model_name='original'):
    assert isinstance(model, Wrapper)
    if not os.path.isdir(os.path.join(output_dir, "vis")):
        os.makedirs(os.path.join(output_dir, "vis"))
    feature_out = []
    def hook(module, feat_in, feat_out):
        feature_out.append(feat_out.detach().clone().cpu()[20])
    vis_indices = [0, 1, 5, 8, 11, 24, 43]
    current_idx = 0
    handles = []
    for m in list(model.backbone.modules()):
        if isinstance(m, nn.Conv2d):
            if current_idx in vis_indices:
                handle = m.register_forward_hook(hook)
                handles.append(handle)
            current_idx += 1
    
    batch = batch.cuda()
    target = target.cuda()
    model.backbone(batch)
    for i, feat in enumerate(feature_out):
        for c in range(feat.shape[0]):
            feat_img = feat[c].numpy()
            min_feat_val = np.min(feat_img)
            max_feat_val = np.max(feat_img)
            feat_img = (feat_img - min_feat_val) / (max_feat_val - min_feat_val)
            feat_img = (feat_img*255).astype('uint8')
            save_path = os.path.join(output_dir, "vis", f"{model_name}_layer{vis_indices[i]}_channel{c}.png")
            cv2.imwrite(save_path, feat_img)

    benign_feature_out = copy.deepcopy(feature_out)

    attack = torchattacks.MIFGSM(model.backbone, eps=16/255, alpha=1.6/255, steps=10, decay=1)
    normalize_mean = [0.485, 0.456, 0.406]
    normalize_std = [0.229, 0.224, 0.225]
    attack.set_normalization_used(normalize_mean, normalize_std)
    adv_batch = attack(batch, target)

    feature_out = []
    model.backbone(adv_batch)
    adv_feature_out = feature_out
    for i in range(len(benign_feature_out)):
        benign_feat = benign_feature_out[i].clone()
        adv_feat = adv_feature_out[i].clone()
        c, h, w = benign_feat.shape[0], benign_feat.shape[1], benign_feat.shape[2]
        diff = torch.norm(benign_feat-adv_feat, dim=(1,2)).numpy()
        max_diff_val = np.max(diff)
        min_diff_val = np.min(diff)
        diff = diff/max_diff_val
        if c==64:
            diff_image = np.reshape(diff, (8,8))
        elif c==128:
            diff_image = np.reshape(diff, (8, 16))
        elif c==256:
            diff_image = np.reshape(diff, (16, 16))
        elif c==512:
            diff_image = np.reshape(diff, (16, 32))
        else:
            raise NotImplementedError
        diff_image = np.repeat(diff_image, h, axis=0)
        diff_image = np.repeat(diff_image, w, axis=1)
        diff_image = (diff_image*255).astype('uint8')
        save_path = os.path.join(output_dir, "vis", f"{model_name}_layer{vis_indices[i]}_diff.png")
        cv2.imwrite(save_path, diff_image)

    for handle in handles:
        handle.remove()