import os
import argparse
import time
import os.path as osp
import sys
import torch.nn.functional as F
import shutil
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.parallel
import torchvision

from tensorboardX import SummaryWriter
from torchvision import transforms
from termcolor import cprint
from lib import dataloader
from model import resnet_snl_test, preresnet_snl
import cv2
from PIL import Image, ImageFile

from utils.loggers import Logger

from label import get_label

# torch version
cprint('=> Torch Vresion: ' + torch.__version__, 'green')

# args
parser = argparse.ArgumentParser(description='PyTorch Training')
parser.add_argument('--debug', '-d', dest='debug', action='store_true',
        help='enable debug mode')

parser.add_argument('--num_gpu', default=1, type=int,
        help='number of gpu')

parser.add_argument('--model-path', default ='./imagenet_snl.tar', type = str,
        help = 'the trained model path')

parser.add_argument('--img-dir', default ='./img', type = str,
help = 'the trained model path')

parser.add_argument('--name', default ='2_bluetick.JPEG', type = str,
help = '1_drake.JPEG, 2_bluetick.JPEG, 3_agaric.JPEG, 4_cabbage_butterfly.JPEG, 5_croquet ball.JPEG, 6_tick.JPEG, 7_loafer.JPEG')

best_prec1 = 0
best_prec5 = 0

args = parser.parse_args()


if args.num_gpu == 1:
    torch.cuda.set_device(0)

np.random.seed(67)
torch.manual_seed(67)


def main():
    global args
    global best_prec1, best_prec5
    global checkpoint_fold, checkpoint_best

    label = get_label()
    #print(label[0])


    base_size = 256
    crop_size = 224
    workers = 4

    trans = transforms.Compose([
                        transforms.Resize(base_size),
                        transforms.CenterCrop(crop_size),
                        transforms.ToTensor(),
                        transforms.Normalize(
                            [0.485, 0.456, 0.406],
                            [0.229, 0.224, 0.225])
                        ])
#####################################################
    #img_dir = 'img'
    #name = '3.jpg'
    print(args.name)
    img_name = args.img_dir + '/' + args.name
    save_name = args.name[:-5]
    
    ImageFile.LOAD_TRUNCATED_IMAGES = True
    img = Image.open(img_name).convert('RGB')
    #print(img.numpy)
    img_torch = trans(img)
    img_torch = img_torch.unsqueeze(0)
#####################################################
    model = resnet_snl_test.model_hub('50',
                                pretrained=False,
                                nl_type='snl',
                                nl_nums=1,
                                stage_num=1,
                                pool_size=7, div=2, isrelu=False)
####################################################
    #print(model)
    # parallel
    #if args.num_gpu > 1:
    #    model = torch.nn.DataParallel(model, device_ids=gpu_ids).cuda()
    #else:
    model = model.cuda()

    print('loading checkpoint {}'.format(args.model_path))
    checkpoint = torch.load(args.model_path, map_location="cuda:0")
    model.load_state_dict(checkpoint['state_dict'], strict=True)
    model.eval()
    with torch.no_grad():
        input = img_torch
        input = input.cuda()
        start = time.time()
        output, x_1, x_2, att, x_in, x_out, final = model(input)
        end = time.time()
        visualize_feature(final, input, 0, save_name) 
        visualize_attention(att, input, 0, final, save_name)
        print("Time:", end - start)
        _, pred = output.topk(1, 1, True, True)
        pred = pred.t().cpu().numpy()
        cur_pred = int(pred[:, 0])
        print("Top1 Prediction:", label[cur_pred])
        logit = F.softmax(output, dim=1)
        cur_logit = logit[:, cur_pred].cpu().numpy()
        cur_logit = cur_logit[0]
        print("Probability:", str(cur_logit * 100) + "%")
    exit(0)


