#!/usr/bin/env python3 -u

from __future__ import print_function

import argparse
import csv
import os, logging
import copy
import random
from collections import OrderedDict

import numpy as np
import torch
from torch.autograd import Variable, grad
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms

import models
from utils import progress_bar, set_logging_defaults
from datasets import load_dataset

# torch_version = [int(v) for v in torch.__version__.split('.')]
tensorboardX_compat = True #(torch_version[0] >= 1) and (torch_version[1] >= 1) # PyTorch >= 1.1
try:
    from tensorboardX import SummaryWriter
except ImportError:
    print ('No tensorboardX package is found. Start training without tensorboardX')
    tensorboardX_compat = False
    #raise RuntimeError("No tensorboardX package is found. Please install with the command: \npip install tensorboardX")

parser = argparse.ArgumentParser(description='UASD Training')
parser.add_argument('--lr', default=0.002, type=float, help='learning rate')
parser.add_argument('--model', default="wide_resnet", type=str,
                    help='model type (default: wide_resnet)')
parser.add_argument('--name', default='0', type=str, help='name of run')
parser.add_argument('--batch-size', default=64, type=int, help='batch size')
parser.add_argument('--num_iters', default=50000, type=int, help='total epochs to run')
parser.add_argument('--decay', default=0, type=float, help='weight decay')
parser.add_argument('--ngpu', default=1, type=int, help='number of gpu')
parser.add_argument('--sgpu', default=0, type=int, help='gpu index (start)')
parser.add_argument('--dataset', default='cifar10', type=str, help='the name for dataset')
parser.add_argument('--udata', default='svhn', type=str, help='type of unlabel data')
parser.add_argument('--tinyroot', default='/data/tinyimagenet/tiny-imagenet-200/', type=str, help='TinyImageNet directory')
parser.add_argument('--imgroot', default='/data/ILSVRC/Data/CLS-LOC/', type=str, help='unlabel data directory')
parser.add_argument('--dataroot', default='/data/', type=str, help='data directory')
parser.add_argument('--saveroot', default='./results', type=str, help='data directory')
parser.add_argument('--finetune', '-ft', action='store_true', help='finetuning')
parser.add_argument('--pc', default=25, type=int, help='number of samples per class')
parser.add_argument('--no_alignment', action='store_true', help='no distribution alignment')
parser.add_argument('--naug', default=8, type=int, help='superclass indices')
parser.add_argument('--lmd_u', default=1., type=float, help='Lu loss weight')
parser.add_argument('--lmd_pre', default=1., type=float, help='Lu loss weight')
parser.add_argument('--lmd_rot', default=0.5, type=float, help='rotation loss weight')
parser.add_argument('--nworkers', default=4, type=int, help='num_workers')
parser.add_argument('--method', default='default', type=str, help='data directory')

parser.add_argument('--model_path', default=None, type=str, help='(unsupervised) pretrained model path')
parser.add_argument('--ood_samples', default=0, type=int, help='number of ood samples in [0,10000,20000,30000,40000]')
parser.add_argument('--fix_optim', action='store_true', help='using optimizer of FixMatch')
parser.add_argument('--stop_iters', default=None, type=int, help='early stopping')
parser.add_argument('--use_jitter', action='store_true', help='using jitter augmentation for unlabeled data')
parser.add_argument('--no_rampup', action='store_true', help='do not use rampup')
parser.add_argument('--simclr_optim', action='store_true', help='using optimizer of SimCLR semi finetune')
args = parser.parse_args()
use_cuda = torch.cuda.is_available()

best_val = 0  # best validation accuracy
start_iters = 0  # start from epoch 0 or last checkpoint epoch
current_val = 0

cudnn.benchmark = True

# Data
_labeled_trainset, _unlabeled_trainset, _labeled_testset = load_dataset(args.dataset, args.dataroot, batch_size=args.batch_size, pc=str(args.pc), method=args.method, naug=args.naug, uroot=args.udata, tinyroot=args.tinyroot, imgroot=args.imgroot, ood_samples=args.ood_samples, use_jitter=args.use_jitter)
_labeled_num_class = _labeled_trainset.num_classes
print('Numclass: ', _labeled_num_class)
print('==> Preparing dataset: {}'.format(args.dataset))
print('Number of label dataset: ' ,len(_labeled_trainset))
print('Number of unlabel dataset: ',len(_unlabeled_trainset))
print('Number of test dataset: ',len(_labeled_testset))

