import os
import re
import torch
import fnmatch
import warnings
import argparse
import itertools
import torchvision
import matplotlib
matplotlib.use('Agg')
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import transforms
unloader = transforms.ToPILImage()
from scipy.interpolate import interp1d
warnings.filterwarnings('ignore', category=UserWarning)

from utils import *
from yolo2 import utils
from cfg import get_cfgs
from yolo2 import load_data
from load_models import load_models   
from yolov5.models.common import DetectMultiBackend


def truths_length(truths):
    for i in range(50):
        if truths[i][1] == -1:
            return i

def label_filter(truths, labels=None):
    if labels is not None:
        new_truths = truths.new(truths.shape).fill_(-1)
        c = 0
        for t in truths:
            if t[0].item() in labels:
                new_truths[c] = t
                c = c + 1
        return new_truths


def test(model, loader, adv_cloth=None, gan=None, z=None, type=None, conf_thresh=0.5, nms_thresh=0.4, iou_thresh=0.5, num_of_samples=100,
         old_fasion=True):
    model.eval()
    total = 0.0
    proposals = 0.0
    correct = 0.0
    batch_num = len(loader)   
    with torch.no_grad():
        positives = []
        for batch_idx, (data, target) in tqdm(enumerate(loader), total=batch_num, position=0):
            data = data.to(device)
           
            target = target.to(device)
            adv_batch_t = patch_transformer(adv_cloth, target, args.img_size, do_rotate=True, rand_loc=False,
                                            pooling=args.pooling, old_fasion=old_fasion)
            data = patch_applier(data, adv_batch_t)  
            output = model(data)
            all_boxes = utils.get_region_boxes_general(False, output, model, conf_thresh, kwargs['name'])
            for i in range(len(all_boxes)):
                boxes = all_boxes[i]
                boxes = utils.nms(boxes, nms_thresh)
                truths = target[i].view(-1, 5)
                truths = label_filter(truths, labels=[7])
                num_gts = truths_length(truths)
                truths = truths[:num_gts, 1:]
                truths = truths.tolist()
                total = total + num_gts
                for j in range(len(boxes)):
                    if boxes[j][6].item() == 7:
                        best_iou = 0
                        best_index = 0
                        for ib, box_gt in enumerate(truths):
                            iou = utils.bbox_iou(box_gt, boxes[j], x1y1x2y2=False)
                            if iou > best_iou:
                                best_iou = iou
                                best_index = ib
                        if best_iou > iou_thresh:
                            del truths[best_index]
                            positives.append((boxes[j][4].item(), True))
                        else:
                            positives.append((boxes[j][4].item(), False))

        positives = sorted(positives, key=lambda d: d[0], reverse=True)
        tps = []
        fps = []
        confs = []
        tp_counter = 0
        fp_counter = 0
        for pos in positives:
            if pos[1]:
                tp_counter += 1
            else:
                fp_counter += 1
            tps.append(tp_counter)
            fps.append(fp_counter)
            confs.append(pos[0])
        precision = []
        recall = []
        for tp, fp in zip(tps, fps):
            recall.append(tp / total)
            precision.append(tp / (fp + tp))

    if len(precision) > 1 and len(recall) > 1:
        p = np.array(precision)
        r = np.array(recall)
        p_start = p[np.argmin(r)]
        samples = np.arange(0., 1., 1.0 / num_of_samples)
        interpolated = interp1d(r, p, fill_value=(p_start, 0.), bounds_error=False)(samples)
        avg = sum(interpolated) / len(interpolated)
    elif len(precision) > 0 and len(recall) > 0:
        avg = precision[0] * recall[0]
    else:
        avg = float('nan')
    return precision, recall, avg, confs


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PyTorch Training')
    parser.add_argument('--net', default='yolov2', help='')
    parser.add_argument('--method', default='TCA', help='')
    parser.add_argument('--suffix', default=None, help='suffix name')
    parser.add_argument('--gen_suffix', default=None, help='generator suffix name')
    parser.add_argument('--device', default='cuda:0', help='')
    parser.add_argument('--prepare_data', action='store_true', help='')
    parser.add_argument('--epoch', type=int, default=200, help='')
    parser.add_argument('--load_path', default=None, help='')
    parser.add_argument('--npz_dir', default=None, help='')
    pargs = parser.parse_args()

    args, kwargs = get_cfgs(pargs.net, pargs.method, 'test')    
    if pargs.epoch is not None:
        args.n_epochs = pargs.epoch
    if pargs.suffix is None:
        pargs.suffix = pargs.net + '_' + pargs.method

    device = torch.device(pargs.device)                
  
    darknet_model = load_models(**kwargs)               
    darknet_model = darknet_model.eval().to(device)    
    class_names = utils.load_class_names('./data/coco.names')

    patch_applier = load_data.PatchApplier().to(device)
    patch_transformer = load_data.PatchTransformer(args).to(device)

    if pargs.prepare_data:
        print("It is time to prepare data")
        conf_thresh = 0.5
        nms_thresh = 0.4
        img_ori_dir = './data/INRIAPerson/Test/pos'
        img_dir = './data/images'                    
        lab_dir = './data/labels'   
        data_nl = load_data.InriaDataset(img_ori_dir, None, kwargs['max_lab'], args.img_size, shuffle=False)
        loader_nl = torch.utils.data.DataLoader(data_nl, batch_size=args.batch_size, shuffle=False, num_workers=10)
        if lab_dir is not None:
            if not os.path.exists(lab_dir):
                os.makedirs(lab_dir)
        if img_dir is not None:
            if not os.path.exists(img_dir):
                os.makedirs(img_dir)

        print('preparing the test data')
        with torch.no_grad():
            for batch_idx, (data, labs) in tqdm(enumerate(loader_nl), total=len(loader_nl)):
                data = data.to(device)
                output = darknet_model(data)
                all_boxes = utils.get_region_boxes_general(False, output, darknet_model, conf_thresh, kwargs['name'])
                for i in range(data.size(0)):
                    boxes = all_boxes[i]
                    boxes = utils.nms(boxes, nms_thresh)
                    new_boxes = boxes[:, [6, 0, 1, 2, 3]]
                    new_boxes = new_boxes[new_boxes[:, 0] == 0]
                    # new_boxes = new_boxes[new_boxes[:, 0] == 7]
                    new_boxes = new_boxes.detach().cpu().numpy()
                    if lab_dir is not None:
                        save_dir = os.path.join(lab_dir, labs[i])
                        np.savetxt(save_dir, new_boxes, fmt='%f')
                        img = unloader(data[i].detach().cpu())
                    if img_dir is not None:
                        save_dir = os.path.join(img_dir, labs[i].replace('.txt', '.png'))
                        img.save(save_dir)
        print('preparing done')

   
    img_dir_test = ''                 
    lab_dir_test = ''  
    test_data = load_data.InriaDataset(img_dir_test, lab_dir_test, kwargs['max_lab'], args.img_size, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=10)
    loader = test_loader
    epoch_length = len(loader)
    print(f'One epoch is {len(loader)}')

    if pargs.npz_dir is None:

        if pargs.method == 'RCA' or pargs.method == 'TCA':
            if pargs.load_path is None:
              
                img_path = ''
            else:
                img_path = pargs.load_path
            cloth = torch.from_numpy(np.load(img_path)[0]).to(device)
            test_cloth = cloth.detach().clone()
            test_gan = None
            test_z = None
            test_type = 'patch'
        else:
            raise ValueError
  
        save_dir = 'test_results'
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        save_path = os.path.join(save_dir, pargs.suffix)
        plt.figure(figsize=[15, 10])
        prec, rec, ap, confs = test(darknet_model, test_loader, adv_cloth=test_cloth, gan=test_gan, z=test_z, type=test_type, conf_thresh=0.01, old_fasion=kwargs['old_fasion'])
        np.savez(save_path, prec=prec, rec=rec, ap=ap, confs=confs, adv_patch=cloth.detach().cpu().numpy())
        print('AP is %.4f'% ap)
        plt.plot(rec, prec)
        leg = [pargs.suffix + ': ap %.3f' % ap]
        unloader(cloth[0]).save(save_path + '.tif')
    else:
        files = fnmatch.filter(os.listdir(pargs.npz_dir), '*.npz')
        order = {'RCA': 0, 'TCA': 1, 'EGA': 2, 'TCEGA': 3}
        files.sort()
        files.sort(key=lambda x: order[re.search('(RCA)|(TCA)|(EGA)|(TCEGA)', x).group()] if re.search('(RCA)|(TCA)|(EGA)|(TCEGA)', x) is not None else 1e5)
        leg = []
        for file in files:
            save_path = os.path.join(pargs.npz_dir, file)
            save_data = np.load(save_path, allow_pickle=True)
            save_data = save_data.values()
            prec, rec, ap, confs, clothi = list(save_data)
            plt.plot(rec, prec)
            leg.append(file.replace('.npz', '') + ', ap: %.3f' % ap)
            unloader(torch.from_numpy(clothi[0])).save(save_path.replace('.npz', '.png'))
        save_dir = pargs.npz_dir
    plt.plot([0, 1], [0, 1], 'k--')
    plt.legend(leg, loc=4)
    plt.title('PR-curve')
    plt.ylabel('Precision')
    plt.xlabel('Recall')
    plt.ylim([0, 1.05])
    plt.xlim([0, 1.05])
    plt.savefig(os.path.join(save_dir, 'PR-curve.png'), dpi=300)