import argparse
import os
import random
import shutil
import time
import warnings

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter
from attribution_methods.our.model import resnet,vgg
from attribution_methods.our import networks as nw
import numpy as np
import torch.nn.functional as F

os.environ["CUDA_VISIBLE_DEVICES"] = "5"
ROOT='/data_SSD2/zgh/workspace/data/ImageNet'
val_root=ROOT+'/val'
bbox_root=ROOT+'/val_bbox'

def main():
    return 0

def read_bbox(bbox_path):
    try:
        import xml.etree.cElementTree as ET
    except ImportError:
        import xml.etree.ElementTree as ET
    tree = ET.parse(bbox_path)
    root = tree.getroot()
    bbox_num=0
    width=0
    height=0
    xmin=0
    xmax=0
    ymin=0
    ymax=0

    for node in list(root):
        #print(node.tag)
        if node.tag == "size":
            for child in node:
                if child.tag=='width':
                    width=int(child.text)
                elif child.tag=='height':
                    height=int(child.text)
        if node.tag=='object':
            bbox_num+=1
            if bbox_num>1:
                return None
            for bbox in node:
                if bbox.tag=='bndbox':
                    for child in bbox:
                        #print(child.tag)
                        if child.tag == 'xmin':
                            xmin = int(child.text)
                        elif child.tag == 'ymin':
                            ymin = int(child.text)
                        elif child.tag == 'xmax':
                            xmax = int(child.text)
                        elif child.tag == 'ymax':
                            ymax = int(child.text)
    #print(width, height, xmax, xmin, ymax, ymin)
    if (xmax-xmin)*(ymax-ymin)>=width*height*0.5:
        return None
    return width,height,xmax,xmin,ymax,ymin


import cv2
from PIL import Image
import matplotlib.pyplot as plt
def show_cam_on_image(imgs, masks, file_name="result/cam1.jpg"):
    """
    展示最终结果，这里只做了单个图片的展示
    :param imgs: imgs ,BHWC
    :param masks: attributions , BWC
    :param file_name: 最后结果保存的文件
    """
    imgs=imgs.transpose(1,2,0)
    imgs=imgs-imgs.min()
    imgs=imgs/imgs.max()
    plt.imshow(imgs)
    plt.imshow(masks, cmap="jet", alpha=0.5)
    plt.savefig(file_name)
    plt.close()
def read_imgs(datadir=val_root,bboxdir=bbox_root):
    """
    read images from given folder
    :param path: the folder path
    :return: images, BHWC
    """
    means = [0.485, 0.456, 0.406]
    stds = [0.229, 0.224, 0.225]
    img_size = 224
    images = []
    targets=[]
    bboxes=[]

    dirs = os.listdir(datadir)
    dirs.sort()
    show=True
    for index,fn in enumerate(dirs):
        if index==1000:
            break
        img_dir = datadir + '/' + fn
        img_names=os.listdir(img_dir)
        print(index,fn)
        for img in img_names:
            img_path=img_dir + '/' + img
            portion = os.path.splitext(img)
            bbox_path=bbox_root+'/' +portion[0]+'.xml'
            bbox_property=read_bbox(bbox_path)
            if bbox_property is None:
                continue
            else:
                width, height, xmax, xmin, ymax, ymin=bbox_property
            input_image = Image.open(img_path).convert('RGB')
            raw_width,raw_height=input_image.size
            if raw_width!=width or raw_height!=height:
                continue
            #print(img)
            xmin=int(xmin*img_size/width)
            ymin = int(ymin * img_size / height)
            xmax = int(xmax * img_size / width)
            ymax = int(ymax * img_size / height)
            preprocess = transforms.Compose([
                transforms.Resize(img_size),
                transforms.CenterCrop(img_size),
                transforms.ToTensor(),
                transforms.Normalize(mean=means, std=stds),
            ])
            input_tensor = preprocess(input_image)
            image = input_tensor.unsqueeze(0)
            images.append(image)
            targets.append(index)
            temp=torch.zeros_like(image)
            temp=temp.sum(1,keepdim=True)
            temp[:,:,ymin:ymax,xmin:xmax]=1
            bboxes.append(temp)
            if show:
                show=False
                show_cam_on_image(images[0].squeeze().numpy(), bboxes[0].squeeze().numpy())

    return images,targets,bboxes

def save_data(images,targets,bboxes):
    images=torch.cat(tuple(images)).numpy()
    targets=np.array(targets)
    bboxes=torch.cat(tuple(bboxes)).numpy()
    print(images.shape,targets.shape,bboxes.shape)
    np.save('images.npy',images)
    np.save('targets.npy', targets)
    np.save('bboxes.npy', bboxes)

if __name__ == '__main__':
    images,targets,bboxes=read_imgs()
    save_data(images,targets,bboxes)