#!/usr/bin/env python3 -u

from __future__ import print_function

import argparse
import csv
import os, logging
import copy
import random
from collections import OrderedDict

import numpy as np
import torch
from torch.autograd import Variable, grad
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms

import models
from utils import progress_bar, set_logging_defaults
from datasets import load_dataset

# torch_version = [int(v) for v in torch.__version__.split('.')]
tensorboardX_compat = True #(torch_version[0] >= 1) and (torch_version[1] >= 1) # PyTorch >= 1.1
try:
    from tensorboardX import SummaryWriter
except ImportError:
    print ('No tensorboardX package is found. Start training without tensorboardX')
    tensorboardX_compat = False
    #raise RuntimeError("No tensorboardX package is found. Please install with the command: \npip install tensorboardX")

parser = argparse.ArgumentParser(description='ReMixMatch + OpenCoS Training')
parser.add_argument('--lr', default=0.002, type=float, help='learning rate')
parser.add_argument('--model', default="wide_resnet", type=str,
                    help='model type (default: wide_resnet)')
parser.add_argument('--name', default='0', type=str, help='name of run')
parser.add_argument('--batch-size', default=64, type=int, help='batch size')
parser.add_argument('--num_iters', default=50000, type=int, help='total epochs to run')
parser.add_argument('--decay', default=0, type=float, help='weight decay')
parser.add_argument('--ngpu', default=1, type=int, help='number of gpu')
parser.add_argument('--sgpu', default=0, type=int, help='gpu index (start)')
parser.add_argument('--dataset', default='cifar10', type=str, help='the name for dataset')
parser.add_argument('--udata', default='svhn', type=str, help='type of unlabel data')
parser.add_argument('--tinyroot', default='/data/tinyimagenet/tiny-imagenet-200/', type=str, help='TinyImageNet directory')
parser.add_argument('--imgroot', default='/data/ILSVRC/Data/CLS-LOC/', type=str, help='unlabel data directory')
parser.add_argument('--dataroot', default='/data/', type=str, help='data directory')
parser.add_argument('--saveroot', default='./results', type=str, help='data directory')
parser.add_argument('--temp', default=0.5, type=float, help='temperature scaling')
parser.add_argument('--finetune', '-ft', action='store_true', help='finetuning')
parser.add_argument('--pc', default=25, type=int, help='number of samples per class')
parser.add_argument('--ema', action='store_true', help='EMA training')
parser.add_argument('--no_alignment', action='store_true', help='no distribution alignment')
parser.add_argument('--naug', default=8, type=int, help='superclass indices')
parser.add_argument('--lmd_u', default=1., type=float, help='Lu loss weight')
parser.add_argument('--lmd_pre', default=1., type=float, help='Lu loss weight')
parser.add_argument('--lmd_rot', default=0.5, type=float, help='rotation loss weight')
parser.add_argument('--nworkers', default=4, type=int, help='num_workers')
parser.add_argument('--method', default='remixmatch', type=str, help='data directory')

parser.add_argument('--sim_path', default=None, type=str, help='saved similarity')
parser.add_argument('--model_path', default=None, type=str, help='(unsupervised) pretrained model path')
parser.add_argument('--ood_samples', default=0, type=int, help='number of ood samples in [0,10000,20000,30000,40000]')
parser.add_argument('--fix_optim', action='store_true', help='using optimizer of FixMatch')

parser.add_argument('--lmd_unif', default=1., type=float, help='smoothing loss weight')
parser.add_argument('--aux_divide', action='store_true', help='divide bn parameters')
parser.add_argument('--ths', default=1., type=float, help='parameter for threshold')
parser.add_argument('--temp_s2', default=1, type=float, help='temperature scaling')
parser.add_argument('--total_unlabel', action='store_true', help='using total unlabel data')
parser.add_argument('--stop_iters', default=None, type=int, help='early stopping')
parser.add_argument('--use_jitter', action='store_true', help='using jitter augmentation for unlabeled data')
parser.add_argument('--simclr_optim', action='store_true', help='using optimizer of SimCLR semi finetune')
parser.add_argument('--no_head', action='store_true', help='not using the mlp head of simclr')
args = parser.parse_args()
use_cuda = torch.cuda.is_available()

