import os
import argparse
import time
import os.path as osp
import sys
import torch.nn.functional as F
import shutil
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.parallel
import torchvision
from tensorboardX import SummaryWriter
from torchvision import transforms
from termcolor import cprint
from lib import dataloader
from model import resnet_snl_test, preresnet_snl
import cv2
from PIL import Image, ImageFile
from utils.loggers import Logger
from label import get_label
import matplotlib.pyplot as plt

def visualize_feature(feature, input):
    b = input.size(0)
    norm_mean = [0.485, 0.456, 0.406]
    norm_std = [0.229, 0.224, 0.225]
    feature = feature.sum(dim=1)
    #print(feature.size())
    cur_img = input[0, :, :, :].clone()
    _, h, w = cur_img.size()
    for k in range(0, 3):
        cur_img[k, :, :] = cur_img[k, :, :] * norm_std[k] + norm_mean[k]
    cur_img = cur_img.permute(1, 2, 0)
    cur_img = np.uint8(cur_img.cpu().numpy() * 255.0)
    cur_featuremap = feature[0, :, :].cpu().numpy()
    heatmap = cv2.resize(cur_featuremap, (cur_img.shape[1], cur_img.shape[0]))
    amin, amax = heatmap.min(), heatmap.max()
    heatmap = (heatmap-amin)/(amax-amin)
    heatmap = np.uint8(heatmap * 255.0)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    superimposed_img = np.zeros(cur_img.shape)
    superimposed_img[:,:,0] = heatmap[:,:,0] * 0.8 + cur_img[:,:,0] * 0.2
    superimposed_img[:,:,1] = heatmap[:,:,1] * 0.8 + cur_img[:,:,1] * 0.2
    superimposed_img[:,:,2] = heatmap[:,:,2] * 0.8 + cur_img[:,:,2] * 0.2
    return superimposed_img

def visualize_attention(feature, input, tmp):
    b = input.size(0)
    N = feature.size(1)
    d_h = tmp.size(2)
    d_w = tmp.size(3)
    norm_mean = [0.485, 0.456, 0.406]
    norm_std = [0.229, 0.224, 0.225]
    attention_map = {}
    #feature = feature.sum(dim=1)
    #print(feature.size())
    cur_img = input[0, :, :, :].clone()
    _, h, w = cur_img.size()
    for k in range(0, 3):
        cur_img[k, :, :] = cur_img[k, :, :] * norm_std[k] + norm_mean[k]
    cur_img = cur_img.permute(1, 2, 0)
    cur_img = np.uint8(cur_img.cpu().numpy() * 255.0)
    for n in range(0, N):
        cur_featuremap = feature[0, n, :].cpu().numpy()
        heatmap = cv2.resize(cur_featuremap, (cur_img.shape[1], cur_img.shape[0]))
        mask = cur_featuremap.copy()
        mask[:] = 0
        mask[n] = 1
        mask = mask.reshape([d_h, d_w])
        cur_featuremap = cur_featuremap.reshape([d_h, d_w])
        mask = cv2.resize(mask, (cur_img.shape[1], cur_img.shape[0]))
        heatmap = cv2.resize(cur_featuremap, (cur_img.shape[1], cur_img.shape[0]))
        #print(heatmap)
        amin, amax = heatmap.min(), heatmap.max()
        heatmap = (heatmap-amin)/(amax-amin)

        heatmap = np.uint8(heatmap * 255.0)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        superimposed_img = np.zeros(cur_img.shape)
        superimposed_img[:,:,0] = heatmap[:,:,0] * 0.3 + cur_img[:,:,0] * 0.6
        superimposed_img[:,:,1] = heatmap[:,:,1] * 0.3 + cur_img[:,:,1] * 0.6
        superimposed_img[:,:,2] = heatmap[:,:,2] * 0.3 + cur_img[:,:,2] * 0.6

        #if not os.path.exists(os.path.join("result/attention_map")):
        #    os.makedirs(os.path.join("result/attention_map"))

        r = superimposed_img[:,:,0]
        g = superimposed_img[:,:,1]
        b = superimposed_img[:,:,2]
        r[mask != 0] = 255
        g[mask != 0] = 0
        b[mask != 0] = 0
        superimposed_img[:,:,0] = r
        superimposed_img[:,:,1] = g
        superimposed_img[:,:,2] = b
        attention_map[n] = superimposed_img
    return attention_map
