from __future__ import print_function
import random

import time
import argparse
import os
import sys
import shutil
import pprint

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from models.model import WideResnet, WideResnetLarge
from lr_scheduler import WarmupCosineLrScheduler
from models.ema import EMA

from utils import accuracy, mixmatch_interleave, create_logger
from utils import AverageMeter

from init import get_params, set_model

from tensorboardX import SummaryWriter


def linear_rampup(current):

    rampup_length = num_epochs
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current / rampup_length, 0.0, 1.0)
        return float(current)


class MixMatchLoss(object):
    def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, lambda_u):
        probs_u = torch.softmax(outputs_u, dim=1)

        Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
        Lu = torch.mean((probs_u - targets_u)**2)

        return Lx, Lu, lambda_u * linear_rampup(epoch)


def get_data(args):

    assert args.dataset in ['CIFAR10', 'CIFAR100', 'SVHN', 'STL10', 'TinyImageNet']

    if args.dataset == 'CIFAR10' or args.dataset == 'CIFAR100':
        from datasets.cifar import get_train_loader_mixmatch
    elif args.dataset == 'SVHN':
        from datasets.svhn import get_train_loader_mixmatch
    elif args.dataset == 'STL10':
        from datasets.stl10 import get_train_loader_mixmatch
    elif args.dataset == 'TinyImageNet':
        from datasets.tiny_imagenet import get_train_loader_mixmatch

    dltrain_x, dltrain_u, dlval = get_train_loader_mixmatch(
        args.dataset, args.batchsize, args.mu, args.n_iters_per_epoch, L=args.num_label, num_val=args.num_val)

    return dltrain_x, dltrain_u, dlval


