import os
import torch
import argparse
import warnings
import itertools
import torchvision
from tqdm import tqdm
import torch.optim as optim
from datetime import datetime
warnings.filterwarnings('ignore',category=UserWarning)

# 识别模型与PATCH攻击
from utils import *
from yolo2 import utils
from generator import *
from cfg import get_cfgs
from color_util import *
from yolo2 import load_data
from load_models import load_models
from arch.yolov3_models import YOLOv3Darknet
from yolov5.models.common import DetectMultiBackend
from load_data import *



def train_patch():
    def generate_patch(type):
            if type == 'gray':
                adv_patch_cpu = torch.full((3, 288, 288), 0.5)
            elif type == 'random':
                adv_patch_cpu = torch.rand((3, 112, 288))
            return adv_patch_cpu

    adv_patch = generate_patch("gray").to(device)
    adv_patch.unsqueeze_(0)
    adv_patch.requires_grad_(True)
    optimizer = optim.Adam([adv_patch], lr=0.03, amsgrad=True)  
    scheduler_factory = lambda x: optim.lr_scheduler.ReduceLROnPlateau(x, 'min', patience=50)
    scheduler = scheduler_factory(optimizer)

    for epoch in range(1, args.n_epochs + 1):
        ep_det_loss = 0    
        ep_loss = 0
        for i_batch, (img_batch, lab_batch) in tqdm(enumerate(loader), desc=f'Running epoch {epoch}', total=epoch_length):
            optimizer.zero_grad()
            img_batch = img_batch.to(device)
            lab_batch = lab_batch.to(device)
            adv_patch_d2p = adv_patch
            adv_batch_t = patch_transformer(adv_patch_d2p, lab_batch, args.img_size, do_rotate=True, rand_loc=False,
                                            pooling=args.pooling, old_fasion=kwargs['old_fasion'])
            
            p_img_batch = patch_applier(img_batch, adv_batch_t)

            if epoch % 5 == 0:
                for i in range(1):
                    img = p_img_batch[i].detach().cpu().numpy().transpose(1,2,0)
                    img = Image.fromarray((img * 255).astype(np.uint8))
                    img.save("./middle_images/p_img_batch" + str(i) + ".jpg")

            if pargs.arch == "yolov2":
                det_loss, valid_num = get_v2_loss(True, darknet_model, p_img_batch, lab_batch, args, kwargs)
            elif pargs.arch == "yolov3":
                det_loss, valid_num = get_v3_loss(darknet_model, p_img_batch, lab_batch, args, kwargs)
            elif pargs.arch == "yolov5":
                det_loss, valid_num = get_v5_loss(darknet_model, p_img_batch, lab_batch, args, kwargs)
            else:
                det_loss, valid_num = get_rcnn_loss(darknet_model, p_img_batch, lab_batch, args, kwargs)
            if valid_num > 0:
                det_loss = det_loss / valid_num
       
            loss = det_loss
            ep_det_loss += det_loss.detach().cpu().numpy()
            if loss == 0: continue
            loss.backward()
            optimizer.step()            
            ep_loss += loss.item()
            adv_patch.data.clamp_(0, 1)  
            if epoch % max(min((args.n_epochs // 10), 100), 1) == 0:
                rpath = os.path.join(results_dir, 'patch%d' % epoch)
                np.save(rpath, adv_patch.detach().cpu().numpy())       
        ep_det_loss = ep_det_loss / len(loader)
        ep_loss = ep_loss / len(loader)
        print('detection loss',ep_det_loss)
      
        scheduler.step(ep_loss)
    return 0

def loadModel(name_digital,name_physical):
    path_model=os.path.join('ckpt',name_digital+'_to_'+name_physical+'.pth')
    model = BPnet()
    print("Loading model from",path_model,"...")
    model.load_state_dict(torch.load(path_model))
    print("Successful")
    return model


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PyTorch Training')
    parser.add_argument('--net', default='yolov2', help='识别模型')
    parser.add_argument('--method', default='TCA', help='')
    parser.add_argument('--suffix', default=None, help='suffix name')
    parser.add_argument('--gen_suffix', default=None, help='generator suffix name')
    parser.add_argument('--epoch', type=int, default=None, help='')
    parser.add_argument('--z_epoch', type=int, default=None, help='')
    parser.add_argument('--device', default='cuda:7', help='')
    parser.add_argument("--arch", type=str, default="yolov5", help='识别模型')
    pargs = parser.parse_args()
    args, kwargs = get_cfgs(pargs.net, pargs.method)
    if pargs.epoch is not None:
        args.n_epochs = pargs.epoch
    if pargs.z_epoch is not None:
        args.z_epochs = pargs.z_epoch
    if pargs.suffix is None:
        pargs.suffix = 'car_y5_woc_1128'
    device = torch.device(pargs.device)
    if pargs.arch == "yolov2":
        darknet_model = load_models(**kwargs)
        darknet_model = darknet_model.eval().to(device)
    elif pargs.arch == "yolov3":
        darknet_model = YOLOv3Darknet().eval().to(device)
        darknet_model.load_darknet_weights('./arch/weights/yolov3.weights')
    elif pargs.arch == "yolov5":
        darknet_model = DetectMultiBackend("yolov5s.pt").eval().to(device)
        for m in darknet_model.modules():
            if hasattr(m, 'inplace'):
                m.inplace = False
    else:
        darknet_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True).eval().to(device)
        for p in darknet_model.parameters():
            p.requires_grad = False
   
    
    img_dir_train = ''
    lab_dir_train = ''
    train_data = load_data.InriaDataset(img_dir_train, lab_dir_train, kwargs['max_lab'], args.img_size, shuffle=True)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=kwargs['batch_size'], shuffle=True, num_workers=4)
    patch_applier = load_data.PatchApplier().to(device)           
    patch_transformer = load_data.PatchTransformer(args).to(device) 
    results_dir = 'results/' + pargs.suffix
    print(results_dir)
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    loader = train_loader
    epoch_length = len(loader)
    print(f'One epoch is {len(loader)}')
    train_patch()