from __future__ import print_function

import argparse
import os
import shutil
import time
import random

import numpy as np

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torch.nn.functional as F

import models.wideresnet as models1
import models.wideresnet2 as models2
import dataset.cifar10 as cifar10
import dataset.imagenet as imagenet
import dataset.mnist as mnist
import dataset.svhn as svhn
from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig
from tensorboardX import SummaryWriter

parser = argparse.ArgumentParser(description='PyTorch MixMatch Training')
# Optimization options
parser.add_argument('--epochs', default=128, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('--batch-size', default=64, type=int, metavar='N',
                    help='train batchsize')
parser.add_argument('--lr', '--learning-rate', default=0.002, type=float,
                    metavar='LR', help='initial learning rate')
# Checkpoints
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
# Miscs
parser.add_argument('--manualSeed', type=int, default=0, help='manual seed')
# Device options
parser.add_argument('--gpu', default='0', type=str,
                    help='id(s) for CUDA_VISIBLE_DEVICES')
# Method options
parser.add_argument('--n-labeled', type=int, default=250,
                    help='Number of labeled data')
parser.add_argument('--train-iteration', type=int, default=1024,
                    help='Number of iteration per epoch')
parser.add_argument('--out', default='result',
                    help='Directory to output the result')
parser.add_argument('--alpha', default=0.75, type=float)
parser.add_argument('--lambda-u', default=75, type=float)
parser.add_argument('--T', default=0.5, type=float)
parser.add_argument('--ema-decay', default=0.999, type=float)
parser.add_argument('--dataset', default='mnist', type=str)

args = parser.parse_args()
state = {k: v for k, v in args._get_kwargs()}

# Use CUDA
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
use_cuda = torch.cuda.is_available()

# Random seed
if args.manualSeed is None:
    args.manualSeed = random.randint(1, 10000)
np.random.seed(args.manualSeed)

best_acc = 0  # best test accuracy


def main():
    global best_acc

    if not os.path.isdir(args.out):
        mkdir_p(args.out)

    # Data
    if args.dataset == 'cifar10':
        print(f'==> Preparing cifar10')
        transform_train = transforms.Compose([
            cifar10.RandomPadandCrop(32),
            cifar10.RandomFlip(),
            cifar10.ToTensor(),
        ])

        transform_val = transforms.Compose([
            cifar10.ToTensor(),
        ])

        train_labeled_set, train_unlabeled_set, val_set, test_set = cifar10.get_cifar10(
            './data', args.n_labeled, transform_train=transform_train,
            transform_val=transform_val)
        labeled_trainloader = data.DataLoader(train_labeled_set,
                                              batch_size=args.batch_size,
                                              shuffle=True, num_workers=0,
                                              drop_last=True)
        unlabeled_trainloader = data.DataLoader(train_unlabeled_set,
                                                batch_size=args.batch_size,
                                                shuffle=True, num_workers=0,
                                                drop_last=True)
        val_loader = data.DataLoader(val_set, batch_size=args.batch_size,
                                     shuffle=False, num_workers=0)
        test_loader = data.DataLoader(test_set, batch_size=args.batch_size,
                                      shuffle=False, num_workers=0)
    elif args.dataset == 'imagenet':
        print(f'==> Preparing imagenet to attack cifar10')
        transform_train = transforms.Compose([
            transforms.Resize(32),
            imagenet.RandomPadandCrop(32),
            imagenet.RandomFlip(),
            imagenet.ToTensor(),
        ])

        transform_val = transforms.Compose([
            imagenet.ToTensor(),
        ])

        train_labeled_set, train_unlabeled_set, val_set, test_set = imagenet.get_cifar10(
            '/scratch/ssd002/datasets/imagenet256/', args.n_labeled, transform_train=transform_train,
            transform_val=transform_val)
        print("labeled", args.n_labeled)
        print(len(train_labeled_set))
        print(len(train_unlabeled_set))
        labeled_trainloader = data.DataLoader(train_labeled_set,
                                              batch_size=args.batch_size,
                                              shuffle=True, num_workers=0,
                                              drop_last=True)
        unlabeled_trainloader = data.DataLoader(train_unlabeled_set,
                                                batch_size=args.batch_size,
                                                shuffle=True, num_workers=0,
                                                drop_last=True)
        val_loader = data.DataLoader(val_set, batch_size=args.batch_size,
                                     shuffle=False, num_workers=0)
        test_loader = data.DataLoader(test_set, batch_size=args.batch_size, # this is from the cifar test set
                                      shuffle=False, num_workers=0)
    elif args.dataset == 'mnist':
        print(f'==> Preparing mnist')
        transform_train = transforms.Compose([
            mnist.RandomPadandCrop(28),
            mnist.RandomFlip(),
            mnist.ToTensor(),
        ])

        transform_val = transforms.Compose([
            mnist.ToTensor(),
        ])

        train_labeled_set, train_unlabeled_set, val_set, test_set = mnist.get_mnist(
            './data', args.n_labeled,
            transform_train=transform_train,
            transform_val=transform_val)
        labeled_trainloader = data.DataLoader(train_labeled_set,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=0, drop_last=True)
        unlabeled_trainloader = data.DataLoader(train_unlabeled_set,
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                num_workers=0, drop_last=True)
        val_loader = data.DataLoader(val_set, batch_size=args.batch_size,
                                     shuffle=False, num_workers=0)
        test_loader = data.DataLoader(test_set, batch_size=args.batch_size,
                                      shuffle=False, num_workers=0)
    elif args.dataset == 'svhn':
        print(f'==> Preparing svhn')
        transform_train = transforms.Compose([
            svhn.RandomPadandCrop(32),
            svhn.RandomFlip(),
            svhn.ToTensor(),
        ])

        transform_val = transforms.Compose([
            svhn.ToTensor(),
        ])

        train_labeled_set, train_unlabeled_set, val_set, test_set = svhn.get_svhn(
            './data', args.n_labeled, transform_train=transform_train,
            transform_val=transform_val)
        labeled_trainloader = data.DataLoader(train_labeled_set,
                                              batch_size=args.batch_size,
                                              shuffle=True, num_workers=0,
                                              drop_last=True)
        unlabeled_trainloader = data.DataLoader(train_unlabeled_set,
                                                batch_size=args.batch_size,
                                                shuffle=True, num_workers=0,
                                                drop_last=True)
        val_loader = data.DataLoader(val_set, batch_size=args.batch_size,
                                     shuffle=False, num_workers=0)
        test_loader = data.DataLoader(test_set, batch_size=args.batch_size,
                                      shuffle=False, num_workers=0)
    # Model
    if args.dataset == 'cifar10' or args.dataset == "imagenet":
        print("==> creating WRN-28-2")

        def create_model(ema=False):
            model = models1.WideResNet(num_classes=10)
            model = model.cuda()

            if ema:
                for param in model.parameters():
                    param.detach_()

            return model
    elif args.dataset == 'svhn':
        print("==> creating WRN-28-2")

        def create_model(ema=False):
            model = models1.WideResNet(num_classes=10)
            model = model.cuda()

            if ema:
                for param in model.parameters():
                    param.detach_()

            return model
    elif args.dataset == 'mnist':
        print("==> creating WRN for MNIST")

        def create_model(ema=False):  # Add another architecture to do this.
            model = models2.WideResNet(num_classes=10)
            model = model.cuda()

            if ema:
                for param in model.parameters():
                    param.detach_()

            return model

    model = create_model()
    ema_model = create_model(ema=True)

    cudnn.benchmark = True
    print('    Total params: %.2fM' % (
                sum(p.numel() for p in model.parameters()) / 1000000.0))

    train_criterion = SemiLoss()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    ema_optimizer = WeightEMA(model, ema_model, alpha=args.ema_decay)
    start_epoch = 0

    # Resume
    if args.dataset == 'cifar10' or args.dataset == "imagenet":
        title = 'noisy-cifar-10'
    elif args.dataset == 'mnist':
        title = 'noisy-mnist'
    elif args.dataset == 'svhn':
        title = 'svhn'
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.out = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        ema_model.load_state_dict(checkpoint['ema_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.out, 'log.txt'), title=title,
                        resume=True)
    else:
        logger = Logger(os.path.join(args.out, 'log.txt'), title=title)
        logger.set_names(
            ['Train Loss', 'Train Loss X', 'Train Loss U', 'Valid Loss',
             'Valid Acc.', 'Test Loss', 'Test Acc.'])

    writer = SummaryWriter(args.out)
    step = 0
    test_accs = []
    # Train and val
    for epoch in range(start_epoch, args.epochs):
        print(
            '\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr']))

        train_loss, train_loss_x, train_loss_u = train(labeled_trainloader,
                                                       unlabeled_trainloader,
                                                       model, optimizer,
                                                       ema_optimizer,
                                                       train_criterion, epoch,
                                                       use_cuda)
        _, train_acc = validate(labeled_trainloader, ema_model, criterion,
                                epoch, use_cuda, mode='Train Stats')
        val_loss, val_acc = validate(val_loader, ema_model, criterion, epoch,
                                     use_cuda, mode='Valid Stats')
        test_loss, test_acc = validate(test_loader, ema_model, criterion, epoch,
                                       use_cuda, mode='Test Stats ')

        step = args.train_iteration * (epoch + 1)

        writer.add_scalar('losses/train_loss', train_loss, step)
        writer.add_scalar('losses/valid_loss', val_loss, step)
        writer.add_scalar('losses/test_loss', test_loss, step)

        writer.add_scalar('accuracy/train_acc', train_acc, step)
        writer.add_scalar('accuracy/val_acc', val_acc, step)
        writer.add_scalar('accuracy/test_acc', test_acc, step)

        # append logger file
        logger.append(
            [train_loss, train_loss_x, train_loss_u, val_loss, val_acc,
             test_loss, test_acc])

        # save model
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'ema_state_dict': ema_model.state_dict(),
            'acc': val_acc,
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict(),
        }, is_best)
        test_accs.append(test_acc.cpu()) # test_acc
    logger.close()
    writer.close()

    print('Best acc:')
    print(best_acc)

    print('Mean acc:')
    print(np.mean(test_accs[-20:]))


