from datetime import datetime
import time
import os

import torch
import torch.nn.functional as F
from apex import amp

import utils.dist_utils as dist_utils
from utils.logger import TensorboardLogger, FileLogger
from utils.meter import AverageMeter, NetworkMeter, TimeMeter
from utils.arg_parser import get_parser
from utils.saver import Saver

from models.load import load_model
from classifiers.load import init_classifier
from training.utils import distributed_predict, accuracy, correct
from training.scheduler import Scheduler
from data.dataloader import get_imagenet_loaders, get_svhn_loaders, get_cure_tsr_loaders, get_gtsrb_loaders
from training.train_algs import *

from grid import save_model_grid, save_img_grid

from torchvision.utils import save_image

args = get_parser()
is_master, is_rank0 = dist_utils.whoami(args)
args.world_size = dist_utils.env_world_size()
args.rank = dist_utils.env_rank()

# Only want master rank logging to tensorboard
tb = TensorboardLogger(args.logdir, is_master=is_master)
log = FileLogger(args.logdir, is_master=is_master, is_rank0=is_rank0)

def main():

    tb.log('sizes/world', dist_utils.env_world_size())
    dist_utils.setup_dist_backend(args)

    # load datasets, initialize classifiers, load model of natural variation
    # trn_loader, val_loader, trn_samp, val_samp = get_imagenet_loaders(args)
    trn_loader, val_loader ,trn_samp, val_samp = get_svhn_loaders(args)
    # trn_loader, val_loader, trn_samp, val_samp = get_cure_tsr_loaders(args)
    # trn_loader, val_loader, trn_samp, val_samp = get_gtsrb_loaders(args)


    model, criterion, optimizer = init_classifier(args)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.local_rank))
        model.load_state_dict(checkpoint['state_dict'])
        args.start_epoch = checkpoint['epoch']
        optimizer.load_state_dict(checkpoint['optimizer'])

    scheduler = Scheduler(optimizer, args, tb, log)
    G = load_model(args, reverse=False)

    # ####
    if args.local_rank == 0:
        from torchvision.utils import save_image
        imgs, labels = next(iter(trn_loader))
        save_image(imgs, 'train.png')

        imgs, _ = next(iter(val_loader))
        save_image(imgs, 'test.png')

        delta = torch.randn(imgs.size(0), 8, 1, 1).cuda()
        with torch.no_grad():
            out = G(imgs.cuda(), delta)
        
        save_image(out, 'mb_images.png')
    # quit()

    # ####

    start_time = datetime.now() 

    if args.evaluate: 
        top1, top5 = validate(val_loader, model, criterion, 0, start_time)
        print(f'Top1: {top1} | Top5: {top5}')
        return

    if args.distributed:
        dist_utils.sync_processes(args)

    saver = Saver(args, scheduler.tot_epochs)

    best_top1 = 0.
    for epoch in range(args.start_epoch, scheduler.tot_epochs):

        if args.distributed is True:
            trn_samp.set_epoch(epoch)
            val_samp.set_epoch(epoch)

        train(trn_loader, model, criterion, optimizer, scheduler, epoch, G, alg=get_alg(args))
        top1, top5 = validate(val_loader, model, criterion, epoch, start_time)
        saver.update(top1, top5)
        time_diff = (datetime.now()-start_time).total_seconds()/3600.0

        log.event("~~epoch\t\thours\t\ttop1\t\ttop5")
        log.event(f"~~{epoch}\t\t{time_diff:.5f}\t\t{top1:.3f}\t\t{top5:.3f}\n")

        # if is_rank0 is True: save_checkpoint(epoch, model, optimizer, args)

        if top1 > best_top1:
            if is_rank0 is True: save_checkpoint(epoch, model, optimizer, args)
            best_top1 = top1