best_val = 0  # best validation accuracy
best_val_ema = 0  # best validation accuracy
start_iters = 0  # start from epoch 0 or last checkpoint epoch
current_val = 0
current_val_ema = 0

cudnn.benchmark = True

# Data
_labeled_trainset, _unlabeled_trainset, _labeled_testset = load_dataset(args.dataset, args.dataroot, batch_size=args.batch_size, pc=str(args.pc), method=args.method, naug=args.naug, uroot=args.udata, tinyroot=args.tinyroot, imgroot=args.imgroot, ood_samples=args.ood_samples, use_jitter=args.use_jitter)
_labeled_num_class = _labeled_trainset.num_classes
print('Numclass: ', _labeled_num_class)
print('==> Preparing dataset: {}'.format(args.dataset))
print('Number of label dataset: ' ,len(_labeled_trainset))
print('Number of unlabel dataset: ',len(_unlabeled_trainset))
print('Number of test dataset: ',len(_labeled_testset))

logdir = os.path.join(args.saveroot, args.dataset, args.model, args.name)
set_logging_defaults(logdir, args)
logger = logging.getLogger('main')
logname = os.path.join(logdir, 'log.csv')
if tensorboardX_compat:
    writer = SummaryWriter(logdir=logdir)


if use_cuda:
    torch.cuda.set_device(args.sgpu)
    print(torch.cuda.device_count())
    print('Using CUDA..')

criterion = nn.CrossEntropyLoss()

def get_embed(net, head, dataloader):
    net.eval()

    emb = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.cuda(), targets.cuda()

            if head is None:
                outputs = net.feature(inputs)
            else:
                outputs = head(net.feature(inputs))
            emb.append(outputs.cpu())
    emb = torch.cat(emb, dim=0)

    return emb

def get_embed_center(net, head, dataloader):
    net.eval()

    emb = []
    centers = [0 for c in range(_labeled_num_class)]
    cnt = [0 for c in range(_labeled_num_class)]
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.cuda(), targets.cuda()

            if head is None:
                outputs = net.feature(inputs)
            else:
                outputs = head(net.feature(inputs))
            emb.append(outputs.cpu())

            for ii in range(targets.size(0)):
                cnt[targets[ii].item()] = cnt[targets[ii].item()] + 1
                centers[targets[ii].item()] = centers[targets[ii].item()] + outputs[ii].cpu()
    for c in range(_labeled_num_class):
        centers[c] = (centers[c] / cnt[c]).unsqueeze(0)
    centers = torch.cat(centers, dim=0)
    emb = torch.cat(emb, dim=0)

    return emb, centers

def cosine_similarity(x1, x2, eps=1e-12):
    w1 = x1.norm(p=2, dim=1, keepdim=True)
    w2 = x2.norm(p=2, dim=1, keepdim=True)
    return torch.mm(x1, x2.t()) / (w1 * w2.t()).clamp(min=eps)

class MergeDataset(torch.utils.data.Dataset):
    def __init__(self, dataset1, dataset2):
        assert len(dataset1)==len(dataset2)
        self.dataset1 = dataset1
        self.dataset2 = dataset2

    def __getitem__(self, i):
        return self.dataset1[i][:2] + self.dataset2[i]

    def __len__(self):
        return len(self.dataset1)

class Projector(nn.Module):
    def __init__(self, dimsize=2048):
        super(Projector, self).__init__()

        self.linear_1 = nn.Linear(dimsize, dimsize)
        self.linear_2 = nn.Linear(dimsize, dimsize)

    def forward(self, x):
        output = self.linear_1(x)
        output = F.relu(output)
        output = self.linear_2(output)

        return output

