import os
import sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import shutil
import time
import pickle

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data

from as_models.resnet_cifar import WideResNetCIFAR, ResNetV2CIFAR
from as_models.vgg_cifar import VGG
from as_models.model_utils import save_checkpoint
from as_utils.tools import accuracy, AverageMeter
from as_utils.logger import LoggerText
from as_data_reader.dataset_reader import create_train_and_val_data_loaders

from noise_self_distil.args import args


def get_model_id():
    prefix = 'baseline'
    if args.prefix is not None:
        prefix += '-' + args.prefix

    model_id = prefix + '-' + args.dataset + \
               '-' + args.network + \
               '-epo_' + str(args.epochs) + \
               '-b_' + str(args.batch_size) + \
               '-lr_' + str(args.lr) + \
               '-wd_' + str(args.weight_decay) + \
               '-' + args.init + \
               '-' + str(args.augment)
    if args.suffix is not None:
        model_id += '-' + args.suffix
    # print(model_id)
    return model_id


def main():
    print(args)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    train_loader, val_loader = create_train_and_val_data_loaders(
        args.augment, args.dataset, args.batch_size, args.test_batch_size, split_train_val=0.8,
        num_worker=args.num_worker)

    # create model
    if 'wrn' in args.network:  # like wrn_28_10
        layers = int(args.network.split('_')[-2])
        widen_factor = int(args.network.split('_')[-1])
        model = WideResNetCIFAR(layers, args.dataset == 'cifar100' and 100 or 10,
                                widen_factor, init=args.init)
    elif 'resnet_v2' in args.network:  # like resnet_v2_110
        layers = int(args.network.split('resnet_v2')[-1].split('_')[1])
        model = ResNetV2CIFAR(layers, args.dataset == 'cifar100' and 100 or 10, init=args.init)
    elif 'VGG' in args.network:
        model = VGG(args.network, args.dataset == 'cifar100' and 100 or 10, args.init)
    else:
        raise ValueError('Not suppported {} yet'.format(args.network))

    assert isinstance(model, torch.nn.Module)
    # print(model)

    print('dataset:', args.dataset)
    print('network: ', args.network)
    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    print('Weight decay: ', args.weight_decay * 0.5 * sum([(p.data ** 2).sum() for p in model.parameters()]))

    cudnn.benchmark = True
    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    # model = torch.nn.DataParallel(model).cuda()
    model = model.to(device)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(device)

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum, nesterov=args.nesterov,
                                weight_decay=args.weight_decay)

    from torch.optim.lr_scheduler import MultiStepLR
    quarter = args.epochs * len(train_loader) // 4
    scheduler = MultiStepLR(optimizer, gamma=0.1, milestones=[quarter * 2, quarter * 3])

    print('learning steps: ', quarter*2, quarter*3)
    print('use device: ', device)

    directory = os.path.join(f'{args.outdir}', get_model_id())
    os.makedirs(directory, exist_ok=True)

    train_results_names = ['wd', 'loss', 'precision']
    test_results_names = ['loss', 'precision']
    logger = LoggerText(directory)
    logger.write_once(args)
    logger.write_once(train_results_names)
    logger.write_once(test_results_names)

    def train():
        """Train for one epoch on the training set"""
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()

        # switch to train mode
        model.train()

        end = time.time()
        for i, (input, target) in enumerate(train_loader):
            target = target.to(device, non_blocking=True)
            input = input.to(device, non_blocking=True)
            # target = target.cuda(non_blocking=True)
            # input = input.cuda(non_blocking=True)

            # compute output
            output, _ = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target, topk=(1,))[0]
            losses.update(loss.data.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                wd = args.weight_decay * 0.5 * sum([(p.data ** 2).sum() for p in model.parameters()])
                rest_time = (args.epochs - epoch + 1) * len(train_loader) - i
                rest_time *= batch_time.avg / 3600

                print('Epoch: [{0}/{1}][{2}/{3}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.3f} ({loss.avg:.4f})\t'
                      'WD {wd: .3f}\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Rest: {rest: .1f} hours'.format(epoch, args.epochs, i, len(train_loader),
                                                       rest=rest_time,
                                                       batch_time=batch_time,
                                                       loss=losses,
                                                       top1=top1,
                                                       wd=wd)
                      )

        # weight decay
        wd = args.weight_decay * 0.5 * sum([(p.data ** 2).sum() for p in model.parameters()])

        return wd.item(), losses.avg, top1.avg

    def validate():
        """Perform validation on the validation set"""
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()

        # switch to evaluate mode
        model.eval()

        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            target = target.to(device, non_blocking=True)
            input = input.to(device, non_blocking=True)
            # if use_cuda:
            #     target = target.cuda(non_blocking=True)
            #     input = input.cuda(non_blocking=True)

            # compute output
            with torch.no_grad():
                output, _ = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target, topk=(1,))[0]
            losses.update(loss.data.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(i, len(val_loader),
                                                                      batch_time=batch_time,
                                                                      loss=losses,
                                                                      top1=top1)
                      )

        print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))

        return losses.avg, top1.avg

    if args.suffix is not None:
        save_checkpoint(model, directory, f'ckpt-0.pth')

    for epoch in range(args.start_epoch, args.epochs):
        print(time.ctime())
        print('lr {0}, at {1} epoch: '.format(scheduler.get_lr()[0], epoch))

        # train for one epoch
        train_results = train()

        # evaluate on validation set
        test_results = validate()
        prec1 = test_results[-1]

        # remember best prec@1 and save checkpoint
        if args.suffix is not None:
            if 'all' in args.suffix:
                save_checkpoint(model, directory, f'ckpt-{epoch+1}.pth')
            else:
                save_checkpoint(model, directory)

        # logger
        logger.write_one_step_results(epoch, train_results, test_results)
        logger.flush()

    logger.close()


if __name__ == '__main__':
    main()