def train(trn_loader, model, criterion, optimizer, scheduler, epoch, G, alg):
    net_meter = NetworkMeter()
    timer = TimeMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to train mode
    model.train()
    for i, (input, target) in enumerate(trn_loader):
        if args.short_epoch and (i > 10): break

        input, target = input.cuda(), target.cuda()

        batch_num = i+1
        timer.batch_start()
        scheduler.update_lr(epoch, i+1, len(trn_loader))

        if 'MDA' in alg:
            input, target = mda_train(input, target, model, G, args)

        elif 'MRT' in alg:
            input, target = mrt_train(input, target, model, criterion, G, args)

        elif 'PGD' is alg:
            input, target = pgd_train(input, target, model, criterion)

        elif 'MAT' in alg:
            input, target = mat_train(input, target, model, criterion, G, args)

        # save_image(input, f'svhn/train/rank-{args.local_rank}-batch-{i}.png')

        output = model(input)
        loss = criterion(output, target)
                
        optimizer.zero_grad()
        if args.half_prec:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
               scaled_loss.backward()
        else:
            loss.backward()

        optimizer.step()

        timer.batch_end()
        corr1, corr5 = correct(output.data, target, topk=(1, 5))
        reduced_loss, batch_total = loss.data.item(), input.size(0)

        if args.distributed is True:
            metrics = torch.tensor([batch_total, reduced_loss, corr1, corr5]).float().cuda()
            batch_total, reduced_loss, corr1, corr5 = dist_utils.sum_tensor(metrics).cpu().numpy()
            reduced_loss = reduced_loss/dist_utils.env_world_size()
        else:
            corr1, corr5 = corr1.item(), corr5.item()

        top1acc = corr1 * (100.0 / batch_total)
        top5acc = corr5 * (100.0 / batch_total)

        losses.update(reduced_loss, batch_total)
        top1.update(top1acc, batch_total)
        top5.update(top5acc, batch_total)

        if should_print(batch_num, trn_loader, args) is True:
            tb.log_memory()
            tb.log_trn_times(timer.batch_time.val, timer.data_time.val, input.size(0))
            tb.log_trn_loss(losses.val, top1.val, top5.val)

            recv_gbit, transmit_gbit = net_meter.update_bandwidth()
            tb.log("sizes/batch_total", batch_total)
            tb.log('net/recv_gbit', recv_gbit)
            tb.log('net/transmit_gbit', transmit_gbit)
            
            output = (f'Epoch: [{epoch}][{batch_num}/{len(trn_loader)}]\t'
                      f'Time {timer.batch_time.val:.3f} ({timer.batch_time.avg:.3f})\t'
                      f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
                      f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})\t'
                      f'Data {timer.data_time.val:.3f} ({timer.data_time.avg:.3f})\t'
                      f'BW {recv_gbit:.3f} {transmit_gbit:.3f}')
            log.verbose(output)

        tb.update_step_count(batch_total)


def validate(val_loader, model, criterion, epoch, start_time):
    timer = TimeMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    eval_start_time = time.time()


    for i, (input,target) in enumerate(val_loader):

        input, target = input.cuda(), target.cuda()
        # save_image(input, f'svhn/test/rank-{args.local_rank}-batch-{i}.png')

        if args.short_epoch and (i > 10): break

        batch_num = i + 1
        timer.batch_start()

        if args.distributed:
            top1acc, top5acc, loss, batch_total = distributed_predict(input, target, model, criterion)
        else:
            with torch.no_grad():
                output = model(input)
                loss = criterion(output, target).data
            batch_total = input.size(0)
            top1acc, top5acc = accuracy(output.data, target, topk=(1,5))

        # Eval batch done. Logging results
        timer.batch_end()
        losses.update(loss, batch_total)
        top1.update(top1acc, batch_total)
        top5.update(top5acc, batch_total)

        if should_print(batch_num, val_loader, args) is True:
            output = (f'Test:  [{epoch}][{batch_num}/{len(val_loader)}]\t'
                      f'Time {timer.batch_time.val:.3f} ({timer.batch_time.avg:.3f})\t'
                      f'Loss {losses.val:.4f} ({losses.avg:.4f})\t'
                      f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})')
            log.verbose(output)

    tb.log_eval(top1.avg, top5.avg, time.time()-eval_start_time)
    tb.log('epoch', epoch)

    return top1.avg, top5.avg

def get_alg(args):

    if args.mrt is True:
        return 'MRT'
    elif args.mda is True:
        return 'MDA'
    elif args.mat is True:
        return 'MAT'
    elif args.pgd is True:
        return 'PGD'
    else:
        return 'Baseline'

def should_print(batch_num, loader, args):
    """Checks whether logger should print output at current batch."""
    
    if (batch_num % args.print_freq == 0) or (batch_num == len(loader)):
        if args.local_rank == 0:
            return True
    return False


def save_checkpoint(epoch, model, optimizer, args):
    state = {
        'epoch': epoch+1, 'state_dict': model.state_dict(),
        'optimizer' : optimizer.state_dict(),
    }
    fname = os.path.join(args.save_path, f'best-top1-checkpoint.tar')
    torch.save(state, fname)

if __name__ == '__main__':
    main()
    tb.close()