def load_dataset_softlabel(_labeled_num_class, _labeled_trainset, _unlabeled_trainset, sim_path=None):
    print('Identifying out-of-class samples...')
    # Load checkpoint.
    if (args.model in ['resnet50', 'resnet50_auxbn']):
        net_t = models.load_model("resnet50", 1000)
        transform_test = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ])
        try:
            pretrained_dict = torch.load(args.model_path, map_location='cpu')['model']
        except KeyError:
            pretrained_dict = torch.load(args.model_path, map_location='cpu')['net']
        classifier = ['fc.weight', 'fc.bias']
        dimsize = 2048
    elif (args.model in ['CIFAR_ResNet50', 'CIFAR_ResNet50_AuxBN']):
        net_t = models.load_model('CIFAR_ResNet50', _labeled_num_class)
        transform_test = transforms.Compose([
                transforms.Resize(32),
                transforms.ToTensor(),
            ])
        try:
            pretrained_dict = torch.load(args.model_path, map_location='cpu')['model']
        except KeyError:
            pretrained_dict = torch.load(args.model_path, map_location='cpu')['net']
        classifier = ['linear.weight', 'linear.bias']
        dimsize = 2048
    elif (args.model in ['wide_resnet', 'wide_resnet_auxbn']):
        net_t = models.load_model('wide_resnet', _labeled_num_class)
        transform_test = transforms.Compose([
                transforms.Resize(32),
                transforms.ToTensor(),
            ])
        try:
            pretrained_dict = torch.load(args.model_path, map_location='cpu')['model']
        except KeyError:
            pretrained_dict = torch.load(args.model_path, map_location='cpu')['net']
        classifier = ['linear.weight', 'linear.bias']
        dimsize = 128
    model_dict = net_t.state_dict()
    new_state_dict = OrderedDict()
    for k, v in pretrained_dict.items():
        if k[:6]=='module':
            name = k[7:] # remove `module.`
        else:
            name = k
        new_state_dict[name] = v
    new_state_dict = {k: v for k, v in new_state_dict.items() if k in model_dict and k not in classifier}

    model_dict.update(new_state_dict)
    net_t.load_state_dict(model_dict)
    net_t.cuda()

    if args.no_head:
        head_t = None
    else:
        head_t = Projector(dimsize = dimsize)
        head_dict = head_t.state_dict()
        pretrained_dict = torch.load(args.model_path+'_projector', map_location='cpu')['model']
        new_state_dict = OrderedDict()
        for k, v in pretrained_dict.items():
            if k[:6]=='module':
                name = k[7:] # remove `module.`
            else:
                name = k
            new_state_dict[name] = v
        new_state_dict = {k: v for k, v in new_state_dict.items() if k in head_dict}
        head_dict.update(new_state_dict)
        head_t.load_state_dict(head_dict)
        head_t.cuda()

    if args.ngpu > 1:
        net_t = torch.nn.DataParallel(net_t, device_ids=list(range(args.sgpu, args.sgpu + args.ngpu)))
        if not args.no_head:
            head_t = torch.nn.DataParallel(head_t, device_ids=list(range(args.sgpu, args.sgpu + args.ngpu)))

    labeled_trainset = copy.deepcopy(_labeled_trainset.base_dataset) # full dataset
    labeled_trainset.transform = transform_test
    labeled_trainset = torch.utils.data.Subset(labeled_trainset, _labeled_trainset.indices) # slicing
    labeled_trainloader = torch.utils.data.DataLoader(labeled_trainset, batch_size=256, shuffle=False, num_workers=4)
    unlabeled_trainset = copy.deepcopy(_unlabeled_trainset)
    for u_dataset in unlabeled_trainset.datasets:
        if type(u_dataset)==torch.utils.data.Subset:
            u_dataset.dataset.transform = transform_test
        else:
            u_dataset.transform = transform_test
    unlabeled_trainloader = torch.utils.data.DataLoader(unlabeled_trainset, batch_size=256, shuffle=False, num_workers=4)

    if sim_path is None:
        emb_l, center_l = get_embed_center(net_t, head_t, labeled_trainloader)
        emb_u = get_embed(net_t, head_t, unlabeled_trainloader)
        sim_l = cosine_similarity(emb_l, center_l) # N_l x C
        sim_u = cosine_similarity(emb_u, center_l) # N_u x C
        torch.save(sim_l, os.path.join(logdir, 'similarity_l.pt'))
        torch.save(sim_u, os.path.join(logdir, 'similarity_u.pt'))
    else:
        sim_l = torch.load(os.path.join(sim_path, 'similarity_l.pt'))
        sim_u = torch.load(os.path.join(sim_path, 'similarity_u.pt'))

    del net_t, head_t

    return sim_l, sim_u