def train(labeled_trainloader, unlabeled_trainloader, model, optimizer,
          ema_optimizer, criterion, epoch, use_cuda):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_x = AverageMeter()
    losses_u = AverageMeter()
    ws = AverageMeter()
    end = time.time()

    bar = Bar('Training', max=args.train_iteration)
    labeled_train_iter = iter(labeled_trainloader)
    unlabeled_train_iter = iter(unlabeled_trainloader)

    model.train()
    for batch_idx in range(args.train_iteration):
        try:
            inputs_x, targets_x = labeled_train_iter.next()  # xb, pb
        except:
            labeled_train_iter = iter(labeled_trainloader)
            inputs_x, targets_x = labeled_train_iter.next()

        try:
            (inputs_u, inputs_u2), _ = unlabeled_train_iter.next()  # ubhat
            #(inputs_u, inputs_u2) = unlabeled_train_iter.next()
        except:
            unlabeled_train_iter = iter(unlabeled_trainloader)
            (inputs_u, inputs_u2), _ = unlabeled_train_iter.next()
            #(inputs_u, inputs_u2) = unlabeled_train_iter.next()

        # measure data loading time
        data_time.update(time.time() - end)

        batch_size = inputs_x.size(0)

        # Transform label to one-hot
        targets_x = torch.zeros(batch_size, 10).scatter_(1, targets_x.view(-1,
                                                                           1).long(),
                                                         1)  # to_one_hot (done later in our code)

        if use_cuda:
            inputs_x, targets_x = inputs_x.cuda(), targets_x.cuda(
                non_blocking=True)  # Already in cuda (I think. If not an error should come)
            inputs_u = inputs_u.cuda()
            inputs_u2 = inputs_u2.cuda()

        with torch.no_grad():
            # compute guessed labels of unlabeled samples
            outputs_u = model(inputs_u)
            outputs_u2 = model(inputs_u2)
            p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2,
                                                                 dim=1)) / 2  # qb (before qb.cpu called)
            pt = p ** (1 / args.T)  # temp2 (just written in a different way)
            targets_u = pt / pt.sum(dim=1, keepdim=True)  # qb after total
            targets_u = targets_u.detach()  # removes tracking of gradients for this

        # mixup
        all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0)
        all_targets = torch.cat([targets_x, targets_u, targets_u],
                                dim=0)  # we have one variable which contains both inputs and targets (W)

        l = np.random.beta(args.alpha, args.alpha)

        l = max(l, 1 - l)

        idx = torch.randperm(all_inputs.size(0))

        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]

        mixed_input = l * input_a + (1 - l) * input_b
        mixed_target = l * target_a + (1 - l) * target_b

        # interleave labeled and unlabed samples between batches to get correct batchnorm calculation 
        mixed_input = list(torch.split(mixed_input, batch_size))
        mixed_input = interleave(mixed_input, batch_size)

        logits = [model(mixed_input[0])]
        for input in mixed_input[1:]:
            logits.append(model(input))

        # put interleaved samples back
        logits = interleave(logits, batch_size)
        logits_x = logits[0]
        logits_u = torch.cat(logits[1:], dim=0)

        Lx, Lu, w = criterion(logits_x, mixed_target[:batch_size], logits_u,
                              mixed_target[batch_size:],
                              epoch + batch_idx / args.train_iteration)

        loss = Lx + w * Lu

        # record loss
        losses.update(loss.item(), inputs_x.size(0))
        losses_x.update(Lx.item(), inputs_x.size(0))
        losses_u.update(Lu.item(), inputs_x.size(0))
        ws.update(w, inputs_x.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ema_optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # plot progress
        bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Loss_x: {loss_x:.4f} | Loss_u: {loss_u:.4f} | W: {w:.4f}'.format(
            batch=batch_idx + 1,
            size=args.train_iteration,
            data=data_time.avg,
            bt=batch_time.avg,
            total=bar.elapsed_td,
            eta=bar.eta_td,
            loss=losses.avg,
            loss_x=losses_x.avg,
            loss_u=losses_u.avg,
            w=ws.avg,
        )
        bar.next()
    bar.finish()

    return (losses.avg, losses_x.avg, losses_u.avg,)


