import os
import re
import cv2
import torch
import fnmatch
import argparse
import warnings
import itertools
import matplotlib
matplotlib.use('Agg')
from tqdm import tqdm
from torchvision import transforms
unloader = transforms.ToPILImage()
warnings.filterwarnings('ignore',category=UserWarning)


from utils import *
from yolo2 import utils
from cfg import get_cfgs
from yolo2 import load_data
from load_models import load_models    


def truths_length(truths):
    for i in range(50):
        if truths[i][1] == -1:
            return i


def label_filter(truths, labels=None):
    if labels is not None:
        new_truths = truths.new(truths.shape).fill_(-1)
        c = 0
        for t in truths:
            if t[0].item() in labels:
                new_truths[c] = t
                c = c + 1
        return new_truths


def test(model, loader, adv_cloth=None, gan=None, z=None, type=None, conf_thresh=0.5, nms_thresh=0.4, iou_thresh=0.5, num_of_samples=100,
         old_fasion=True):
    model.eval()
    total = 0.0
    proposals = 0.0
    correct = 0.0
    batch_num = len(loader)   
    with torch.no_grad():
        K = 0
        for batch_idx, (data, target) in tqdm(enumerate(loader), total=batch_num, position=0):
            K +=1
            data = data.to(device)
            target = target.to(device)
            adv_batch_t = patch_transformer(adv_cloth, target, args.img_size, do_rotate=True, rand_loc=False,
                                            pooling=args.pooling, old_fasion=old_fasion)
            data = patch_applier(data, adv_batch_t)   
   
            output = model(data)
            all_boxes = utils.get_region_boxes_general(output, model, conf_thresh, kwargs['name'])
            bboxes = []
            for i in range(len(all_boxes)):
                boxes = all_boxes[i]
                bboxes.append(utils.nms(boxes, nms_thresh))
            add_img = torch.squeeze(data, 0)
          
            img_add = transforms.ToPILImage('RGB')(add_img)
            out_path_detection = os.path.join('test_show', '{}.jpg'.format(K))
            utils.plot_boxes(img_add,bboxes[0],out_path_detection, class_names=class_names)
    return True


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PyTorch Training')
    parser.add_argument('--net', default='yolov2', help='')
    parser.add_argument('--method', default='TCA', help='')
    parser.add_argument('--suffix', default=None, help='suffix name')
    parser.add_argument('--gen_suffix', default=None, help='generator suffix name')
    parser.add_argument('--device', default='cuda:0', help='')
    parser.add_argument('--prepare_data', default=False, action='store_true', help='')
    parser.add_argument('--epoch', type=int, default=200, help='')
    parser.add_argument('--load_path', default=None, help='')
    parser.add_argument('--load_path_z', default=None, help='')
    pargs = parser.parse_args()


    args, kwargs = get_cfgs(pargs.net, pargs.method, 'test') 
    if pargs.epoch is not None:
        args.n_epochs = pargs.epoch
    if pargs.suffix is None:
        pargs.suffix = pargs.net + '_' + pargs.method

    device = torch.device(pargs.device)               
    darknet_model = load_models(**kwargs)               
    darknet_model = darknet_model.eval().to(device)     
    class_names = utils.load_class_names('./data/coco.names')


    patch_applier = load_data.PatchApplier().to(device)
    patch_transformer = load_data.PatchTransformer().to(device)

    if pargs.prepare_data:
        conf_thresh = 0.5
        nms_thresh = 0.4
        img_ori_dir = './data/INRIAPerson/Test/pos'
        img_dir = './data/test_padded'                    
        lab_dir = './data/test_lab_%s' % kwargs['name']  
        data_nl = load_data.InriaDataset(img_ori_dir, None, kwargs['max_lab'], args.img_size, shuffle=False)
        loader_nl = torch.utils.data.DataLoader(data_nl, batch_size=args.batch_size, shuffle=False, num_workers=10)
        if lab_dir is not None:
            if not os.path.exists(lab_dir):
                os.makedirs(lab_dir)
        if img_dir is not None:
            if not os.path.exists(img_dir):
                os.makedirs(img_dir)
  
        print('preparing the test data')
        with torch.no_grad():
            for batch_idx, (data, labs) in tqdm(enumerate(loader_nl), total=len(loader_nl)):
                data = data.to(device)
                output = darknet_model(data)
                all_boxes = utils.get_region_boxes_general(output, darknet_model, conf_thresh, kwargs['name'])
                for i in range(data.size(0)):
                    boxes = all_boxes[i]
                    boxes = utils.nms(boxes, nms_thresh)
                    new_boxes = boxes[:, [6, 0, 1, 2, 3]]
                    new_boxes = new_boxes[new_boxes[:, 0] == 0]
                    new_boxes = new_boxes.detach().cpu().numpy()
                    if lab_dir is not None:
                        save_dir = os.path.join(lab_dir, labs[i])
                        np.savetxt(save_dir, new_boxes, fmt='%f')
                        img = unloader(data[i].detach().cpu())
                    if img_dir is not None:
                        save_dir = os.path.join(img_dir, labs[i].replace('.txt', '.png'))
                        img.save(save_dir)
        print('preparing done')


    img_dir_test = './data/test_padded'
    lab_dir_test = './data/test_lab_%s' % kwargs['name']
    test_data = load_data.InriaDataset(img_dir_test, lab_dir_test, kwargs['max_lab'], args.img_size, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False, num_workers=10)   # batch_Size= args.batch_size
    loader = test_loader
    epoch_length = len(loader)
    print(f'One epoch is {len(loader)}')



    if pargs.method == 'RCA' or pargs.method == 'TCA':
        if pargs.load_path is None:
            result_dir = './results/result_' + pargs.net + '_' + pargs.method
            img_path = os.path.join(result_dir, 'patch%d.npy' % args.n_epochs)
        else:
            img_path = pargs.load_path
        cloth = torch.from_numpy(np.load(img_path)[:1]).to(device)
        test_cloth = cloth.detach().clone()
        test_gan = None
        test_z = None
        test_type = 'patch'
    else:
        raise ValueError
    test(darknet_model, test_loader, adv_cloth=test_cloth, gan=test_gan, z=test_z, type=test_type, conf_thresh=0.8, old_fasion=kwargs['old_fasion'])