logdir = os.path.join(args.saveroot, args.dataset, args.model, args.name)
set_logging_defaults(logdir, args)
logger = logging.getLogger('main')
logname = os.path.join(logdir, 'log.csv')
if tensorboardX_compat:
    writer = SummaryWriter(logdir=logdir)

if use_cuda:
    torch.cuda.set_device(args.sgpu)
    print(torch.cuda.device_count())
    print('Using CUDA..')

criterion = nn.CrossEntropyLoss()

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

class MergeDataset(torch.utils.data.Dataset):
    def __init__(self, dataset1, dataset2):
        assert len(dataset1)==len(dataset2)
        self.dataset1 = dataset1
        self.dataset2 = dataset2

    def __getitem__(self, i):
        return self.dataset1[i][:2] + self.dataset2[i]

    def __len__(self):
        return len(self.dataset1)

def get_softmax(net, dataloader):
    net.eval()

    emb = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.cuda(), targets.cuda()

            outputs = net(inputs)
            emb.append(torch.softmax(outputs,dim=1).cpu())
    emb = torch.cat(emb, dim=0)

    return emb

def load_dataset_softlabel(net, _labeled_num_class, _labeled_trainset, _unlabeled_trainset, sim_path=None):
    # Load checkpoint.
    if (args.model in ['resnet50']):
        transform_test = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ])
    elif (args.model in ['CIFAR_ResNet50', 'wide_resnet']):
        transform_test = transforms.Compose([
                transforms.Resize(32),
                transforms.ToTensor(),
            ])

    labeled_trainset = copy.deepcopy(_labeled_trainset.base_dataset) # full dataset
    labeled_trainset.transform = transform_test

    if args.dataset == 'cifar10':
        val_idx = np.load(os.path.join('splits', 'cifar10_unlabel_train_idx.npy')).astype(np.int64)
    elif args.dataset == 'cifar100':
        val_idx = np.load(os.path.join('splits', 'cifar100_unlabel_train_idx.npy')).astype(np.int64)
    elif args.dataset == 'animal':
        tmp_idx1 = np.load(os.path.join('splits', 'animal_400pc_idx.npy')).astype(np.int64)
        tmp_idx2 = np.load(os.path.join('splits', 'cifar10_notanimal_unlabel_idx.npy')).astype(np.int64)
        val_idx = [i for i in range(50000) if i not in tmp_idx1 and i not in tmp_idx2]
    elif args.dataset == 'tiny@img':
        tmp_idx1 = np.load(os.path.join('splits', 'tiny_imagenet_train_idx.npy')).astype(np.int64)
        tmp_idx2 = np.load(os.path.join('splits', 'tiny_imagenet_25pc_train_idx.npy')).astype(np.int64)
        val_idx = [i for i in range(1281167) if i in tmp_idx1 and i not in tmp_idx2]
    elif args.dataset == 'mini@img':
        tmp_idx1 = np.load(os.path.join('splits', 'mini_imagenet_train_idx.npy')).astype(np.int64)
        tmp_idx2 = np.load(os.path.join('splits', 'mini_imagenet_25pc_train_idx.npy')).astype(np.int64)
        val_idx = [i for i in range(1281167) if i in tmp_idx1 and i not in tmp_idx2]
    elif args.dataset in ['dog_cls', 'cat_cls', 'frog_cls', 'turtle_cls', 'bird_cls', 'primate_cls', 'fish_cls', 'carb_cls', 'insect_cls', 'reptile_cls', 'aquatic_animal_cls', 'food_cls', 'produce_cls', 'scenry_cls']:
        tmp_idx1 = np.load(os.path.join('splits_img', args.dataset +'_train.npy')).astype(np.int64)
        tmp_idx2 = np.load(os.path.join('splits_img', args.dataset +'_25pc_train.npy')).astype(np.int64)
        val_idx = [i for i in range(1281167) if i in tmp_idx1 and i not in tmp_idx2]

    val_trainset = torch.utils.data.Subset(labeled_trainset, val_idx[-int(len(_labeled_trainset.indices) * 0.1):]) # slicing

    val_trainloader = torch.utils.data.DataLoader(val_trainset, batch_size=256, shuffle=False, num_workers=4)
    unlabeled_trainset = copy.deepcopy(_unlabeled_trainset)
    for u_dataset in unlabeled_trainset.datasets:
        if type(u_dataset)==torch.utils.data.Subset:
            u_dataset.dataset.transform = transform_test
        else:
            u_dataset.transform = transform_test
    unlabeled_trainloader = torch.utils.data.DataLoader(unlabeled_trainset, batch_size=256, shuffle=False, num_workers=4)

    soft_l = get_softmax(net, val_trainloader)
    ths = soft_l.max(1)[0].mean()
    soft_u = get_softmax(net, unlabeled_trainloader)

    return ths, soft_u

