import os
import torch
import warnings
import argparse
from tqdm import tqdm
from torchvision import transforms
unloader = transforms.ToPILImage()
warnings.filterwarnings('ignore', category=UserWarning)

from utils import *
from yolo2 import utils
from cfg import get_cfgs
from yolo2 import load_data
from load_models import load_models  

max_label_lth = 15
def pad_lab(label):
        padded_lab = []
        for lab in label:
            pad_size =  max_label_lth - lab.shape[0]
            padded_lab.append(F.pad(lab, (0, 0, 0, pad_size), value=-1) if pad_size > 0 else lab)
        return padded_lab[0]

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--net', default='yolov2', help='')
    parser.add_argument('--method', default='TCA', help='')
    parser.add_argument('--device', default='cuda:0', help='')
    parser.add_argument('--prepare_data', action='store_true', help='')
    parser.add_argument('--epoch', type=int, default=200, help='')
    parser.add_argument('--load_path', default=None, help='')
    parser.add_argument('--npz_dir', default=None, help='')
    pargs = parser.parse_args()

    args, kwargs = get_cfgs(pargs.net, pargs.method, 'test') 
    
    device = torch.device(pargs.device)              
    class_names = utils.load_class_names('./data/coco.names')

 
    patch_applier = load_data.PatchApplier().to(device)
    patch_transformer = load_data.PatchTransformer(args).to(device)


    if pargs.prepare_data:
        print("It is time to prepare data")
        conf_thresh = 0.5
        nms_thresh = 0.4
        img_ori_dir = ''
        label_ori_dir = ''
        img_dir = ''       
        lab_dir = '' 
        data_nl = load_data.add_InriaDataset(img_ori_dir, label_ori_dir, kwargs['max_lab'], args.img_size, shuffle=False)

        if lab_dir is not None:
            if not os.path.exists(lab_dir):
                os.makedirs(lab_dir)
        if img_dir is not None:
            if not os.path.exists(img_dir):
                os.makedirs(img_dir)
        print('preparing the test data')
        with torch.no_grad():
            for i in range(data_nl.__len__()):
                data, labs, labs_path = data_nl.__getitem__(i)
                if lab_dir is not None:
                    save_dir = os.path.join(lab_dir, labs_path.split("/")[-1])
                    # print(save_dir)
                    np.savetxt(save_dir, labs.detach().cpu().numpy(), fmt='%f')
                    img = unloader(data.detach().cpu())
                if img_dir is not None:
                    save_dir = os.path.join(img_dir, labs_path.split("/")[-1].replace('.txt', '.png'))
                    img.save(save_dir)
                    print(save_dir)
        print('preparing done')


    img_dir_test = ''    
    lab_dir_test = ''  
    data_nl = load_data.add_InriaDataset(img_dir_test, lab_dir_test, kwargs['max_lab'], args.img_size, shuffle=False)
    adv_path = ''
    adv_cloth = torch.from_numpy(np.load(adv_path)[0]).to(device)
    target_path = ''
    if not os.path.exists(target_path):
        os.makedirs(target_path)
    with torch.no_grad():
        for i in tqdm(range(data_nl.__len__())):
            data, labs, labs_path = data_nl.__getitem__(i)
            data = data.unsqueeze(0).to(device)
            labs = labs.unsqueeze(0)
            labs = pad_lab(labs)
            labs = labs.unsqueeze(0)
            labs = labs.to(device)
            adv_batch_t = patch_transformer(adv_cloth, labs, args.img_size, do_rotate=True, rand_loc=False, pooling=args.pooling, old_fasion=True)
            data = patch_applier(data, adv_batch_t)
            data = data.squeeze(0)
            img = unloader(data.detach().cpu())
            save_dir = os.path.join(target_path, labs_path.split("/")[-1].replace('.txt', '.png'))
            img.save(save_dir)