
import os
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import data.transforms as transform_train, transform_valid
import torchvision.models as models
from data.dataset import CalImageFolder
from .utils import generate_random_sampler_for_train
from .train import train
from .valid import validate
from tensorboardX import SummaryWriter
from .utils import adjust_learning_rate
from .save_checkpoint import save_checkpoint
from models.Metatrainer import DarMo


def main_worker(gpu, ngpus_per_node, args):
    global best_acc1, best_acc2, best_acc3, best_acc4, best_acc5, best_acc6, best_acc7, best_acc8, best_auc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    pretrained_dict = {}
    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model_pre = models.__dict__[args.arch](pretrained=True)
        pretrained_dict = model_pre.state_dict()
    else:
        print("=> creating model '{}' without pretraining".format('args.arch') )

    model = DarMo(3, 256, args.gd)
    model_dict = model.state_dict()
    if len(pretrained_dict) != 0:
        pretrained_dict1 = {'Resnetencoder.ResNet34.' + k: v for k, v in pretrained_dict.items() if
                            'Resnetencoder.ResNet34.' + k in model_dict}
        model_dict.update(pretrained_dict1)
        model.load_state_dict(model_dict)



    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            # model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model.cuda()

    # define loss function (criterion) and optimizer
    if not args.multilabel:
        criterion_cls = nn.CrossEntropyLoss().cuda()
        criterion_gcn = nn.CrossEntropyLoss().cuda()
    else:
        criterion_cls = nn.CrossEntropyLoss().cuda()
        criterion_gcn = nn.MultiLabelSoftMarginLoss().cuda()

    if args.adam:
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
    else:
        optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True


    traindir = args.trd
    valdir = args.vd

    train_dataset = CalImageFolder(
        None,
        traindir,
        transform_train)


    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    bac_sampler = generate_random_sampler_for_train(torch_dataset=train_dataset, batch_size=args.batch_size)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=1, shuffle=False,
        num_workers=args.workers, pin_memory=True, sampler=train_sampler, batch_sampler=bac_sampler)

    val_loader = torch.utils.data.DataLoader(
        CalImageFolder(None, valdir, transform_valid),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        print ('evaluating')
        validate(val_loader, model, criterion_cls, criterion_gcn, valdir, args)
        return


    name = str(args.saved) + '/' + str(args.batch_size) + '_' + str(args.lr) + '_' + str(args.epochs) + '_model_'
    writer_root='./'+str(args.saved) + '/'+'runs/'+str(args.batch_size) + '_' + str(args.lr)
    writer = SummaryWriter(writer_root)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion_cls, criterion_gcn, optimizer, epoch, args)

        # evaluate on validation set
        acc, auc = validate(val_loader, model, criterion_cls, criterion_gcn, args)

        # remember best acc@1 and save checkpoint

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                                                    and args.rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'acc': acc,
                'auc': auc,
                'optimizer': optimizer.state_dict(),
            }, name, epoch, args.epochs)

    writer.export_scalars_to_json(writer_root+'/'+str(args.batch_size) + '_' + str(args.lr) +'_all_scalars.json')

    writer.close()