def validate(valloader, model, criterion, epoch, use_cuda, mode):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    bar = Bar(f'{mode}', max=len(valloader))
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(valloader):
            # measure data loading time
            data_time.update(time.time() - end)

            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
            # compute output
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # measure accuracy and record loss
            # prec1, prec5 = accuracy(outputs, targets, topk=(1,5))
            prec1 = accuracy(outputs, targets, topk=(1,))
            losses.update(loss.item(), inputs.size(0))
            # top1.update(prec1.item(), inputs.size(0))
            top1.update(prec1[0], inputs.size(0))
            # top5.update(prec5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            # bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
            #             batch=batch_idx + 1,
            #             size=len(valloader),
            #             data=data_time.avg,
            #             bt=batch_time.avg,
            #             total=bar.elapsed_td,
            #             eta=bar.eta_td,
            #             loss=losses.avg,
            #             top1=top1.avg,
            #             top5=top5.avg,
            #             )
            bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f}'.format(
                batch=batch_idx + 1,
                size=len(valloader),
                data=data_time.avg,
                bt=batch_time.avg,
                total=bar.elapsed_td,
                eta=bar.eta_td,
                loss=losses.avg,
                top1=top1.avg,
            )
            bar.next()
        bar.finish()
    return (losses.avg, top1.avg)