def cycle(iterable):
    while True:
        for x in iterable:
            yield x

def ema_train():
    # Model
    print('==> Building model: {}'.format(args.model))
    net = models.load_model(args.model, _labeled_num_class, divide=args.aux_divide)

    if args.finetune:
        model_dict = net.state_dict()
        if (args.model in ['resnet50', 'resnet50_auxbn']):
            try:
                pretrained_dict = torch.load(args.model_path, map_location='cpu')['model']
            except KeyError:
                pretrained_dict = torch.load(args.model_path, map_location='cpu')['net']

            classifier = ['fc.weight', 'fc.bias', 'linear_rot.weight', 'linear_rot.bias']
            imagesize = 224
        elif (args.model in ['CIFAR_ResNet50', 'CIFAR_ResNet50_AuxBN', 'wide_resnet', 'wide_resnet_auxbn']):
            try:
                pretrained_dict = torch.load(args.model_path, map_location='cpu')['model']
            except KeyError:
                pretrained_dict = torch.load(args.model_path, map_location='cpu')['net']
            classifier = ['linear.weight', 'linear.bias', 'linear_rot.weight', 'linear_rot.bias']
            imagesize = 32

        new_state_dict = OrderedDict()
        for k, v in pretrained_dict.items():
            if k[:6]=='module':
                name = k[7:] # remove `module.`
            else:
                name = k
            new_state_dict[name] = v
        new_state_dict = {k: v for k, v in new_state_dict.items() if k not in classifier}

        if (args.model in ['resnet50_auxbn']):
            tmp_state_dict = copy.deepcopy(new_state_dict)
            new_state_dict = OrderedDict()
            for k in model_dict.keys():
                pos = k.find('downsample') 
                if pos!=-1: 
                    if k[pos+13:pos+19]=='bn_aux': 
                        v = tmp_state_dict[k[:pos+12]+k[pos+19:]] 
                    elif k[pos+13:pos+15]=='bn': 
                        v = tmp_state_dict[k[:pos+12]+k[pos+15:]] 
                    else: 
                        v = tmp_state_dict[k] 
                else: 
                    pos = k.find('bn') 
                    if pos == -1:
                        if k in classifier:
                            continue
                        else:
                            v = tmp_state_dict[k] 
                    elif k[pos+4:pos+10]=='bn_aux': 
                        v = tmp_state_dict[k[:pos+3]+k[pos+10:]] 
                    elif k[pos+4:pos+6]=='bn': 
                        v = tmp_state_dict[k[:pos+3]+k[pos+6:]] 
                    else: 
                        print (pos) 
                        print (k) 
                        raise KeyError 
                new_state_dict[k] = v
        elif (args.model in ['CIFAR_ResNet50_AuxBN', 'wide_resnet_auxbn']):
            tmp_state_dict = copy.deepcopy(new_state_dict)
            new_state_dict = OrderedDict()
            for k in model_dict.keys():
                pos = k.find('shortcut') 
                if pos!=-1: 
                    if k[pos+11:pos+17]=='bn_aux': 
                        v = tmp_state_dict[k[:pos+10]+k[pos+17:]] 
                    elif k[pos+11:pos+13]=='bn': 
                        v = tmp_state_dict[k[:pos+10]+k[pos+13:]] 
                    else: 
                        v = tmp_state_dict[k] 
                else: 
                    pos = k.find('bn') 
                    if pos == -1:
                        if k in classifier:
                            continue
                        else:
                            v = tmp_state_dict[k] 
                    elif k[pos+4:pos+10]=='bn_aux': 
                        v = tmp_state_dict[k[:pos+3]+k[pos+10:]] 
                    elif k[pos+4:pos+6]=='bn': 
                        v = tmp_state_dict[k[:pos+3]+k[pos+6:]] 
                    else: 
                        print (pos) 
                        print (k) 
                        raise KeyError 
                new_state_dict[k] = v

        model_dict.update(new_state_dict)
        net.load_state_dict(model_dict)

    net_ema = copy.deepcopy(net)
    for param in net_ema.parameters():
        param.detach_()

    net.cuda()
    net_ema.cuda()
    print('    Total params: %.2fM' % (sum(p.numel() for p in net.parameters())/1000000.0))
    # print(net)
    if args.ngpu > 1:
        net = torch.nn.DataParallel(net, device_ids=list(range(args.sgpu, args.sgpu + args.ngpu)))
        net_ema = torch.nn.DataParallel(net_ema, device_ids=list(range(args.sgpu, args.sgpu + args.ngpu)))

    if args.simclr_optim:
        assert (not args.fix_optim)
        optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=0, nesterov=True)
    elif args.fix_optim:
        optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.decay, nesterov=True)
    else:
        optimizer = optim.Adam(net.parameters(), lr=args.lr) # weight decay in ema_optimizer

    ema_optimizer= WeightEMA(net, net_ema, alpha=0.999, wd=(not args.fix_optim and not args.simclr_optim))

    net.train()
    net_ema.train()

    if len(_labeled_trainset) < args.batch_size:
        rand_sampler = torch.utils.data.RandomSampler(_labeled_trainset, num_samples=args.batch_size, replacement=True)
        _labeled_trainloader = torch.utils.data.DataLoader(_labeled_trainset, batch_size=args.batch_size, sampler=rand_sampler, num_workers=0)
    else:
        _labeled_trainloader = torch.utils.data.DataLoader(_labeled_trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.nworkers, drop_last=True)
    _labeled_testloader = torch.utils.data.DataLoader(_labeled_testset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    _labeled_train_iter = iter(cycle(_labeled_trainloader))

    sim_l, sim_u = load_dataset_softlabel(_labeled_num_class, _labeled_trainset, _unlabeled_trainset, args.sim_path)
    scores_l = torch.max(sim_l, dim=1)[0]
    mean = torch.mean(scores_l)
    std = torch.std(scores_l)
    sample = (torch.max(sim_u, dim=1)[0] > mean - args.ths * std)

    id_sample_idx = torch.tensor(list(range(len(_unlabeled_trainset))))[sample]
    ood_sample_idx = torch.tensor(list(range(len(_unlabeled_trainset))))[~sample]

    #assert args.ood_samples > 0
    if args.total_unlabel:
        id_trainset = _unlabeled_trainset
        print('Number of total unlabel dataset: ',len(id_trainset))
    else:
        id_trainset = torch.utils.data.Subset(_unlabeled_trainset, id_sample_idx)
        print('Number of ID unlabel dataset: ',len(id_trainset))


    unlabeled_trainset = copy.deepcopy(_unlabeled_trainset)
    unlabeled_trainset = MergeDataset(unlabeled_trainset, torch.utils.data.TensorDataset(sim_u))
    ood_trainset = torch.utils.data.Subset(unlabeled_trainset, ood_sample_idx)
    print('Number of OOD unlabel dataset: ',len(ood_trainset))


    ood_trainloader = torch.utils.data.DataLoader(ood_trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.nworkers, drop_last=True)
    ood_train_iter = iter(cycle(ood_trainloader))
    id_trainloader = torch.utils.data.DataLoader(id_trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.nworkers, drop_last=True)
    id_train_iter = iter(cycle(id_trainloader))

    train_loss = 0
    correct = 0
    total = 0

    run_iters = args.num_iters if args.stop_iters is None else args.stop_iters
    for batch_idx in range(start_iters, run_iters + 1):
        (inputs, inputs_aug), targets = next(_labeled_train_iter)
        (inputs_o, inputs_o_strong), _ = next(id_train_iter)
        (inputs_s, inputs_s_strong), _, sim_ood = next(ood_train_iter)

        if use_cuda:
            inputs = inputs.cuda()
            inputs_o = inputs_o.cuda()
            inputs_o_strong = [x.cuda() for x in inputs_o_strong]
            inputs_s = inputs_s.cuda()

        targets = torch.zeros(args.batch_size, _labeled_num_class).scatter_(1, targets.view(-1,1), 1).cuda()
        sim_ood = sim_ood.cuda()
        targets_s = torch.softmax(sim_ood / args.temp_s2, dim=1)

        with torch.no_grad():
            outputs_o = net(inputs_o)
            targets_o = torch.softmax(outputs_o, dim=1)
            if args.no_alignment:
                p = targets_o
            else:
                if batch_idx == 0:
                    label_avg = targets.mean(0, keepdim=True)
                    unlabel_avg = targets_o.mean(0, keepdim=True)
                elif batch_idx // 128 == 0:
                    label_avg = torch.cat([label_avg, targets.mean(0, keepdim=True)],0)
                    unlabel_avg = torch.cat([unlabel_avg, targets_o.mean(0, keepdim=True)],0)
                else:
                    label_avg = label_avg[-127:]
                    unlabel_avg = unlabel_avg[-127:]
                    label_avg = torch.cat([label_avg, targets.mean(0, keepdim=True)],0)
                    unlabel_avg = torch.cat([unlabel_avg, targets_o.mean(0, keepdim=True)],0)
                    if len(label_avg) != 128:
                        print(len(label_avg))
                        assert(False)
                    if len(unlabel_avg) != 128:
                        print(len(unlabel_avg))
                        assert(False)

                target_ankor = (1e-6 + label_avg.mean(0)) / (1e-6 + unlabel_avg.mean(0))
                p = targets_o * target_ankor[None].detach()
            p = p / p.sum(dim=1, keepdim=True)
            T = args.temp
            pt = p**(1/T) 
            targets_o = pt / pt.sum(dim=1, keepdim=True)
            targets_o = targets_o.detach()

        inputs_o1 = inputs_o_strong[0]
        all_inputs_o = torch.cat(inputs_o_strong, 0)
        all_targets_o = torch.cat([targets_o for i in range(len(inputs_o_strong))], 0)

        all_inputs = torch.cat([inputs, inputs_o, all_inputs_o], dim=0)
        all_targets = torch.cat([targets, targets_o, all_targets_o], dim=0)

        alpha = 0.75
        l = np.random.beta(alpha, alpha)
        l = max(l, 1-l)

        idx = torch.randperm(all_inputs.size(0))

        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]

        mixed_input = l * input_a + (1 - l) * input_b
        mixed_target = l * target_a + (1 - l) * target_b

        mixed_input = list(torch.split(mixed_input, args.batch_size))
        mixed_input = interleave(mixed_input, args.batch_size)

        logits = [net(mixed_input[0])]
        for input in mixed_input[1:]:
            logits.append(net(input))
        # put interleaved samples back
        logits = interleave(logits, args.batch_size)

        logits_x = logits[0]
        logits_u = torch.cat(logits[1:], dim=0)

        Lx = -torch.mean(torch.sum(F.log_softmax(logits_x, dim=1) * mixed_target[:args.batch_size], dim=1))
        Lu = -torch.mean(torch.sum(F.log_softmax(logits_u, dim=1) * mixed_target[args.batch_size:], dim=1))
        loss_xu = Lx + 1.5 * Lu * args.lmd_u #* np.clip(batch_idx/args.num_iters, 0.0, 1.0)

        # unlabel ce loss
        if args.lmd_pre > 0:
            logits_o1 = net(inputs_o1)
            loss_u1_ce = -0.5 * torch.mean(torch.sum(F.log_softmax(logits_o1, dim=1) * targets_o, dim=1)) * args.lmd_pre #* np.clip(batch_idx/args.num_iters, 0.0, 1.0)
        else:
            loss_u1_ce = 0
        if args.lmd_rot > 0:
            #unlabel rotation loss
            batch = inputs_o1.size(0) // 4
            x = inputs_o1[:batch]
            x_90 = inputs_o1[batch:2*batch].transpose(2,3)
            x_180 = inputs_o1[2*batch:3*batch].flip(2,3)
            x_270 = inputs_o1[3*batch:].transpose(2,3).flip(2,3)
            x_total = torch.cat((x,x_90,x_180,x_270),0)
            y_total = torch.tensor([0]*batch + [1]*batch + [2]*batch + [3]*(inputs_o1.size(0)-3*batch))

            inputs_o_rot = x_total
            targets_o_rot = y_total.cuda()
            logits_o_rot = net.rot(inputs_o_rot)

            loss_rot = args.lmd_rot * criterion(logits_o_rot, targets_o_rot)
        else:
            loss_rot = 0

        if args.lmd_unif > 0:
            logits_s = net(inputs_s, aux=True)
            loss_unif = - torch.mean(torch.sum(F.log_softmax(logits_s, dim=1) * targets_s, dim=1)) * args.lmd_unif
        else:
            loss_unif = 0

        loss = loss_xu + loss_u1_ce + loss_rot + loss_unif

        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ema_optimizer.step()
        if args.fix_optim:
            adjust_learning_rate(optimizer, batch_idx+1)

        if batch_idx % 1000 == 0:
            if batch_idx // 1000 > (run_iters // 1000) - 5:
                median = True
            else:
                median = False
            logger = logging.getLogger('train')
            logger.info('[Iters {}] [Loss {:.3f}]'.format(
                batch_idx,
                train_loss/1000))
            print('[Iters {}] [Loss {:.3f}]'.format(
                batch_idx,
                train_loss/1000))
            if tensorboardX_compat:
                writer.add_scalar("training/loss", train_loss/1000, batch_idx+1)

            train_loss = 0
            ema_optimizer.step(bn=True)
            save = val(net, batch_idx, _labeled_testloader, median=median)
            if save:
               checkpoint(net, optimizer, best_val, batch_idx)
            save = val(net_ema, batch_idx, _labeled_testloader, ema=True, median=median)
            if save:
                checkpoint(net_ema, optimizer, best_val_ema, batch_idx, ema=True)
            net.train()
            net_ema.train()
        else:
            progress_bar(batch_idx % 1000, 1000, 'working...')

    checkpoint(net, optimizer, current_val, args.num_iters, last=True)
    checkpoint(net_ema, optimizer, current_val_ema, args.num_iters, ema=True, last=True)


