import os
import argparse

import torch
import torchvision
from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter

from config.train_config import cfg
from dataloader.coco_dataset import coco
from utils.evaluate_utils import evaluate
from utils.im_utils import Compose, ToTensor, RandomHorizontalFlip
from utils.plot_utils import plot_loss_and_lr, plot_map
from utils.train_utils import train_one_epoch, write_tb, create_model

def main():
    torch.distributed.init_process_group(backend="nccl")
    local_rank = torch.distributed.get_rank()
    device = torch.device("cuda", local_rank)
    
    if not os.path.exists(cfg.model_save_dir):
        os.makedirs(cfg.model_save_dir)

    # tensorboard writer
    writer = SummaryWriter(os.path.join(cfg.model_save_dir, 'epoch_log'))

    data_transform = {
        "train": Compose([ToTensor(), RandomHorizontalFlip(cfg.train_horizon_flip_prob)]),
        "val": Compose([ToTensor()])
    }

    if not os.path.exists(cfg.img_root):
        raise FileNotFoundError("dataset img dir not exist!")

    # load train data set
    train_data_sets = [coco(cfg.img_root, os.path.join(cfg.data_root, i + '.json'), data_transform["train"]) for i in
                       cfg.train_domain]
    mini_batch_size = cfg.mini_batch_size
    nw = cfg.num_workers
    train_data_loader = [torch.utils.data.DataLoader(t,
                                                     batch_size=mini_batch_size,
                                                     num_workers=nw,
                                                     shuffle=True,
                                                     collate_fn=t.collate_fn) for t in train_data_sets]

    # load validation data set
    val_data_set = coco(cfg.img_root, cfg.test_anno, data_transform["val"])
    val_data_set_loader = torch.utils.data.DataLoader(val_data_set,
                                                      batch_size=cfg.batch_size,
                                                      num_workers=nw,
                                                      shuffle=False,
                                                      collate_fn=val_data_set.collate_fn)

    # create model num_classes equal background + 80 classes
    model = create_model(cfg.num_class, cfg)

    model.to(device)
    model = torch.nn.parallel.DistributedDataParallel(model,
                                                      device_ids=[local_rank],
                                                      output_device=local_rank)

    # define optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=cfg.lr,
                                momentum=cfg.momentum, weight_decay=cfg.weight_decay)

    # learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=cfg.lr_dec_step_size,
                                                   gamma=cfg.lr_gamma)

    # train from pretrained weights
    if cfg.resume != "":
        checkpoint = torch.load(cfg.resume)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        cfg.start_epoch = checkpoint['epoch'] + 1
        print("the training process from epoch{}...".format(cfg.start_epoch))

    train_loss = []
    learning_rate = []
    val_mAP = []

    best_mAP = 0
    for epoch in range(cfg.start_epoch, cfg.num_epochs):
        loss_dict, total_loss = train_one_epoch(model, optimizer, train_data_loader,
                                                device, epoch, train_loss=train_loss, train_lr=learning_rate,
                                                print_freq=50, warmup=False)

        lr_scheduler.step()

        print("------>Starting validation data valid")
        coco_stats, mAP = evaluate(model, val_data_set_loader, device=device, mAP_list=val_mAP)

        print('validation mAp is {}'.format(mAP))
        print('best mAp is {}'.format(best_mAP))

        if mAP > best_mAP:
            best_mAP = mAP
            # save weights
            save_files = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch}
            model_save_dir = cfg.model_save_dir
            if not os.path.exists(model_save_dir):
                os.makedirs(model_save_dir)
            torch.save(save_files,
                       os.path.join(model_save_dir, "{}-model-{}-mAp-{}.pth".format(cfg.backbone, epoch, mAP)))



if __name__ == "__main__":
    version = torch.version.__version__[:5]
    print('torch version is {}'.format(version))
    main()

    