def save_checkpoint(state, is_best, checkpoint=args.out,
                    filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath,
                        os.path.join(checkpoint, 'model_best.pth.tar'))


def linear_rampup(current, rampup_length=args.epochs):
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current / rampup_length, 0.0, 1.0)
        return float(current)


class SemiLoss(object):
    def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch):
        probs_u = torch.softmax(outputs_u, dim=1)

        Lx = -torch.mean(
            torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
        Lu = torch.mean((probs_u - targets_u) ** 2)

        return Lx, Lu, args.lambda_u * linear_rampup(epoch)


class WeightEMA(object):
    def __init__(self, model, ema_model, alpha=0.999):
        self.model = model
        self.ema_model = ema_model
        self.alpha = alpha
        self.params = list(model.state_dict().values())
        self.ema_params = list(ema_model.state_dict().values())
        self.wd = 0.02 * args.lr

        for param, ema_param in zip(self.params, self.ema_params):
            param.data.copy_(ema_param.data)

    def step(self):
        one_minus_alpha = 1.0 - self.alpha
        for param, ema_param in zip(self.params, self.ema_params):
            if ema_param.dtype == torch.float32:
                ema_param.mul_(self.alpha)
                ema_param.add_(param * one_minus_alpha)
                # customized weight decay
                param.mul_(1 - self.wd)


def interleave_offsets(batch, nu):
    groups = [batch // (nu + 1)] * (nu + 1)
    for x in range(batch - sum(groups)):
        groups[-x - 1] += 1
    offsets = [0]
    for g in groups:
        offsets.append(offsets[-1] + g)
    assert offsets[-1] == batch
    return offsets


def interleave(xy, batch):
    nu = len(xy) - 1
    offsets = interleave_offsets(batch, nu)
    xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
    for i in range(1, nu + 1):
        xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
    return [torch.cat(v, dim=0) for v in xy]


if __name__ == '__main__':
    main()
