from __future__ import print_function
import random

import time
import argparse
import os
import sys
import shutil
import pprint

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from models.model import WideResnet, WideResnetLarge
from lr_scheduler import WarmupCosineLrScheduler
from models.ema import EMA

from utils import accuracy, interleave, de_interleave, create_logger
from utils import AverageMeter

from loss import contrastive_loss

from init import get_params, set_model

from tensorboardX import SummaryWriter


class RankingMatchLoss(object):
    def __call__(self, logits_x, lbs_x, logits_u_s, lbs_u_guess, mask, norm_feas_x, norm_feas_u_s, temperature):

        loss_x = F.cross_entropy(logits_x, lbs_x)

        loss_u = (F.cross_entropy(logits_u_s, lbs_u_guess, reduction='none') * mask).mean()

        loss_contrast_x = contrastive_loss(norm_feas_x, lbs_x, temperature=temperature)

        idens_u = lbs_u_guess[mask==1]
        if idens_u.nelement() == 0:
            # To avoid Nan value when loss_u = 0 (mask is all-zero)
            loss_contrast_u = torch.zeros(()).cuda()
        else:
            norm_feas_u_s_ = norm_feas_u_s[mask==1]
            loss_contrast_u = contrastive_loss(norm_feas_u_s_, idens_u, temperature=temperature)

        return loss_x, loss_u, loss_contrast_x, loss_contrast_u


def get_data(args):
    if args.dataset == 'CIFAR10' or args.dataset == 'CIFAR100':
        from datasets.cifar import get_train_loader
    elif args.dataset == 'SVHN':
        from datasets.svhn import get_train_loader
    elif args.dataset == 'STL10':
        from datasets.stl10 import get_train_loader
    elif args.dataset == 'TinyImageNet':
        from datasets.tiny_imagenet import get_train_loader

    dltrain_x, dltrain_u, dlval = get_train_loader(
        args.dataset, args.batchsize, args.mu, args.n_iters_per_epoch, L=args.num_label, num_val=args.num_val)

    return dltrain_x, dltrain_u, dlval


