import os
import sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import os
import time
import numpy as np

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data

from as_models.resnet_cifar import WideResNetCIFAR, ResNetV2CIFAR
from as_models.model_utils import save_checkpoint, load_model
from as_utils.tools import accuracy, AverageMeter
from as_utils.logger import LoggerText
from as_data_reader.dataset_reader import create_train_and_test_data_loaders

from noise_self_distil.args import args


def get_model_id():
    prefix = 'kd'
    prefix = f'{prefix}-{args.prefix}' if args.prefix is not None else prefix

    model_id = f'{prefix}-{args.dataset}-{args.network}-epo_{args.epochs}-' \
               f'b_{args.batch_size}-lr_{args.lr}-wd_{args.weight_decay}-' \
               f't_{args.temperature}-rate_{args.kd_rate}-' \
               f'noise_{args.noise_force}_{args.noise_type}-' \
               f'{args.init}'

    model_id = f'{model_id}-{args.suffix}' if args.suffix is not None else model_id
    # print(model_id)
    return model_id


def main():
    print(args)

    if args.seed is not None:
        # some small randomness cannot be determinism, e.g., pooling, padding, and sampling.
        # https://pytorch.org/docs/stable/notes/randomness.html
        # but here, it is determinist.
        torch.random.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    train_loader, val_loader = create_train_and_test_data_loaders(
        args.augment, args.dataset, args.batch_size, args.test_batch_size, args.num_worker)

    # create teacher model
    if 'wrn' in args.teacher_network:
        layers = int(args.teacher_network.split('_')[-2])
        widen_factor = int(args.teacher_network.split('_')[-1])
        teacher_model = WideResNetCIFAR(layers, args.dataset == 'cifar100' and 100 or 10,
                                        widen_factor, init=args.init)
    elif 'resnet_v2' in args.teacher_network:
        layers = int(args.teacher_network.split('resnet_v2')[-1].split('_')[1])
        teacher_model = ResNetV2CIFAR(layers, args.dataset == 'cifar100' and 100 or 10, init=args.init)
    else:
        raise ValueError('Not suppported {} yet'.format(args.network))

    assert isinstance(teacher_model, torch.nn.Module)

    print('use device: ', device)
    print("=> loading checkpoint '{}'".format(args.teacher_model))
    # checkpoint = torch.load(args.teacher_model, map_location=torch.device('cpu'))
    # teacher_model.load_state_dict(checkpoint['state_dict'])

    load_model(teacher_model, args.teacher_model)
    print('teacher_network: ', args.teacher_network)

    for param in teacher_model.parameters():
        param.requires_grad = False

    # create model
    if 'wrn' in args.network:  # like wrn_28_10
        layers = int(args.network.split('_')[-2])
        widen_factor = int(args.network.split('_')[-1])
        model = WideResNetCIFAR(layers, args.dataset == 'cifar100' and 100 or 10,
                                widen_factor, init=args.init)
    elif 'resnet_v2' in args.network:  # like resnet_v2_110
        layers = int(args.network.split('resnet_v2')[-1].split('_')[1])
        model = ResNetV2CIFAR(layers, args.dataset == 'cifar100' and 100 or 10, init=args.init)
    else:
        raise ValueError('Not suppported {} yet'.format(args.network))

    assert isinstance(model, torch.nn.Module)
    # print(model)

    print('dataset:', args.dataset)
    print('network: ', args.network)
    print('initializer: ', args.init)
    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    print('Weight decay: ', args.weight_decay * 0.5 * sum([(p.data ** 2).sum() for p in model.parameters()]))

    model = model.to(device)
    teacher_model = teacher_model.to(device)
    teacher_model.eval()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(device)

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum, nesterov=args.nesterov,
                                weight_decay=args.weight_decay)

    from torch.optim.lr_scheduler import MultiStepLR
    quarter = args.epochs * len(train_loader) // 4
    scheduler = MultiStepLR(optimizer, gamma=0.1, milestones=[quarter * 2, quarter * 3])

    print('learning steps: ', quarter*2, quarter*3)

    directory = os.path.join(f'{args.outdir}', get_model_id())
    os.makedirs(directory, exist_ok=True)

    train_results_names = ['wd', 'loss', 'precision', 'kd']
    test_results_names = ['loss', 'precision']
    logger = LoggerText(directory)
    logger.write_once(args)
    logger.write_once(train_results_names)
    logger.write_once(test_results_names)

    if args.suffix is not None:
        save_checkpoint(model, directory, 'init.pth')

    def train():
        """Train for one epoch on the training set"""
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        kd_losses = AverageMeter()
        t_losses = AverageMeter()

        # switch to train mode
        model.train()

        end = time.time()

        P = None
        for i, (input, target) in enumerate(train_loader):
            target = target.to(device, non_blocking=True)
            input = input.to(device, non_blocking=True)

            batch_size = input.size(0)

            # compute output
            logits, features = model(input)
            logits_0, features_0 = teacher_model(input)
            loss = criterion(logits, target)
            t_loss = criterion(logits_0, target)
            prec1 = accuracy(logits.data, target, topk=(1,))[0]

            losses.update(loss.data.item(), batch_size)
            top1.update(prec1.item(), batch_size)
            t_losses.update(t_loss, batch_size)

            if args.noise_type == 's':
                # symmetric flipping labels.
                num_classes = logits_0.shape[-1]
                if P is None:
                    n = args.noise_force
                    assert 0.0 <= n <= 1.0
                    P = (n / (num_classes - 1)) * torch.ones((num_classes, num_classes))
                    for i in range(num_classes):
                        P[i, i] = 1. - n
                    P = P.to(device)

                flipped_labels = torch.multinomial(P, 1)
                noisy_transition = torch.zeros_like(P).to(device)
                noisy_transition[flipped_labels[:, 0], range(num_classes)] = 1
                logits_0 = torch.matmul(logits_0, noisy_transition)
            elif args.noise_type == 'n2':
                # normal distribution with E[noise] = 0
                noise = args.noise_force * torch.randn_like(logits_0)
                logits_0 = logits_0 + noise
            else:
                assert ValueError('Unknown noise type.')

            kd_loss = torch.mean(torch.sum((logits_0 - logits) ** 2, -1))

            total_loss = loss + kd_loss * args.kd_rate
            kd_losses.update(kd_loss.item())

            # compute gradient and do SGD step
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            scheduler.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                wd = args.weight_decay * 0.5 * sum([(p.data ** 2).sum() for p in model.parameters()])
                rest_time = (args.epochs - epoch + 1) * len(train_loader) - i
                rest_time *= batch_time.avg / 3600

                print('Epoch: [{0}/{1}][{2}/{3}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.3f} ({loss.avg:.3f})\t'
                      'T-Loss {t_loss.val:.2f} ({t_loss.avg:.2f})\t'
                      'KDLoss {kdloss.val:.4f} ({kdloss.avg:.4f})\t'
                      'Weight Decay {wd: .3f}\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Rest: {rest: .1f} hours'.format(epoch, args.epochs, i, len(train_loader),
                                                       rest=rest_time,
                                                       batch_time=batch_time,
                                                       loss=losses,
                                                       top1=top1,
                                                       wd=wd,
                                                       t_loss=t_losses,
                                                       kdloss=kd_losses)
                      )

        # weight decay
        wd = args.weight_decay * 0.5 * sum([(p.data ** 2).sum() for p in model.parameters()])

        return wd.item(), losses.avg, top1.avg, kd_losses.avg

    def validate():
        """Perform validation on the validation set"""
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()

        # switch to evaluate mode
        model.eval()

        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            target = target.to(device, non_blocking=True)
            input = input.to(device, non_blocking=True)
            # if use_cuda:
            #     target = target.cuda(non_blocking=True)
            #     input = input.cuda(non_blocking=True)

            # compute output
            with torch.no_grad():
                output, _ = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target, topk=(1,))[0]
            losses.update(loss.data.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(i, len(val_loader),
                                                                      batch_time=batch_time,
                                                                      loss=losses,
                                                                      top1=top1)
                      )

        print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))

        return losses.avg, top1.avg

    for epoch in range(args.start_epoch, args.epochs):
        print(time.ctime())
        print('lr {0}, kd_rate {1}, at {2} epoch: '.format(scheduler.get_lr()[0], args.kd_rate, epoch))

        train_results = train()

        # evaluate on validation set
        test_results = validate()
        prec1 = test_results[-1]

        if args.suffix is not None:
            if 'all' in args.suffix:
                save_checkpoint(model, directory, f'ckpt-{epoch}.pth')
            else:
                save_checkpoint(model, directory)

        # logger
        logger.write_one_step_results(epoch, train_results, test_results)
        logger.flush()


if __name__ == '__main__':
    main()