def train_one_epoch(epoch,
                    model,
                    criterion,
                    optim,
                    lr_schdlr,
                    ema,
                    dltrain_x,
                    dltrain_u,
                    lambda_u,
                    n_iters,
                    logger,
                    log_interval,
                    alpha,
                    T,
                    n_class
                    ):

    model.train()

    loss_meter = AverageMeter()
    loss_x_meter = AverageMeter()
    loss_u_meter = AverageMeter()

    epoch_start = time.time()  # start time
    dl_x, dl_u = iter(dltrain_x), iter(dltrain_u)
    for it in range(n_iters):
        ims_x, _, targets_x = next(dl_x)
        ims_u1, ims_u2, _ = next(dl_u)

        bt = ims_x.size(0)

        # Transform label to one-hot
        targets_x = torch.zeros(bt, n_class).scatter_(1, targets_x.view(-1,1).long(), 1)

        ims_x, targets_x = ims_x.cuda(), targets_x.cuda()
        ims_u1, ims_u2 = ims_u1.cuda(), ims_u2.cuda()

        with torch.no_grad():
            # compute guessed labels of unlabel samples
            outputs_u1, _ = model(ims_u1)
            outputs_u2, _ = model(ims_u2)
            p = (torch.softmax(outputs_u1, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2
            pt = p**(1/T)
            targets_u = pt / pt.sum(dim=1, keepdim=True)
            targets_u = targets_u.detach()

        # mixup
        all_inputs = torch.cat([ims_x, ims_u1, ims_u2], dim=0)
        all_targets = torch.cat([targets_x, targets_u, targets_u], dim=0)

        lamda = np.random.beta(alpha, alpha)
        lamda = max(lamda, 1-lamda)

        newidx = torch.randperm(all_inputs.size(0))

        input_a, input_b = all_inputs, all_inputs[newidx]
        target_a, target_b = all_targets, all_targets[newidx]

        mixed_input = lamda * input_a + (1 - lamda) * input_b
        mixed_target = lamda * target_a + (1 - lamda) * target_b

        # interleave labeled and unlabed samples between batches to get correct batchnorm calculation
        mixed_input = list(torch.split(mixed_input, bt))
        mixed_input = mixmatch_interleave(mixed_input, bt)

        logit, _ = model(mixed_input[0])
        logits = [logit]
        for input in mixed_input[1:]:
            logit, _ = model(input)
            logits.append(logit)

         # put interleaved samples back
        logits = mixmatch_interleave(logits, bt)
        logits_x = logits[0]
        logits_u = torch.cat(logits[1:], dim=0)

        loss_x, loss_u, w = criterion(logits_x, mixed_target[:bt], logits_u, mixed_target[bt:], epoch+it/n_iters, lambda_u)

        loss = loss_x + w * loss_u

        optim.zero_grad()
        loss.backward()
        optim.step()
        ema.update_params()
        lr_schdlr.step()

        loss_meter.update(loss.item())
        loss_x_meter.update(loss_x.item())
        loss_u_meter.update(loss_u.item())

        if (it + 1) % log_interval == 0:
            t = time.time() - epoch_start

            lr_log = [pg['lr'] for pg in optim.param_groups]
            lr_log = sum(lr_log) / len(lr_log)

            logger.info("Epoch:{}, iter: {}. loss: {:.4f}. loss_x: {:.4f}. loss_u: {:.4f}. lambda_u: {:.4f}. lr: {}. Time: {:.2f}".format(
                epoch, it + 1, 
                loss_meter.avg, 
                loss_x_meter.avg, 
                loss_u_meter.avg,
                w, 
                lr_log, t))

            epoch_start = time.time()

    ema.update_buffer()

    logger.info("Train Epoch:{}. loss: {:.4f}. loss_x: {:.4f}. loss_u: {:.4f}.".format(epoch, loss_meter.avg, loss_x_meter.avg, loss_u_meter.avg))

    return loss_meter.avg, loss_x_meter.avg, loss_u_meter.avg


def evaluate(epoch, model, dataloader, criterion, logger):
    model.eval()

    loss_meter = AverageMeter()
    top1_meter = AverageMeter()
    top5_meter = AverageMeter()

    with torch.no_grad():
        for ims, lbs in dataloader:
            ims = ims.cuda()
            lbs = lbs.cuda()
            logits, _ = model(ims)
            loss = criterion(logits, lbs)
            scores = torch.softmax(logits, dim=1)
            top1, top5 = accuracy(scores, lbs, (1, 5))
            loss_meter.update(loss.item())
            top1_meter.update(top1.item())
            top5_meter.update(top5.item())

    logger.info("Test Epoch:{}. Top1: {:.4f}. Top5: {:.4f}. Loss: {:.4f}.".format(epoch, top1_meter.avg, top5_meter.avg, loss_meter.avg))

    return top1_meter.avg, top5_meter.avg, loss_meter.avg


def evaluate_ema(epoch, ema, dataloader, criterion, logger):
    # using EMA params to evaluate performance
    ema.apply_shadow()
    ema.model.eval()
    ema.model.cuda()

    loss_meter = AverageMeter()
    top1_meter = AverageMeter()
    top5_meter = AverageMeter()

    with torch.no_grad():
        for ims, lbs in dataloader:
            ims = ims.cuda()
            lbs = lbs.cuda()
            logits, _ = ema.model(ims)
            loss = criterion(logits, lbs)
            scores = torch.softmax(logits, dim=1)
            top1, top5 = accuracy(scores, lbs, (1, 5))
            loss_meter.update(loss.item())
            top1_meter.update(top1.item())
            top5_meter.update(top5.item())

    # note roll back model current params to continue training
    ema.restore()

    logger.info("Test Epoch:{}. Top1: {:.4f}. Top5: {:.4f}. Loss: {:.4f}.".format(epoch, top1_meter.avg, top5_meter.avg, loss_meter.avg))

    return top1_meter.avg, top5_meter.avg, loss_meter.avg


#############################################################################################
# Options
#############################################################################################
parser = argparse.ArgumentParser(description='Semi-supervised Learning')

parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0  0,1,2  0,2')

parser.add_argument('--run_dir', type=str, default='run', help='the directory for training')
parser.add_argument('--name',default='MixMatch', type=str, help='output model name')
parser.add_argument('--log_interval', type=int, default=10, help='log training status')
parser.add_argument('--save_epoch', type=int, default=64, help='step to save model')

parser.add_argument('--resume',default=None, type=str, help='checkpoint to continue')

parser.add_argument('--backbone', type=str, default='wideresnet', help='Wideresnet')
parser.add_argument('--wresnet-k', default=2, type=int, help='width factor of wide resnet')
parser.add_argument('--wresnet-n', default=28, type=int, help='depth of wide resnet')
parser.add_argument('--large_model', action='store_true', help='default is False. If True, using WideResnetLarge model')

parser.add_argument('--dataset', type=str, default='CIFAR10', help='CIFAR10, CIFAR100, SVHN, STL10, or TinyImageNet')
parser.add_argument('--num_label', type=int, default=40, help='number of labeled samples for training')
parser.add_argument('--num_val', type=int, default=5000, help='number of samples of cross-validation set')

parser.add_argument('--fold', type=int, default=0, help='used for STL10. This is to pick respective 1000-examples fold')

parser.add_argument('--start_epoch', type=int, default=0, help='epoch to start training')
parser.add_argument('--epochs', type=int, default=1024, help='number of training epoches')
parser.add_argument('--batchsize', type=int, default=64, help='train batch size of labeled samples')
parser.add_argument('--valbatchsize', default=100, type=int, help='validation batch size')
parser.add_argument('--mu', type=int, default=1, help='factor of train batch size of unlabeled samples')
    
parser.add_argument('--thr', type=float, default=0.95, help='threshold for picking high-confidence predictions')

parser.add_argument('--k_imgs', type=int, default=64 * 1024, help='number of training images for each epoch')
parser.add_argument('--lam-u', type=float, default=75, help='weight of unlabeled loss')
parser.add_argument('--ema-alpha', type=float, default=0.999, help='decay rate for ema module')
parser.add_argument('--lr', type=float, default=0.03, help='base learning rate for training')
parser.add_argument('--weight-decay', type=float, default=5e-4, help='weight decay')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum for optimizer')
parser.add_argument('--warmup', default=0, type=float, help='warmup epochs (unlabeled data based)')

parser.add_argument('--alpha', default=0.75, type=float, help='for Mixup algorithm')
parser.add_argument('--T', default=0.5, type=float, help='used for sharpening, entropy minimization')

parser.add_argument('--seed', type=int, default=-1, help='seed for random behaviors, no seed if negtive')
parser.add_argument('--seed_init', type=int, default=10000, help='the number to feed to seed()')

parser.add_argument('--cudnn_deter', action='store_false', help='default is True. If True, use deterministic functions as much as possible')
parser.add_argument('--cudnn_bench', action='store_true', help='default is False. If True, program may run faster')
#############################################################################################


def main():
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids

    # Create folder for training and copy important files
    if not os.path.exists(args.run_dir):
        os.makedirs(args.run_dir)
    time_str = time.strftime('%Y-%m-%d-%H-%M')
    args.out_dir = os.path.join(args.run_dir, args.name + '_' + args.backbone + '_' + args.dataset + '-{}-{}'.format(args.num_label, args.mu) + '_bz' + str(args.batchsize) + '_epoch' + str(args.epochs) + '_' + time_str)
    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)
    else:
        print('Error: the folder name is dupplicated!')
        return

    if not os.path.exists(args.out_dir + '/src'):
        os.makedirs(args.out_dir + '/src')

    tb_log_dir = os.path.join(args.out_dir, 'tensorboard')

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    logger = create_logger(args.out_dir, args.name.lower(), time_str) 

    shutil.copy2(
        os.path.basename(__file__),
        args.out_dir)
    for f in os.listdir('.'):
        if f.endswith('.py'):
            shutil.copy2(
                f,
                args.out_dir + '/src')
    folders = ['datasets', 'models']
    for folder in folders:
        if not os.path.exists(args.out_dir + '/src/{}'.format(folder)):
            os.makedirs(args.out_dir + '/src/{}'.format(folder))
        for f in os.listdir(folder):
            if f.endswith('.py'):
                shutil.copy2(
                    '{}/'.format(folder) + f,
                    args.out_dir + '/src/{}'.format(folder))


    # global settings
    if args.seed > 0:
        logger.info('Seed number is given by user: {}'.format(args.seed))
    else:
        args.seed = random.randrange(args.seed_init)
        logger.info('Seed number is randomly generated: {}'.format(args.seed))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.cudnn_deter
    torch.backends.cudnn.benchmark = args.cudnn_bench
    torch.cuda.manual_seed(args.seed)

    args.n_iters_per_epoch = args.k_imgs // args.batchsize  
    args.n_iters_all = args.n_iters_per_epoch * args.epochs  

    logger.info("***** Running training *****")
    logger.info(f"  Task = {args.dataset}@{args.num_label}")
    logger.info(f"  Num Epochs = {args.n_iters_per_epoch}")
    logger.info(f"  Batch size per GPU = {args.batchsize}")
    logger.info(f"  Total optimization steps = {args.n_iters_all}")

    args.n_classes, args.num_val = get_params(args.dataset)

    if args.dataset == 'STL10':
        args.num_val = args.fold

    model = set_model(args.n_classes, args.wresnet_k, args.wresnet_n, stl=True) if ((args.dataset == 'STL10') or (args.dataset == 'TinyImageNet')) else set_model(args.n_classes, args.wresnet_k, args.wresnet_n, large=args.large_model)
    logger.info("Total params: {:.2f}M".format(
        sum(p.numel() for p in model.parameters()) / 1e6))
    criterion = MixMatchLoss()

    dltrain_x, dltrain_u, dlval = get_data(args)

    ema = EMA(model, args.ema_alpha)

    wd_params, non_wd_params = [], []
    for name, param in model.named_parameters():
        if 'bn' in name:
            non_wd_params.append(param)  # bn.weight, bn.bias and classifier.bias
        else:
            wd_params.append(param)
    param_list = [
        {'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}]
    optim = torch.optim.SGD(param_list, lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum, nesterov=True)
    lr_schdlr = WarmupCosineLrScheduler(
        optim, max_iter=args.n_iters_all, warmup_iter=args.warmup
    )

    best_acc_val = -1
    best_epoch_val = 0
    best_acc_ema_val = -1
    best_epoch_ema_val = 0

    if args.resume != None:
        logger.info('Continue training from checkpoint {}.'.format(args.resume))
        state = torch.load(args.resume)

        args.start_epoch = state['epoch']

        lr_log = [pg['lr'] for pg in optim.param_groups]
        lr_log = sum(lr_log) / len(lr_log)
        logger.info('optim before loaded. lr: {}.'.format(lr_log))

        optim.load_state_dict(state['optimizer'])

        lr_log = [pg['lr'] for pg in optim.param_groups]
        lr_log = sum(lr_log) / len(lr_log)
        logger.info('optim after loaded. lr: {}.'.format(lr_log))
        
        model.load_state_dict(state['state_dict'])
        ema.shadow = state['ema_state_dict']

        best_acc_val = state['best_top1_val']
        best_epoch_val = state['best_epoch_val']
        best_acc_ema_val = state['best_top1_ema_val']
        best_epoch_ema_val = state['best_epoch_ema_val']

        logger.info("Val. best_acc_val: {:.4f} at epoch {}. best_acc_ema_val: {:.4f} at epoch {}.".format(best_acc_val, best_epoch_val, best_acc_ema_val, best_epoch_ema_val))

    train_args = dict(
        model=model,
        criterion=criterion,
        optim=optim,
        lr_schdlr=lr_schdlr,
        ema=ema,
        dltrain_x=dltrain_x,
        dltrain_u=dltrain_u,
        lambda_u=args.lam_u,
        n_iters=args.n_iters_per_epoch,
        logger=logger,
        log_interval=args.log_interval,
        alpha=args.alpha,
        T=args.T,
        n_class=args.n_classes
    )

    global num_epochs
    num_epochs = args.epochs

    logger.info(pprint.pformat(args))

    logger.info('-----------start training--------------')
    for epoch in range(args.start_epoch, args.epochs):
        train_loss, loss_x, loss_u = train_one_epoch(epoch, **train_args)
        # torch.cuda.empty_cache()

        writer_dict['writer'].add_scalars('train_loss', {
                                            'loss': train_loss,
                                            'loss_x': loss_x,
                                            'loss_u': loss_u
                                        }, epoch)

        logger.info('start validation.')
        logger.info('testing on val_loader using normal model...')
        top1_val, top5_val, valid_loss_val = evaluate(epoch, model, dlval, nn.CrossEntropyLoss().cuda(), logger)

        logger.info('testing on val_loader using ema model...')
        top1_ema_val, top5_ema_val, valid_loss_ema_val = evaluate_ema(epoch, ema, dlval, nn.CrossEntropyLoss().cuda(), logger)

        writer_dict['writer'].add_scalars('val_acc', {
                                            'top1_val': top1_val,
                                            'top5_val': top5_val,
                                            'top1_ema_val': top1_ema_val,
                                            'top5_ema_val': top5_ema_val
                                        }, epoch)
        writer_dict['writer'].add_scalars('val_loss', {
                                            'loss_val': valid_loss_val,
                                            'loss_ema_val': valid_loss_ema_val
                                        }, epoch)

        is_best_val = best_acc_val < top1_val
        if is_best_val:
            best_acc_val = top1_val
            best_epoch_val = epoch

        is_best_ema_val = best_acc_ema_val < top1_ema_val
        if is_best_ema_val:
            best_acc_ema_val = top1_ema_val
            best_epoch_ema_val = epoch

        logger.info('saving checkpoint...')
        torch.save({
            'epoch': epoch+1,
            'state_dict': model.state_dict(),
            'ema_state_dict': ema.shadow,   # not ema.model.state_dict()
            'top1_val': top1_val,
            'best_top1_val': best_acc_val,
            'best_epoch_val': best_epoch_val,
            'top1_ema_val': top1_ema_val,
            'best_top1_ema_val': best_acc_ema_val,
            'best_epoch_ema_val': best_epoch_ema_val,
            'optimizer': optim.state_dict(),
        },
        os.path.join(args.out_dir, args.name + '_checkpoint'))

        if is_best_val:
            logger.info('saving best normal model on val_loader...')
            torch.save(model.state_dict(), os.path.join(args.out_dir, args.name + '_bestval'))

        if is_best_ema_val:
            logger.info('saving best ema model on val_loader...')
            torch.save(ema.shadow, os.path.join(args.out_dir, args.name + '_ema_bestval')) # not ema.model.state_dict()

        logger.info("Val Epoch {}. best_acc_val: {:.4f} at epoch {}. best_acc_ema_val: {:.4f} at epoch {}.".format(epoch, best_acc_val, best_epoch_val, best_acc_ema_val, best_epoch_ema_val))

        if (epoch + 1) % args.save_epoch == 0:
            torch.save(model.state_dict(), os.path.join(args.out_dir, args.name + '_e{}'.format(epoch)))
            torch.save(ema.shadow, os.path.join(args.out_dir, args.name + '_ema_e{}'.format(epoch)))    # not ema.model.state_dict()

    writer_dict['writer'].close()


if __name__ == '__main__':
    main()