def train_one_epoch(epoch,
                    model,
                    criterion,
                    optim,
                    lr_schdlr,
                    ema,
                    dltrain_x,
                    dltrain_u,
                    lambda_u,
                    lambda_con,
                    n_iters,
                    logger,
                    log_interval,
                    threshold,
                    temperature,
                    l2_norm
                    ):
    model.train()
    loss_meter = AverageMeter()
    loss_x_meter = AverageMeter()
    loss_u_meter = AverageMeter()
    loss_conx_meter = AverageMeter()
    loss_conu_meter = AverageMeter()
    mask_meter = AverageMeter()

    epoch_start = time.time()  # start time
    dl_x, dl_u = iter(dltrain_x), iter(dltrain_u)
    for it in range(n_iters):
        ims_x_weak, _, lbs_x = next(dl_x)
        ims_u_weak, ims_u_strong, _ = next(dl_u)

        ims_x_weak, lbs_x = ims_x_weak.cuda(), lbs_x.cuda()
        ims_u_weak, ims_u_strong = ims_u_weak.cuda(), ims_u_strong.cuda()

        # --------------------------------------
        bt = ims_x_weak.size(0)
        mu = int(ims_u_weak.size(0) // bt)
        imgs = torch.cat([ims_x_weak, ims_u_weak, ims_u_strong], dim=0)
        imgs = interleave(imgs, 2 * mu + 1) # (2 * mu + 1) --> number of batches of labeled data
                                            # interleave --> in order to gain batch normalization of training
        logits, _ = model(imgs)
        logits = de_interleave(logits, 2 * mu + 1)

        logits_x = logits[:bt]
        logits_u_w, logits_u_s = torch.split(logits[bt:], bt * mu)

        if l2_norm:
            norm_logits_x = F.normalize(logits_x, p=2, dim=1)
            norm_logits_u_s = F.normalize(logits_u_s, p=2, dim=1)
        else:
            norm_logits_x = logits_x
            norm_logits_u_s = logits_u_s

        with torch.no_grad():
            probs = torch.softmax(logits_u_w, dim=1)
            scores, lbs_u_guess = torch.max(probs, dim=1)
            mask = scores.ge(threshold).float()

        loss_x, loss_u, loss_conx, loss_conu = criterion(logits_x, lbs_x, logits_u_s, lbs_u_guess, mask, norm_logits_x, norm_logits_u_s, temperature)

        loss = loss_x + lambda_u * loss_u + lambda_con * (loss_conx + loss_conu)
        # --------------------------------------

        optim.zero_grad()
        loss.backward()
        optim.step()
        ema.update_params()
        lr_schdlr.step()

        loss_meter.update(loss.item())
        loss_x_meter.update(loss_x.item())
        loss_u_meter.update(loss_u.item())
        loss_conx_meter.update(loss_conx.item())
        loss_conu_meter.update(loss_conu.item())
        mask_meter.update(mask.mean().item())

        if (it + 1) % log_interval == 0:
            t = time.time() - epoch_start

            lr_log = [pg['lr'] for pg in optim.param_groups]
            lr_log = sum(lr_log) / len(lr_log)

            logger.info("Epoch:{}, iter: {}. loss: {:.4f}. loss_x: {:.4f}. loss_u: {:.4f}. loss_conx: {:.4f}. loss_conu: {:.4f}. "
                        "Mask:{:.4f} . lr: {}. Time: {:.2f}".format(
                epoch, it + 1, loss_meter.avg, loss_x_meter.avg, loss_u_meter.avg, loss_conx_meter.avg, loss_conu_meter.avg,
                mask_meter.avg, lr_log, t))

            epoch_start = time.time()

    ema.update_buffer()

    logger.info("Train Epoch:{}. loss: {:.4f}. loss_x: {:.4f}. loss_u: {:.4f}. loss_conx: {:.4f}. loss_conu: {:.4f}. Mask:{:.4f}.".format(epoch, loss_meter.avg, loss_x_meter.avg, loss_u_meter.avg, loss_conx_meter.avg, loss_conu_meter.avg, mask_meter.avg))

    return loss_meter.avg, loss_x_meter.avg, loss_u_meter.avg, loss_conx_meter.avg, loss_conu_meter.avg, mask_meter.avg


def evaluate(epoch, model, dataloader, criterion, logger):
    model.eval()

    loss_meter = AverageMeter()
    top1_meter = AverageMeter()
    top5_meter = AverageMeter()

    with torch.no_grad():
        for ims, lbs in dataloader:
            ims = ims.cuda()
            lbs = lbs.cuda()
            logits, _ = model(ims)
            loss = criterion(logits, lbs)
            scores = torch.softmax(logits, dim=1)
            top1, top5 = accuracy(scores, lbs, (1, 5))
            loss_meter.update(loss.item())
            top1_meter.update(top1.item())
            top5_meter.update(top5.item())

    logger.info("Test Epoch:{}. Top1: {:.4f}. Top5: {:.4f}. Loss: {:.4f}.".format(epoch, top1_meter.avg, top5_meter.avg, loss_meter.avg))

    return top1_meter.avg, top5_meter.avg, loss_meter.avg


def evaluate_ema(epoch, ema, dataloader, criterion, logger):
    # using EMA params to evaluate performance
    ema.apply_shadow()
    ema.model.eval()
    ema.model.cuda()

    loss_meter = AverageMeter()
    top1_meter = AverageMeter()
    top5_meter = AverageMeter()

    with torch.no_grad():
        for ims, lbs in dataloader:
            ims = ims.cuda()
            lbs = lbs.cuda()
            logits, _ = ema.model(ims)
            loss = criterion(logits, lbs)
            scores = torch.softmax(logits, dim=1)
            top1, top5 = accuracy(scores, lbs, (1, 5))
            loss_meter.update(loss.item())
            top1_meter.update(top1.item())
            top5_meter.update(top5.item())

    # note roll back model current params to continue training
    ema.restore()

    logger.info("Test Epoch:{}. Top1: {:.4f}. Top5: {:.4f}. Loss: {:.4f}.".format(epoch, top1_meter.avg, top5_meter.avg, loss_meter.avg))

    return top1_meter.avg, top5_meter.avg, loss_meter.avg


#############################################################################################
# Options
#############################################################################################
parser = argparse.ArgumentParser(description='Semi-supervised Learning')

parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0  0,1,2  0,2')

parser.add_argument('--run_dir', type=str, default='run', help='the directory for training')
parser.add_argument('--name',default='RankingMatch_Contrastive', type=str, help='output model name')
parser.add_argument('--log_interval', type=int, default=10, help='log training status')
parser.add_argument('--save_epoch', type=int, default=64, help='step to save model')

parser.add_argument('--resume',default=None, type=str, help='checkpoint to continue training')

parser.add_argument('--backbone', type=str, default='wideresnet', help='Wideresnet')
parser.add_argument('--wresnet-k', default=2, type=int, help='width factor of wide resnet')
parser.add_argument('--wresnet-n', default=28, type=int, help='depth of wide resnet')
parser.add_argument('--large_model', action='store_true', help='default is False. If True, using WideResnetLarge model')

parser.add_argument('--dataset', type=str, default='CIFAR10', help='CIFAR10, CIFAR100, SVHN, STL10, or TinyImageNet')
parser.add_argument('--num_label', type=int, default=40, help='number of labeled samples for training')
parser.add_argument('--num_val', type=int, default=5000, help='number of samples of cross-validation set')

parser.add_argument('--fold', type=int, default=0, help='used for STL10. This is to pick respective 1000-examples fold')

parser.add_argument('--start_epoch', type=int, default=0, help='epoch to start training')
parser.add_argument('--epochs', type=int, default=1024, help='number of training epoches')
parser.add_argument('--batchsize', type=int, default=64, help='train batch size of labeled samples')
parser.add_argument('--valbatchsize', default=100, type=int, help='validation batch size')
parser.add_argument('--mu', type=int, default=7, help='factor of train batch size of unlabeled samples')
    
parser.add_argument('--thr', type=float, default=0.95, help='threshold for picking high-confidence predictions')

parser.add_argument('--k_imgs', type=int, default=64 * 1024, help='number of training images for each epoch')
parser.add_argument('--lam-u', type=float, default=1., help='weight of cross-entropy loss for unlabeled data')
parser.add_argument('--lam-contrast', type=float, default=1., help='weight of contrastive loss')
parser.add_argument('--ema-alpha', type=float, default=0.999, help='decay rate for ema module')
parser.add_argument('--lr', type=float, default=0.03, help='base learning rate for training')
parser.add_argument('--weight-decay', type=float, default=5e-4, help='weight decay')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum for optimizer')
parser.add_argument('--warmup', default=0, type=float, help='warmup epochs (unlabeled data based)')

parser.add_argument('--seed', type=int, default=-1, help='seed for random behaviors, no seed if negtive')
parser.add_argument('--seed_init', type=int, default=10000, help='the number to feed to seed()')

parser.add_argument('--temperature', type=float, default=0.2, help='temperature for contrastive loss')

parser.add_argument('--l2_norm', action='store_false', help='default is True. If True, L2 normalization is used')

parser.add_argument('--cudnn_deter', action='store_false', help='default is True. If True, use deterministic functions as much as possible')
parser.add_argument('--cudnn_bench', action='store_true', help='default is False. If True, program may run faster')
#############################################################################################


def main():
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_ids

    # Create folder for training and copy important files
    if not os.path.exists(args.run_dir):
        os.makedirs(args.run_dir)
    time_str = time.strftime('%Y-%m-%d-%H-%M')
    args.out_dir = os.path.join(args.run_dir, args.name + '_' + args.backbone + '_' + args.dataset + '-{}-{}'.format(args.num_label, args.mu) + '_bz' + str(args.batchsize) + '_epoch' + str(args.epochs) + '_' + time_str)
    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)
    else:
        print('Error: the folder name is dupplicated!')
        return

    if not os.path.exists(args.out_dir + '/src'):
        os.makedirs(args.out_dir + '/src')

    tb_log_dir = os.path.join(args.out_dir, 'tensorboard')

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    logger = create_logger(args.out_dir, args.name.lower(), time_str) 

    shutil.copy2(
        os.path.basename(__file__),
        args.out_dir)
    for f in os.listdir('.'):
        if f.endswith('.py'):
            shutil.copy2(
                f,
                args.out_dir + '/src')
    folders = ['datasets', 'models']
    for folder in folders:
        if not os.path.exists(args.out_dir + '/src/{}'.format(folder)):
            os.makedirs(args.out_dir + '/src/{}'.format(folder))
        for f in os.listdir(folder):
            if f.endswith('.py'):
                shutil.copy2(
                    '{}/'.format(folder) + f,
                    args.out_dir + '/src/{}'.format(folder))


    # global settings
    if args.seed > 0:
        logger.info('Seed number is given by user: {}'.format(args.seed))
    else:
        args.seed = random.randrange(args.seed_init)
        logger.info('Seed number is randomly generated: {}'.format(args.seed))

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.cudnn_deter
    torch.backends.cudnn.benchmark = args.cudnn_bench
    torch.cuda.manual_seed(args.seed)

    args.n_iters_per_epoch = args.k_imgs // args.batchsize  
    args.n_iters_all = args.n_iters_per_epoch * args.epochs  

    logger.info("***** Running training *****")
    logger.info(f"  Task = {args.dataset}@{args.num_label}")
    logger.info(f"  Num Epochs = {args.n_iters_per_epoch}")
    logger.info(f"  Batch size per GPU = {args.batchsize}")
    logger.info(f"  Total optimization steps = {args.n_iters_all}")

    args.n_classes, args.num_val = get_params(args.dataset)

    if args.dataset == 'STL10':
        args.num_val = args.fold

    model = set_model(args.n_classes, args.wresnet_k, args.wresnet_n, stl=True) if ((args.dataset == 'STL10') or (args.dataset == 'TinyImageNet')) else set_model(args.n_classes, args.wresnet_k, args.wresnet_n, large=args.large_model)
    logger.info("Total params: {:.2f}M".format(
        sum(p.numel() for p in model.parameters()) / 1e6))
    criterion = RankingMatchLoss()

    dltrain_x, dltrain_u, dlval = get_data(args)

    ema = EMA(model, args.ema_alpha)

    wd_params, non_wd_params = [], []
    for name, param in model.named_parameters():
        if 'bn' in name:
            non_wd_params.append(param)  # bn.weight, bn.bias and classifier.bias
        else:
            wd_params.append(param)
    param_list = [
        {'params': wd_params}, {'params': non_wd_params, 'weight_decay': 0}]
    optim = torch.optim.SGD(param_list, lr=args.lr, weight_decay=args.weight_decay,
                            momentum=args.momentum, nesterov=True)
    lr_schdlr = WarmupCosineLrScheduler(
        optim, max_iter=args.n_iters_all, warmup_iter=args.warmup
    )

    best_acc_val = -1
    best_epoch_val = 0
    best_acc_ema_val = -1
    best_epoch_ema_val = 0

    if args.resume != None:
        logger.info('Continue training from checkpoint {}.'.format(args.resume))
        state = torch.load(args.resume)

        args.start_epoch = state['epoch']

        logger.info('lr_schdlr before loaded. max_iter: {}. last_epoch: {}. lr: {}. base_lr: {}.'.format(lr_schdlr.max_iter, lr_schdlr.last_epoch, lr_schdlr.get_lr(), lr_schdlr.base_lrs))

        lr_log = [pg['lr'] for pg in optim.param_groups]
        lr_log = sum(lr_log) / len(lr_log)
        logger.info('optim before loaded. lr: {}.'.format(lr_log))

        lr_schdlr.load_state_dict(state['scheduler'])
        optim.load_state_dict(state['optimizer'])

        logger.info('lr_schdlr after loaded. max_iter: {}. last_epoch: {}. lr: {}. base_lr: {}.'.format(lr_schdlr.max_iter, lr_schdlr.last_epoch, lr_schdlr.get_lr(), lr_schdlr.base_lrs))

        lr_log = [pg['lr'] for pg in optim.param_groups]
        lr_log = sum(lr_log) / len(lr_log)
        logger.info('optim after loaded. lr: {}.'.format(lr_log))

        model.load_state_dict(state['state_dict'])
        ema.shadow = state['ema_state_dict']

        best_acc_val = state['best_top1_val']
        best_epoch_val = state['best_epoch_val']
        best_acc_ema_val = state['best_top1_ema_val']
        best_epoch_ema_val = state['best_epoch_ema_val']

        logger.info("Val. best_acc_val: {:.4f} at epoch {}. best_acc_ema_val: {:.4f} at epoch {}.".format(best_acc_val, best_epoch_val, best_acc_ema_val, best_epoch_ema_val))

    train_args = dict(
        model=model,
        criterion=criterion,
        optim=optim,
        lr_schdlr=lr_schdlr,
        ema=ema,
        dltrain_x=dltrain_x,
        dltrain_u=dltrain_u,
        lambda_u=args.lam_u,
        lambda_con=args.lam_contrast,
        n_iters=args.n_iters_per_epoch,
        logger=logger,
        log_interval=args.log_interval,
        threshold=args.thr,
        temperature=args.temperature,
        l2_norm=args.l2_norm
    )

    logger.info(pprint.pformat(args))

    logger.info('-----------start training--------------')
    for epoch in range(args.start_epoch, args.epochs):
        train_loss, loss_x, loss_u, loss_conx, loss_conu, mask_mean = train_one_epoch(epoch, **train_args)
        # torch.cuda.empty_cache()

        writer_dict['writer'].add_scalars('train_loss', {
                                            'loss': train_loss,
                                            'loss_x': loss_x,
                                            'loss_u': loss_u,
                                            'loss_conx': loss_conx,
                                            'loss_conu': loss_conu,
                                            'mask_mean': mask_mean
                                        }, epoch)

        logger.info('start validation.')
        logger.info('testing on val_loader using normal model...')
        top1_val, top5_val, valid_loss_val = evaluate(epoch, model, dlval, nn.CrossEntropyLoss().cuda(), logger)

        logger.info('testing on val_loader using ema model...')
        top1_ema_val, top5_ema_val, valid_loss_ema_val = evaluate_ema(epoch, ema, dlval, nn.CrossEntropyLoss().cuda(), logger)

        writer_dict['writer'].add_scalars('val_acc', {
                                            'top1_val': top1_val,
                                            'top5_val': top5_val,
                                            'top1_ema_val': top1_ema_val,
                                            'top5_ema_val': top5_ema_val
                                        }, epoch)
        writer_dict['writer'].add_scalars('val_loss', {
                                            'loss_val': valid_loss_val,
                                            'loss_ema_val': valid_loss_ema_val
                                        }, epoch)

        is_best_val = best_acc_val < top1_val
        if is_best_val:
            best_acc_val = top1_val
            best_epoch_val = epoch

        is_best_ema_val = best_acc_ema_val < top1_ema_val
        if is_best_ema_val:
            best_acc_ema_val = top1_ema_val
            best_epoch_ema_val = epoch

        logger.info('saving checkpoint...')
        torch.save({
            'epoch': epoch+1,
            'state_dict': model.state_dict(),
            'ema_state_dict': ema.shadow,   # not ema.model.state_dict()
            'top1_val': top1_val,
            'best_top1_val': best_acc_val,
            'best_epoch_val': best_epoch_val,
            'top1_ema_val': top1_ema_val,
            'best_top1_ema_val': best_acc_ema_val,
            'best_epoch_ema_val': best_epoch_ema_val,
            'optimizer': optim.state_dict(),
            'scheduler': lr_schdlr.state_dict(),
        },
        os.path.join(args.out_dir, args.name + '_checkpoint'))

        if is_best_val:
            logger.info('saving best normal model on val_loader...')
            torch.save(model.state_dict(), os.path.join(args.out_dir, args.name + '_bestval'))

        if is_best_ema_val:
            logger.info('saving best ema model on val_loader...')
            torch.save(ema.shadow, os.path.join(args.out_dir, args.name + '_ema_bestval')) # not ema.model.state_dict()

        if (epoch + 1) % args.save_epoch == 0:
            logger.info('saving models at epoch {}...'.format(epoch))
            torch.save(model.state_dict(), os.path.join(args.out_dir, args.name + '_e{}'.format(epoch)))
            torch.save(ema.shadow, os.path.join(args.out_dir, args.name + '_ema_e{}'.format(epoch)))    # not ema.model.state_dict()

        logger.info("Val Epoch {}. best_acc_val: {:.4f} at epoch {}. best_acc_ema_val: {:.4f} at epoch {}.".format(epoch, best_acc_val, best_epoch_val, best_acc_ema_val, best_epoch_ema_val))

    writer_dict['writer'].close()


if __name__ == '__main__':
    main()
