# train.py
#!/usr/bin/env  python3

""" train network using pytorch

author baiyu
"""

import os
import sys
import argparse
import time
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


import torch.distributed as dist
# from torch.nn.parallel import DistributedDataParallel as DDPs

from conf import settings
from utils import get_network, get_training_dataloader, get_test_dataloader, WarmUpLR, \
    most_recent_folder, most_recent_weights, last_epoch, best_acc_weights
import torchvision.models as models
import torchvision.datasets as datasets
from tqdm import tqdm
from enum import Enum
import subprocess

def train(args, epoch):
        # start = time.time()
    net.train()
    for batch_index, (images, labels) in enumerate(tqdm(cifar100_training_loader)):

        if args.gpu:
            labels = labels.to(args.local_rank)
            images = images.to(args.local_rank)

        optimizer.zero_grad()
        outputs = net(images)
        loss = loss_function(outputs, labels)
        loss.backward()
        average_gradients(net)
        optimizer.step()
        if args.rank == 0:
            n_iter = (epoch - 1) * len(cifar100_training_loader) + batch_index + 1
            writer.add_scalar('Train/loss', loss.item(),n_iter)

    if args.rank == 0:
        # evaluate on validation set
        acc1 = eval_training(epoch)

        # remember best acc@1 and save checkpoint
        #best_acc1 = max(acc1, best_acc1)
        #best_acc5 = max(acc5, best_acc5)
        print("Epoch: ", epoch)
        print("current acc1 is ", acc1)
        #print("best acc1 is ", best_acc1, ", best acc5 is", best_acc5)
    train_scheduler.step()
    if args.rank == 0:
        print("learning_rate: ", train_scheduler.get_last_lr())

@torch.no_grad()
def eval_training(epoch=0):

    start = time.time()
    net.eval()

    test_loss = 0.0 # cost function error
    correct = 0.0

    for (images, labels) in cifar100_test_loader:

        if args.gpu:
            images = images.to(device)
            labels = labels.to(device)

        outputs = net(images)
        loss = loss_function(outputs, labels)

        test_loss += loss.item()
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum()

    finish = time.time()

    print('Test set: Epoch: {}, Average loss: {:.4f}, Accuracy: {:.4f}, Time consumed:{:.2f}s'.format(
        epoch,
        test_loss / len(cifar100_test_loader.dataset),
        correct.float() / len(cifar100_test_loader.dataset),
        finish - start
    ))
    print()


    writer.add_scalar('Test/Average loss', test_loss / len(cifar100_test_loader.dataset), epoch)
    writer.add_scalar('Test/Top1Accuracy', correct.float() / len(cifar100_test_loader.dataset), epoch)

    return correct.float() / len(cifar100_test_loader.dataset)



""" Parameter averaging. """
def average_parameters(model):
    size = float(dist.get_world_size())
    for param in model.parameters():
        dist.all_reduce(param.data, op=dist.ReduceOp.SUM)
        param.data /= size
        # print(param.data[-1])
        # import sys; sys.exit()

""" Gradient averaging. """
def average_gradients(model):
    size = float(dist.get_world_size())
    for param in model.parameters():
        dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
        param.grad.data /= size
        # print(param.grad.data[-1])
        # import sys; sys.exit()

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('-net', type=str, required=True, help='net type')
    parser.add_argument('-gpu', action='store_true', default=False, help='use gpu or not')
    parser.add_argument('-microbszpergpu', type=int, default=128, help='batch size for dataloader')
    parser.add_argument('-warm', type=int, default=1, help='warm up training phase')
    parser.add_argument('-lr', type=float, default=3e-4, help='initial learning rate')
    parser.add_argument('-resume', action='store_true', default=False, help='resume training')
    parser.add_argument('--local_rank', default=-1)
    args = parser.parse_args()


    # =================DDP=========================
    local_rank = int(args.local_rank)

    torch.cuda.set_device(local_rank)

    args.rank = int(os.environ.get('RANK'))
    args.local_rank = int(os.environ.get('LOCAL_RANK'))
    args.world_size = int(os.environ.get('WORLD_SIZE'))
    args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'],os.environ['MASTER_PORT'])

    dist.init_process_group(backend='nccl', init_method=args.dist_url, world_size=args.world_size, rank=args.rank)

    torch.cuda.set_device(args.rank)
    device = torch.device("cuda", args.rank)

    # =============================================

    # ===================SLURM=========================

    # num_gpus = torch.cuda.device_count()

    # if "SLURM_JOB_ID" in os.environ:
    #     rank = int(os.environ["SLURM_PROCID"])
    #     args.rank = rank
    #     args.local_rank = rank % num_gpus
    #     world_size = int(os.environ["SLURM_NTASKS"])
    #     args.world_size = world_size
    #     node_list = os.environ["SLURM_NODELIST"]
    #     addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
    #     # specify master port
    #     port = None
    #     if port is not None:
    #         os.environ["MASTER_PORT"] = str(port)
    #     elif "MASTER_PORT" not in os.environ:
    #         os.environ["MASTER_PORT"] = "29500"
    #     if "MASTER_ADDR" not in os.environ:
    #         os.environ["MASTER_ADDR"] = addr
    #     os.environ["WORLD_SIZE"] = str(world_size)
    #     os.environ["LOCAL_RANK"] = str(rank % num_gpus)
    #     os.environ["RANK"] = str(rank)
    # else:
    #     rank = int(os.environ["RANK"])
    #     world_size = int(os.environ["WORLD_SIZE"])
    
    # args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
    # dist.init_process_group(backend='nccl', init_method=args.dist_url, world_size=args.world_size, rank=args.rank)

    # torch.cuda.set_device(args.rank)
    # device = torch.device("cuda", args.rank)

# =================================================


    print(device)
    net = get_network(args, device)
    #data preprocessing:
    cifar100_training_loader = get_training_dataloader(
        settings.CIFAR100_TRAIN_MEAN,
        settings.CIFAR100_TRAIN_STD,
        num_workers=16,
        batch_size=args.microbszpergpu,
        shuffle=True,
        DDP = True
    )

    cifar100_test_loader = get_test_dataloader(
        settings.CIFAR100_TRAIN_MEAN,
        settings.CIFAR100_TRAIN_STD,
        num_workers=16,
        batch_size=args.microbszpergpu,
        shuffle=True
    )

    loss_function = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(net.parameters(), args.lr)
    train_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=settings.MILESTONES, gamma=0.2) #learning rate decay


    LOG_DIR = 'path to log dir'
    modelname = 'resnet50_cifar100_baseline'
    if args.rank == 0:
        writer = SummaryWriter(log_dir = os.path.join(LOG_DIR, modelname))


    average_parameters(net)


    for epoch in range(1, settings.EPOCH + 1):
        train(args, epoch)

    torch.distributed.destroy_process_group()
    if args.rank == 0:
        writer.close()