def interleave_offsets(batch, nu):
    groups = [batch // (nu + 1)] * (nu + 1)
    for x in range(batch - sum(groups)):
        groups[-x - 1] += 1
    offsets = [0]
    for g in groups:
        offsets.append(offsets[-1] + g)
    assert offsets[-1] == batch
    return offsets

def interleave(xy, batch):
    nu = len(xy) - 1
    offsets = interleave_offsets(batch, nu)
    xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
    for i in range(1, nu + 1):
        xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
    return [torch.cat(v, dim=0) for v in xy]

median_acc = []
median_acc_ema = []

def val(net, iters, testloader, ema=False, median=False):
    global best_val
    global best_val_ema
    global median_acc
    global median_acc_ema
    global current_val
    global current_val_ema
    net.eval()
    val_loss = 0.0
    correct = 0.0
    total = 0.0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()

            outputs = net(inputs)
            loss = torch.mean(criterion(outputs, targets))
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum().float()
            progress_bar(batch_idx, len(testloader),
                         'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (val_loss/(batch_idx+1), 100.*correct/total, correct, total))

    logger = logging.getLogger('test')
    logger.info('[Loss {:.3f}] [Acc {:.3f}]'.format(
        val_loss/(batch_idx+1), 100.*correct/total))

    acc = 100.*correct/total

    if ema:
        if median:
            median_acc_ema.append(acc.item())
        if tensorboardX_compat:
            writer.add_scalar("validation/ema_loss", val_loss/(batch_idx+1), iters+1)
            writer.add_scalar("validation/ema_top1_acc", acc, iters+1)
        current_val_ema = acc
        if acc > best_val_ema:
            best_val_ema = acc
            return True
        else:
            return False
    else:
        if median:
            median_acc.append(acc.item())
        if tensorboardX_compat:
            writer.add_scalar("validation/loss", val_loss/(batch_idx+1), iters+1)
            writer.add_scalar("validation/top1_acc", acc, iters+1)
        current_val = acc
        if acc > best_val:
            best_val = acc
            return True
        else:
            return False

def checkpoint(net, optimizer, acc, iters, ema=False, last=False):
    # Save checkpoint.
    print('Saving..')
    state = {
        'net': net.state_dict(),
        'optimizer': optimizer.state_dict(),
        'acc': acc,
        'iters': iters,
        'rng_state': torch.get_rng_state()
    }
    if ema:
        torch.save(state, os.path.join(logdir, 'ema_ckpt.t7' if (not last) else 'last_ema_ckpt.t7'))
    else:
        torch.save(state, os.path.join(logdir, 'ckpt.t7' if (not last) else 'last_ckpt.t7'))

def adjust_learning_rate(optimizer, iters):
    """decrease the learning rate"""
    lr = args.lr * np.cos(iters/(args.num_iters+1) * (7 * np.pi) / (2 * 8))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

class WeightEMA(object):
    def __init__(self, model, ema_model, alpha=0.999, wd=False):
        self.model = model
        self.ema_model = ema_model
        self.alpha = alpha
        self.tmp_model = models.load_model(args.model, _labeled_num_class, divide=args.aux_divide)
        self.wd = 0.02 * args.lr if wd else 0

        for param, ema_param in zip(self.model.parameters(), self.ema_model.parameters()):
            ema_param.data.copy_(param.data)

    def step(self, bn=False):
        if bn:
            # copy batchnorm stats to ema model
            for ema_param, tmp_param in zip(self.ema_model.parameters(), self.tmp_model.parameters()):
                tmp_param.data.copy_(ema_param.data.detach())

            self.ema_model.load_state_dict(self.model.state_dict())

            for ema_param, tmp_param in zip(self.ema_model.parameters(), self.tmp_model.parameters()):
                ema_param.data.copy_(tmp_param.data.detach())
        else:
            one_minus_alpha = 1.0 - self.alpha
            for param, ema_param in zip(self.model.parameters(), self.ema_model.parameters()):
                ema_param.data.mul_(self.alpha)
                ema_param.data.add_(param.data.detach() * one_minus_alpha)
                # customized weight decay
                param.data.mul_(1 - self.wd)

if args.ema:
    ema_train()

    print("Best Accuracy : {}".format(best_val))
    print("Best Accuracy EMA : {}".format(best_val_ema))
    print("Median Accuracy : {}".format(np.median(median_acc)))
    print("Median Accuracy EMA : {}".format(np.median(median_acc_ema)))
    logger = logging.getLogger('best')
    logger.info('[Acc {:.3f}] [EMA Acc {:.3f}] [MEDIAN Acc {:.3f}] [MEDIAN EMA Acc {:.3f}]'.format(best_val, best_val_ema, np.median(median_acc), np.median(median_acc_ema)))
else:
    raise NotImplementedError

if tensorboardX_compat:
    writer.close()