def visualize_attention(feature, input, num, tmp, val_type='attention'):
    b = input.size(0)
    N = feature.size(1)
    d_h = tmp.size(2)
    d_w = tmp.size(3)
    norm_mean = [0.485, 0.456, 0.406]
    norm_std = [0.229, 0.224, 0.225]
    #feature = feature.sum(dim=1)
    #print(feature.size())
    for b_i in range(0, b):
        cur_img = input[b_i, :, :, :].clone()
        _, h, w = cur_img.size()
        for k in range(0, 3):
            cur_img[k, :, :] = cur_img[k, :, :] * norm_std[k] + norm_mean[k]
        cur_img = cur_img.permute(1, 2, 0)
        cur_img = np.uint8(cur_img.cpu().numpy() * 255.0)
        for n in range(0, N):
            cur_featuremap = feature[b_i, n, :].cpu().numpy()
            heatmap = cv2.resize(cur_featuremap, (cur_img.shape[1], cur_img.shape[0]))
            mask = cur_featuremap.copy()
            mask[:] = 0
            mask[n] = 1
            mask = mask.reshape([d_h, d_w])
            cur_featuremap = cur_featuremap.reshape([d_h, d_w])
            mask = cv2.resize(mask, (cur_img.shape[1], cur_img.shape[0]))
            heatmap = cv2.resize(cur_featuremap, (cur_img.shape[1], cur_img.shape[0]))
            #print(heatmap)
            amin, amax = heatmap.min(), heatmap.max() # 求最大最小值
            heatmap = (heatmap-amin)/(amax-amin) # (矩阵元素-最小值)/(最大值-最小值)

            heatmap = np.uint8(heatmap * 255.0)
            heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
            superimposed_img = np.zeros(cur_img.shape)
            superimposed_img[:,:,0] = heatmap[:,:,0] * 0.3 + cur_img[:,:,0] * 0.6
            superimposed_img[:,:,1] = heatmap[:,:,1] * 0.3 + cur_img[:,:,1] * 0.6
            superimposed_img[:,:,2] = heatmap[:,:,2] * 0.3 + cur_img[:,:,2] * 0.6

            if not os.path.exists(os.path.join("result/attention_map")):
                os.makedirs(os.path.join("result/attention_map"))

            r = superimposed_img[:,:,0]
            g = superimposed_img[:,:,1]
            b = superimposed_img[:,:,2]
            r[mask != 0] = 255
            g[mask != 0] = 0
            b[mask != 0] = 0
            superimposed_img[:,:,0] = r
            superimposed_img[:,:,1] = g
            superimposed_img[:,:,2] = b
            cv2.imwrite(os.path.join("result/attention_map", val_type + "_" + str(n) + ".jpg"), superimposed_img)


def visualize_feature(feature, input, num, val_type='query'):
    b = input.size(0)
    norm_mean = [0.485, 0.456, 0.406]
    norm_std = [0.229, 0.224, 0.225]
    feature = feature.sum(dim=1)
    #print(feature.size())
    if not os.path.exists(os.path.join("result/featuremaps")):
        os.makedirs(os.path.join("result/featuremaps"))
    for b_i in range(0, b):
        cur_img = input[b_i, :, :, :].clone()
        _, h, w = cur_img.size()
        for k in range(0, 3):
            cur_img[k, :, :] = cur_img[k, :, :] * norm_std[k] + norm_mean[k]
        cur_img = cur_img.permute(1, 2, 0)
        cur_img = np.uint8(cur_img.cpu().numpy() * 255.0)
        cur_featuremap = feature[b_i, :, :].cpu().numpy()
        heatmap = cv2.resize(cur_featuremap, (cur_img.shape[1], cur_img.shape[0]))
        amin, amax = heatmap.min(), heatmap.max() # 求最大最小值
        heatmap = (heatmap-amin)/(amax-amin) # (矩阵元素-最小值)/(最大值-最小值)
        heatmap = np.uint8(heatmap * 255.0)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        superimposed_img = np.zeros(cur_img.shape)
        superimposed_img[:,:,0] = heatmap[:,:,0] * 0.8 + cur_img[:,:,0] * 0.2
        superimposed_img[:,:,1] = heatmap[:,:,1] * 0.8 + cur_img[:,:,1] * 0.2
        superimposed_img[:,:,2] = heatmap[:,:,2] * 0.8 + cur_img[:,:,2] * 0.2
        cv2.imwrite(os.path.join("result/featuremaps", val_type + ".jpg"), superimposed_img)



if __name__ == '__main__':
    main()
