# train.py
#!/usr/bin/env  python3

""" train network using pytorch

author baiyu
"""

import os
import sys
import argparse
import time
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
# import torch.optim as optim
import adam_microbatch_grad_DDP as optim
import torchvision
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from MicroGradScheduler import MicroGradScheduler
import torch.distributed as dist
# from torch.nn.parallel import DistributedDataParallel as DDPs

from conf import settings
from utils import get_network, get_training_dataloader, get_test_dataloader, WarmUpLR, \
    most_recent_folder, most_recent_weights, last_epoch, best_acc_weights
import torchvision.models as models
import torchvision.datasets as datasets
from tqdm import tqdm
from enum import Enum
import subprocess

global_step = 0
scaled_loss = 0.0


def train(args, epoch):
        # start = time.time()
    net.train()
    for batch_index, (images, labels) in enumerate(tqdm(cifar100_training_loader)):

        if args.gpu:
            labels = labels.to(args.local_rank)
            images = images.to(args.local_rank)

        outputs = net(images)
        loss = loss_function(outputs, labels) / args.grad_accumulation_step
        loss.backward()
        grad_scheduler.windup_optim_step()

    train_scheduler.step()


""" Parameter averaging. """
def average_parameters(model):
    size = float(dist.get_world_size())
    for param in model.parameters():
        dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
        param.data /= size
        # print(param.data[-1])
        # import sys; sys.exit()

""" Gradient averaging. """
def average_gradients(model):
    size = float(dist.get_world_size())
    for param in model.parameters():
        dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
        param.grad.data /= size
        # print(param.grad.data[-1])
        # import sys; sys.exit()

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('-net', type=str, required=True, help='net type')
    parser.add_argument('-gpu', action='store_true', default=False, help='use gpu or not')
    parser.add_argument('-microbszpergpu', type=int, default=128, help='batch size for dataloader')
    parser.add_argument('-warm', type=int, default=1, help='warm up training phase')
    parser.add_argument('-grad_accumulation_step', type=int, default=2, help='grad_accumulation_step')
    parser.add_argument('-lr', type=float, default=3e-4, help='initial learning rate')
    parser.add_argument('-resume', action='store_true', default=False, help='resume training')
    parser.add_argument('--local_rank', default=-1)
    args = parser.parse_args()


    # =================DDP=========================
    local_rank = int(args.local_rank)

    torch.cuda.set_device(local_rank)

    args.rank = int(os.environ.get('RANK'))
    args.local_rank = int(os.environ.get('LOCAL_RANK'))
    args.world_size = int(os.environ.get('WORLD_SIZE'))
    args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'],os.environ['MASTER_PORT'])

    dist.init_process_group(backend='nccl', init_method=args.dist_url, world_size=args.world_size, rank=args.rank)

    torch.cuda.set_device(args.rank)
    device = torch.device("cuda", args.rank)

    # =============================================

    # ===================SLURM=========================

    #$num_gpus = torch.cuda.device_count()

    #$if "SLURM_JOB_ID" in os.environ:
    #$    rank = int(os.environ["SLURM_PROCID"])
    #$    args.rank = rank
    #$    args.local_rank = rank % num_gpus
    #$    world_size = int(os.environ["SLURM_NTASKS"])
    #$    args.world_size = world_size
    #$    node_list = os.environ["SLURM_NODELIST"]
    #$    addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
    #$    # specify master port
    #$    port = None
    #$    if port is not None:
    #$        os.environ["MASTER_PORT"] = str(port)
    #$    elif "MASTER_PORT" not in os.environ:
    #$        os.environ["MASTER_PORT"] = "29500"
    #$    if "MASTER_ADDR" not in os.environ:
    #$        os.environ["MASTER_ADDR"] = addr
    #$    os.environ["WORLD_SIZE"] = str(world_size)
    #$    os.environ["LOCAL_RANK"] = str(rank % num_gpus)
    #$    os.environ["RANK"] = str(rank)
    #$else:
    #$    rank = int(os.environ["RANK"])
    #$    world_size = int(os.environ["WORLD_SIZE"])
    #$
    #$args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
    #$dist.init_process_group(backend='nccl', init_method=args.dist_url, world_size=args.world_size, rank=args.rank)

    #$torch.cuda.set_device(args.rank)
    #$device = torch.device("cuda", args.rank)

# =================================================


    print(device)
    net = get_network(args, device)
    #data preprocessing:
    cifar100_training_loader = get_training_dataloader(
        settings.CIFAR100_TRAIN_MEAN,
        settings.CIFAR100_TRAIN_STD,
        num_workers=16,
        batch_size=args.microbszpergpu,
        shuffle=True,
        DDP = True
    )

    cifar100_test_loader = get_test_dataloader(
        settings.CIFAR100_TRAIN_MEAN,
        settings.CIFAR100_TRAIN_STD,
        num_workers=16,
        batch_size=args.microbszpergpu,
        shuffle=True
    )

    loss_function = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(net.parameters(), args.grad_accumulation_step, args.lr, weight_decay=0)
    train_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=settings.MILESTONES, gamma=0.2) #learning rate decay

    grad_scheduler = MicroGradScheduler(net, optimizer)


    average_parameters(net)

    for epoch in range(1, settings.EPOCH + 1):
        train(args, epoch)


    torch.distributed.destroy_process_group()