def train():
    # Model
    print('==> Building model: {}'.format(args.model))
    net = models.load_model(args.model, _labeled_num_class)

    if args.finetune:
        model_dict = net.state_dict()
        if (args.model in ['resnet50', 'resnet50_auxbn']):
            try:
                pretrained_dict = torch.load(args.model_path, map_location='cpu')['model']
            except KeyError:
                pretrained_dict = torch.load(args.model_path, map_location='cpu')['net']
            classifier = ['fc.weight', 'fc.bias', 'linear_rot.weight', 'linear_rot.bias']
            imagesize = 224
        elif (args.model in ['CIFAR_ResNet50', 'CIFAR_ResNet50_AuxBN', 'wide_resnet', 'wide_resnet_auxbn']):
            try:
                pretrained_dict = torch.load(args.model_path, map_location='cpu')['model']
            except KeyError:
                pretrained_dict = torch.load(args.model_path, map_location='cpu')['net']
            classifier = ['linear.weight', 'linear.bias', 'linear_rot.weight', 'linear_rot.bias']
            imagesize = 32
        new_state_dict = OrderedDict()
        for k, v in pretrained_dict.items():
            if k[:6]=='module':
                name = k[7:] # remove `module.`
            else:
                name = k
            new_state_dict[name] = v
        new_state_dict = {k: v for k, v in new_state_dict.items() if k in model_dict and k not in classifier}
        model_dict.update(new_state_dict)
        net.load_state_dict(model_dict)

    net.cuda()
    print('    Total params: %.2fM' % (sum(p.numel() for p in net.parameters())/1000000.0))
    # print(net)
    if args.ngpu > 1:
        net = torch.nn.DataParallel(net, device_ids=list(range(args.sgpu, args.sgpu + args.ngpu)))

    if args.simclr_optim:
        assert (not args.fix_optim)
        optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=0, nesterov=True)
    elif args.fix_optim:
        optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.decay, nesterov=True)
    else:
        optimizer = optim.Adam(net.parameters(), lr=args.lr)

    net.train()

    if len(_labeled_trainset) < args.batch_size:
        rand_sampler = torch.utils.data.RandomSampler(_labeled_trainset, num_samples=args.batch_size, replacement=True)
        _labeled_trainloader = torch.utils.data.DataLoader(_labeled_trainset, batch_size=args.batch_size, sampler=rand_sampler, num_workers=0)
    else:
        _labeled_trainloader = torch.utils.data.DataLoader(_labeled_trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.nworkers, drop_last=True)
    _labeled_testloader = torch.utils.data.DataLoader(_labeled_testset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    _labeled_train_iter = iter(cycle(_labeled_trainloader))
    _unlabeled_trainloader = torch.utils.data.DataLoader(_unlabeled_trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.nworkers, drop_last=True)
    _unlabeled_train_iter = iter(cycle(_unlabeled_trainloader))
    train_loss = 0
    correct = 0
    total = 0

    ensemble_sum = 0
    ensemble_cnt = 0
    ths = None
    run_iters = args.num_iters if args.stop_iters is None else args.stop_iters
    for batch_idx in range(start_iters, run_iters + 1):
        (inputs, inputs_aug), targets = next(_labeled_train_iter)
        if ths is None:
            (inputs_o, inputs_o_strong), targets_u = next(_unlabeled_train_iter)
            soft_u = None
        else:
            (inputs_o, inputs_o_strong), targets_u, soft_u = next(_unlabeled_train_iter)
            score = soft_u.max(1)[0]
            mask = (score >= ths)
            if use_cuda:
                mask = mask.cuda()
                soft_u = soft_u.cuda()

        if use_cuda:
            inputs = inputs.cuda()
            inputs_o = inputs_o.cuda()
            targets = targets.cuda()

        outputs   = net(inputs)
        outputs_o  = net(inputs_o)

        Lx = criterion(outputs, targets)
        if soft_u is not None:
            Lu = -torch.mean(torch.sum(F.log_softmax(outputs_o, dim=1) * soft_u, dim=1) * mask)
        else:
            Lu = 0

        if args.no_rampup:
            loss = Lx + Lu
        else:
            loss = Lx + Lu * np.clip(batch_idx/args.num_iters, 0.0, 1.0)

        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if args.fix_optim:
            adjust_learning_rate(optimizer, batch_idx+1)

        if batch_idx % 1000 == 0:
            if batch_idx // 1000 > (run_iters // 1000) - 5:
                median = True
            else:
                median = False
            logger = logging.getLogger('train')
            logger.info('[Iters {}] [Loss {:.3f}]'.format(
                batch_idx,
                train_loss/1000))
            print('[Iters {}] [Loss {:.3f}]'.format(
                batch_idx,
                train_loss/1000))
            if tensorboardX_compat:
                writer.add_scalar("training/loss", train_loss/1000, batch_idx+1)

            train_loss = 0
            save = val(net, batch_idx, _labeled_testloader, median=median)
            if save:
               checkpoint(net, optimizer, best_val, batch_idx)

            infer_freq = 1000

            if batch_idx % infer_freq == 0:
                ths, soft_u = load_dataset_softlabel(net, _labeled_num_class, _labeled_trainset, _unlabeled_trainset)
                ensemble_sum += soft_u
                ensemble_cnt += 1
                soft_en = ensemble_sum / float(ensemble_cnt)
                unlabeled_trainset = copy.deepcopy(_unlabeled_trainset)
                unlabeled_trainset = MergeDataset(unlabeled_trainset, torch.utils.data.TensorDataset(soft_en))
                unlabeled_trainloader = torch.utils.data.DataLoader(unlabeled_trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.nworkers, drop_last=True)
                _unlabeled_train_iter = iter(cycle(unlabeled_trainloader))

            net.train()
        else:
            progress_bar(batch_idx % 1000, 1000, 'working...')

    checkpoint(net, optimizer, current_val, args.num_iters, last=True)


def interleave_offsets(batch, nu):
    groups = [batch // (nu + 1)] * (nu + 1)
    for x in range(batch - sum(groups)):
        groups[-x - 1] += 1
    offsets = [0]
    for g in groups:
        offsets.append(offsets[-1] + g)
    assert offsets[-1] == batch
    return offsets

def interleave(xy, batch):
    nu = len(xy) - 1
    offsets = interleave_offsets(batch, nu)
    xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
    for i in range(1, nu + 1):
        xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
    return [torch.cat(v, dim=0) for v in xy]

median_acc = []
median_acc_ema = []

def val(net, iters, testloader, median=False):
    global best_val
    global best_val_ema
    global median_acc
    global median_acc_ema
    global current_val
    global current_val_ema
    net.eval()
    val_loss = 0.0
    correct = 0.0
    total = 0.0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()

            outputs = net(inputs)
            loss = torch.mean(criterion(outputs, targets))
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum().float()
            progress_bar(batch_idx, len(testloader),
                         'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (val_loss/(batch_idx+1), 100.*correct/total, correct, total))

    logger = logging.getLogger('test')
    logger.info('[Loss {:.3f}] [Acc {:.3f}]'.format(
        val_loss/(batch_idx+1), 100.*correct/total))

    acc = 100.*correct/total

    if median:
        median_acc.append(acc.item())
    if tensorboardX_compat:
        writer.add_scalar("validation/loss", val_loss/(batch_idx+1), iters+1)
        writer.add_scalar("validation/top1_acc", acc, iters+1)
    current_val = acc
    if acc > best_val:
        best_val = acc
        return True
    else:
        return False

def checkpoint(net, optimizer, acc, iters, last=False):
    # Save checkpoint.
    print('Saving..')
    state = {
        'net': net.state_dict(),
        'optimizer': optimizer.state_dict(),
        'acc': acc,
        'iters': iters,
        'rng_state': torch.get_rng_state()
    }
    torch.save(state, os.path.join(logdir, 'ckpt.t7' if (not last) else 'last_ckpt.t7'))

def adjust_learning_rate(optimizer, iters):
    """decrease the learning rate"""
    lr = args.lr * np.cos(iters/(args.num_iters+1) * (7 * np.pi) / (2 * 8))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


train()

print("Best Accuracy : {}".format(best_val))
print("Median Accuracy : {}".format(np.median(median_acc)))
logger = logging.getLogger('best')
logger.info('[Acc {:.3f}] [MEDIAN Acc {:.3f}] '.format(best_val, np.median(median_acc)))

if tensorboardX_compat:
    writer.close()
