#!/usr/bin/env python
# -*- coding: utf-8 -*-
from lib.datasets.mydatasets import AdvDataset
import torch
import torch.nn as nn
import torch.optim as optim
import argparse, time, json, os
from build_dataset import get_dataloaders, RandomSampler
from config import config
from lib.datasets.mytransforms import TransformConsistency, TransformFixMatch, TransformCausal, TransformTest
from lib.models.wideresenet import FC, Data_Decoder_CIFAR, WideResNet
import torch.nn.functional as F
from misc import get_cosine_schedule_with_warmup, kld, compute_roc 
from torchvision import transforms


parser = argparse.ArgumentParser()
parser.add_argument("--alg", "-a", default="PI", type=str, help="ssl algorithm : [supervised, PI, MT, VAT, PL]")
parser.add_argument("--em", default=0.2, type=float, help="coefficient of entropy minimization. If you try VAT + EM, set 0.06")
parser.add_argument("--validation", default=500, type=int, help="validate at this interval (default 25000)")
parser.add_argument("--dataset", "-d", default="crossset", type=str, help="dataset name")
parser.add_argument("--n_labels", "-n", default=2400, type=int, help="the number of labeled data")
parser.add_argument("--n_unlabels", "-u", default=20000, type=int, help="the number of unlabeled data")
parser.add_argument('--n_valid', default=1200, type=int)
parser.add_argument('--n_test', default=1200, type=int)
parser.add_argument('--batch_size', default=100, type=int)
parser.add_argument('--tot_class', "-tot_class", default=6, type=int)
parser.add_argument('--n_augs', default=2, type=int)
parser.add_argument('--iterations', default=500000, type=int)
parser.add_argument('--num_workers', default=4, type=int)
parser.add_argument("--root", "-r", default="data", type=str, help="dataset dir")
parser.add_argument("--output", "-o", default="./exp_res", type=str, help="output dir")
parser.add_argument("--ps_th", "-ps_th", default=0.95, type=float, help="pseudo label threshold")
parser.add_argument("--aug_iter", "-aug_iter", default=10000, type=int, help="augmentation start iteration")
parser.add_argument("--aug_number", "-aug_number", default=400, type=int, help="augmentation end iteration")
parser.add_argument("--rec_coef", "-rec_coef", default=0.001, type=float, help="coefficient for reconstruction loss")
parser.add_argument("--adv_eps", "-adv_eps", default=0.03, type=float, help="maximum value of adversarial perturbation")
parser.add_argument("--adv_step", "-adv_step", default=15, type=int, help="number of steps for projected gradient descent")
parser.add_argument("--gpus", default=1, type=int, help="number of GPUs") # using 1 GPUs.
parser.add_argument("--seed", "-s", default=0, type=int, help="train seed")
parser.add_argument("--description", "-desc", default="no description", type=str, help="description of the intention of the current experiment")
args = parser.parse_args()

data_info_dict = {
    "cifar10":{
        "mean":(0.4914, 0.4822, 0.4465),
        "std":(0.2023, 0.1994, 0.2010),
        'image_size':32
    },
    "cifar100":{
        "mean":(0.4914, 0.4822, 0.4465),
        "std":(0.2023, 0.1994, 0.2010),
        'image_size':32
    },
    "svhn":{
        "mean":(0.5, 0.5, 0.5),
        "std":(0.5, 0.5, 0.5),
        'image_size':32
    },
    "mnist":{
        "mean":(0.1306604762738429,),
        "std":(0.30810780717887876,),
        'image_size':28
    },
    "cmnist":{
        "mean":(0.5, 0.5, 0.5),
        "std":(0.5, 0.5, 0.5),
        'image_size':28
    },
    "fashionmnist":{
        "mean":(0.286,),
        "std":(0.353,),
        'image_size':28
    }
}

transform_info_dict = {
    "MT": TransformConsistency,
    "PI": TransformConsistency,
    "PL": TransformConsistency,
    "VAT": TransformConsistency,
    "FM": TransformFixMatch,
    "CS": TransformCausal,
}


arch_dict = {"mnist":"lenet","svhn":"resnet34","cifar10":"resnet34","cifar100":"resnet50"}
input_channel_dict = {"mnist":1,"cmnist":3,"svhn":3,"cifar10":3,"cifar100":3}
    
def compute_rec(reconstructed_inputs, original_inputs):
    return F.mse_loss(reconstructed_inputs, original_inputs, reduction="mean")

def compute_entropy(out, mask):
    # entropy
    return - (((out.softmax(1) * F.log_softmax(out, 1)).sum(1)) * mask).mean()

def optimize_content_classification(inputs, class_label, disentangle=True, require_feature=False, require_known_mask=False, known_mask=None):
    # classification with class labels.
    content = generator['gen_c'](inputs)
    content_pred = classifier["cls_c"](content)

    if require_known_mask:
        content_cls_loss = (F.cross_entropy(content_pred, class_label, reduction="none", ignore_index=-1) * known_mask).mean()
    else:
        content_cls_loss = F.cross_entropy(content_pred, class_label, reduction="none", ignore_index=-1).mean()

    if disentangle:
        style_pred = classifier["cls_s"](content, reverse=True)
        content_cls_loss += compute_entropy(style_pred, known_mask)
    
    content_cls_loss.backward()
    group_opt_step(['gen_c', 'cls_c'])

    if require_feature:
        return content, content_cls_loss
    else:
        return content_cls_loss


def optimize_content_openclassification(inputs, class_label, negative=False, required_mask=False, mask=None):
    # classification with class labels.
    content = generator['gen_c'](inputs)
    content_pred_open = classifier["cls_c_o"](content)
    # adapted from "OpenMatch: https://github.com/VisionLearningGroup/OP_Match".
    content_pred_open = content_pred_open.view(content_pred_open.size(0), 2, -1)
    content_pred_open = F.softmax(content_pred_open, 1)
    open_target = torch.zeros((content_pred_open.size(0), content_pred_open.size(2))).to(class_label.device)
    label_range = torch.range(0, content_pred_open.size(0) - 1).long()
    open_target[label_range, class_label] = 1
    open_target_nega = 1 - open_target

    if required_mask:
        content_opencls_loss = (1-args.em) * torch.mean((torch.sum(-torch.log(content_pred_open[:, 0, :] + 1e-8) * open_target, 1) + torch.max(-torch.log(content_pred_open[:, 1, :] + 1e-8) * open_target_nega, 1)[0]) * (1.-mask))
        content_opencls_loss += args.em * torch.mean(torch.mean(torch.sum(-content_pred_open * torch.log(content_pred_open + 1e-8), 1)) * mask)
    elif negative:
        content_opencls_loss = torch.mean(torch.sum(-torch.log(content_pred_open[:, 1, :] + 1e-8) * open_target, 1))

    content_opencls_loss.backward()
    group_opt_step(['gen_c', 'cls_c_o'])

    return content_opencls_loss
    

def optimize_style_classification(inputs, aug_label, disentangle=True, require_feature=False, require_known_mask=False, known_mask=None):
    # classification with augmentation index.
    style = generator['gen_s'](inputs)
    style_pred = classifier["cls_s"](style)

    if require_known_mask:
        style_cls_loss = (F.cross_entropy(style_pred, aug_label, reduction="none", ignore_index=-1) * known_mask).mean()
    else:
        style_cls_loss = F.cross_entropy(style_pred, aug_label, reduction="none", ignore_index=-1).mean()

    if disentangle:
        content_pred = classifier["cls_c"](style, reverse=True)
        style_cls_loss += compute_entropy(content_pred, known_mask)

    style_cls_loss = style_cls_loss

    style_cls_loss.backward()
    group_opt_step(['gen_s', 'cls_s'])

    if require_feature:
        return style, style_cls_loss
    else:
        return style_cls_loss



def optimize_reconstruction(inputs, content, style):
    reconstructed_content_inputs = reconstructor['dec'](content.detach(), style.detach())
    rec_loss = args.rec_coef * (compute_rec(reconstructed_content_inputs, inputs))
    
    rec_loss.backward()
    group_opt_step(['dec'])

    return rec_loss


    
def optimize_consistency(inputs, targets):
    feature = generator['gen_c'](inputs)
    logits = classifier['cls_c_o'](feature)

    logits = logits.view(logits.size(0), 2, -1)
    targets = targets.view(targets.size(0), 2, -1)
    logits = F.softmax(logits, 1)
    targets = F.softmax(targets, 1)
    consistency_loss = torch.mean(torch.sum(torch.sum(torch.abs(
        logits - targets)**2, 1), 1))
        
    consistency_loss.backward()
    group_opt_step(['gen_c', 'cls_c_o'])

    return consistency_loss


def data_augmentation(generator_all, classifier_all, inputs_i, labels_all, domain_targeted=False, class_targeted=False):
    """
        This function computs the adversarial examples by doing interventions.
    """ 
    generator_c, generator_s = generator_all
    classifier_c, classifier_s = classifier_all
    class_label, domain_label = labels_all
    generator_c.update_batch_stats(False)
    generator_s.update_batch_stats(False)
    generator_c.eval()
    generator_s.eval()
    classifier_c.eval()
    classifier_s.eval()

    delta_benign = args.adv_eps * torch.zeros_like(inputs_i).detach()
    delta_malign = args.adv_eps * torch.randn_like(inputs_i).detach()
    benign_x = inputs_i.data.clone().detach()
    malign_x = inputs_i.data.clone().detach()
    delta_benign.requires_grad_()
    delta_malign.requires_grad_()
    batch_size = len(inputs_i)
    optimizer_benign = torch.optim.SGD([delta_benign], lr=0.007)
    optimizer_malign = torch.optim.SGD([delta_malign], lr=0.007)

    init_content = None
    init_class = None
    init_style = None
    init_domain = None
    # benign augmentation
    for i in range(2):
        benign_x = inputs_i + delta_benign
        optimizer_benign.zero_grad()

        with torch.enable_grad():
            benign_c = generator_c(benign_x)
            benign_s = generator_s(benign_x)
            domain_pred = classifier_s(benign_s)
            if i == 0:
                init_content = benign_c.clone().detach()
                init_domain = domain_pred.clone().detach()
            if domain_targeted:
                benign_loss  = F.cross_entropy(domain_pred, domain_label, reduction="none", ignore_index=-1).mean()
            else:
                benign_loss = -F.cross_entropy(domain_pred, domain_label, reduction="none", ignore_index=-1).mean()
            benign_loss += F.mse_loss(benign_c, init_content)
        benign_loss.backward(retain_graph=True)

        benign_grad_norm = delta_benign.grad.view(batch_size, -1).norm(p=2, dim=1)
        delta_benign.grad.div_(benign_grad_norm.view(-1, 1, 1, 1))
        if (benign_grad_norm == 0).any():
            delta_benign.grad[benign_grad_norm == 0] = torch.randn_like(delta_benign.grad[benign_grad_norm == 0])
        optimizer_benign.step()

        delta_benign.data.add_(inputs_i)
        delta_benign.data.clamp_(0, 1).sub_(inputs_i)
        delta_benign.data.renorm_(p=2, dim=0, maxnorm=args.adv_eps)

    # malign augmentation
    for i in range(args.adv_step):
        malign_x = inputs_i + delta_malign
        optimizer_malign.zero_grad()
        with torch.enable_grad():
            malign_s = generator_s(malign_x)
            malign_c = generator_c(malign_x)
            class_pred = classifier_c(malign_c)
            if i == 0:
                init_style = malign_s.clone().detach()
                init_class = class_pred.clone().detach()
            if class_targeted:
                malign_loss = F.cross_entropy(class_pred, class_label, reduction="none", ignore_index=-1).mean()
            else:
                malign_loss = -F.cross_entropy(class_pred, class_label, reduction="none", ignore_index=-1).mean()
            malign_loss += F.mse_loss(malign_s, init_style)
        malign_loss.backward(retain_graph=True)

        malign_grad_norm = delta_malign.grad.view(batch_size, -1).norm(p=2, dim=1)
        delta_malign.grad.div_(malign_grad_norm.view(-1, 1, 1, 1))
        if (malign_grad_norm == 0).any():
            delta_malign.grad[malign_grad_norm == 0] = torch.randn_like(delta_malign.grad[malign_grad_norm == 0])
        optimizer_malign.step()

        delta_malign.data.add_(inputs_i)
        delta_malign.data.clamp_(0, 1).sub_(inputs_i)
        delta_malign.data.renorm_(p=2, dim=0, maxnorm=args.adv_eps)

    benign_x = torch.clamp((inputs_i + delta_benign), 0.0, 1.0)
    malign_x = torch.clamp((inputs_i + delta_malign), 0.0, 1.0)

    generator_c.update_batch_stats(True)
    generator_s.update_batch_stats(True)
    generator_c.train()
    generator_s.train()
    classifier_c.train()
    classifier_s.train()
    return benign_x.to('cpu').detach(), malign_x.to('cpu').detach(), class_label.to('cpu').detach().flatten().tolist(), benign_loss, malign_loss



def reset_grad():
    for _, opt in optims.items():
        opt.zero_grad()

def group_opt_step(opt_keys):
    for k in opt_keys:
        optims[k].step()
    reset_grad()

def group_opt_lr_decay():
    for k in optims.keys():
        optims[k].param_groups[0]["lr"] *= 0.2
    reset_grad()

def group_scheduler_step():
    for k in optims.keys():
        scheduler[k].step()

if torch.cuda.is_available():
    device = "cuda"
    torch.backends.cudnn.benchmark = True
else:
    device = "cpu"

condition = {}
condition["description"] = args.description
exp_name = ""

shared_cfg = config["shared"]
args.iterations = shared_cfg["iteration"]
condition['argparse'] = {}
print("dataset : {}".format(args.dataset))
alg_cfg = config[args.alg]
print("parameters : ", alg_cfg)
condition["h_parameters"] = alg_cfg
exp_name += str(args.dataset) + "_causal_ssl_"
condition["val_acc"] = []
condition["test_acc"] = []
condition["test_roc"] = []

dataset_info = data_info_dict[args.dataset]
transform_fn = transform_info_dict[args.alg](mean=dataset_info["mean"], std=dataset_info["std"], size_image=dataset_info["image_size"])
rec_transform = transforms.Normalize(mean=dataset_info["mean"], std=dataset_info["std"])
transform_test = TransformTest(mean=dataset_info["mean"], std=dataset_info["std"])

l_loader, u_loader, val_loader, test_loader = \
    get_dataloaders(args, transform_fn=transform_fn, transform_test=transform_test)

###### models #######
generator = nn.ModuleDict({
    'gen_c': WideResNet(),
    'gen_s': WideResNet(),
}).to(device)


# classifier models output no-softmax predictions
classifier = nn.ModuleDict({
    'cls_c': FC(z_dim=generator['gen_c'].feature_dim, num_classes=args.tot_class),
    'cls_s': FC(z_dim=generator['gen_s'].feature_dim, num_classes=args.num_augs),
    'cls_c_o': FC(z_dim=generator['gen_c'].feature_dim, num_classes=2*args.tot_class), 
}).to(device)

reconstructor = nn.ModuleDict({
    'dec': Data_Decoder_CIFAR(z_dim=generator['gen_c'].feature_dim * 2),
}).to(device)


###### load parameters #######
# generator.load_state_dict(torch.load(os.path.join(args.output, 'best_model_generator_modify_adv2022-05-25_07-15-23.pth')))
# classifier.load_state_dict(torch.load(os.path.join(args.output, 'best_model_classifier_modify_adv2022-05-25_07-15-25.pth')))
# reconstructor.load_state_dict(torch.load(os.path.join(args.output, 'best_model_reconstructor_modify_adv2022-05-25_07-15-25.pth')))

###### optimizer #######
optims = {
    'gen_c': optim.SGD(generator['gen_c'].parameters(), lr=3e-2, momentum=0.9, nesterov=True),
    'gen_s': optim.SGD(generator['gen_s'].parameters(), lr=3e-2, momentum=0.9, nesterov=True),
    'cls_c': optim.SGD(classifier['cls_c'].parameters(), lr=3e-2, momentum=0.9, nesterov=True),
    'cls_s': optim.SGD(classifier['cls_s'].parameters(), lr=3e-2, momentum=0.9, nesterov=True),
    'cls_c_o': optim.SGD(classifier['cls_c_o'].parameters(), lr=3e-2, momentum=0.9, nesterov=True),
    'dec': optim.SGD(reconstructor["dec"].parameters(), lr=3e-2, momentum=0.9, nesterov=True),
}

scheduler = {
    'gen_c': get_cosine_schedule_with_warmup(optims['gen_c'], 0, shared_cfg["iteration"]),
    'gen_s': get_cosine_schedule_with_warmup(optims['gen_s'], 0, shared_cfg["iteration"]),
    'cls_c': get_cosine_schedule_with_warmup(optims['cls_c'], 0, shared_cfg["iteration"]),
    'cls_s': get_cosine_schedule_with_warmup(optims['cls_s'], 0, shared_cfg["iteration"]),
    'cls_c_o': get_cosine_schedule_with_warmup(optims['cls_c_o'], 0, shared_cfg["iteration"]),
    'dec': get_cosine_schedule_with_warmup(optims['dec'], 0, shared_cfg["iteration"]),
}

trainable_paramters = sum([p.data.nelement() for p in generator["gen_c"].parameters()])
print("trainable parameters : {}".format(trainable_paramters))

print()
iteration = 0
maximum_val_acc = 0
benign_dataset = []
malign_dataset = []

content_aug_loss = torch.zeros(1).to(device)
content_neg_loss = torch.zeros(1).to(device)
content_opencls_loss = torch.zeros(1).to(device)
content_cls_loss = torch.zeros(1).to(device)
style_cls_loss = torch.zeros(1).to(device)
open_consistency_loss = torch.zeros(1).to(device)
rec_loss = torch.zeros(1).to(device)

k = 0
ss = time.time()
for l_data, u_data in zip(l_loader, u_loader):
    iteration += 1
    l_inputs_list, l_targets = l_data
    l_targets = l_targets.to(device).long()

    u_inputs_list, dummy_targets = u_data
    dummy_targets = dummy_targets.to(device).long()
    dummy_targets[:] = -1

    targets = torch.cat([l_targets, dummy_targets], 0)
    unlabeled_mask = (targets == -1)
    labeled_mask = (targets != -1)

    reset_grad()
    for i in range(args.n_augs+1):
        inputs_i = torch.cat([l_inputs_list[i], u_inputs_list[i]], 0).to(device).float()
        aug_idx = torch.zeros_like(targets).to(device).long()
        aug_idx[:] = i
        # the first part is the non-augmented images
        if i == 0:
            if iteration == args.aug_iter:
                if k == 0:
                    benign_dataset = []
                    malign_dataset = []
                benign_x, malign_x, aug_label, benign_loss, malign_loss = data_augmentation([generator['gen_c'], generator['gen_s']], [classifier['cls_c'], classifier['cls_s']], inputs_i, [targets, aug_idx])
                benign_x = rec_transform(benign_x)
                malign_x = rec_transform(malign_x)
                print("%d-th Max-phase:, benign loss: %g, malign loss: %g" % (k, benign_loss, malign_loss), "\r", end="")
                benign_dataset.append(AdvDataset(benign_x, aug_label))
                malign_dataset.append(AdvDataset(malign_x, aug_label))
                k += 1
                if k == args.aug_number:
                    benign_dataset = torch.utils.data.ConcatDataset(benign_dataset)
                    benign_loader = torch.utils.data.DataLoader(benign_dataset, sampler=RandomSampler(len(benign_dataset), (args.iterations-iteration) * args.batch_size), batch_size=args.batch_size, num_workers=args.num_workers, drop_last=True)
                    malign_dataset = torch.utils.data.ConcatDataset(malign_dataset)
                    malign_loader = torch.utils.data.DataLoader(malign_dataset, sampler=RandomSampler(len(malign_dataset), (args.iterations-iteration) * args.batch_size), batch_size=args.batch_size, num_workers=args.num_workers, drop_last=True)
                    benign_iter = iter(benign_loader)
                    malign_iter = iter(malign_loader)
                    k = 0
                else:
                    iteration -= 1
                    break
            
            with torch.no_grad():
                content = generator['gen_c'](inputs_i)
                content_pred = classifier["cls_c"](content)
                ood_pred = classifier["cls_c_o"](content).view(inputs_i.size(0), 2, -1).softmax(1)

            pseudo_labels = torch.softmax(content_pred, dim=-1)
            max_probs, max_index = torch.max(pseudo_labels, dim=-1)
            target_range = torch.range(0, targets.size(0) - 1).long().to(device)
            ood_score = ood_pred[:, 0, :][target_range, max_index]
            known_mask = (ood_score > .5)
            unknown_mask = (ood_score < .5)
            mask = max_probs.ge(args.ps_th) * unlabeled_mask * known_mask
            targets[mask] = max_index[mask]
            aug_idx[unknown_mask] = -1
        content_opencls_loss = optimize_content_openclassification(inputs_i, targets, required_mask=True, mask=unlabeled_mask.float())
        content, content_cls_loss = optimize_content_classification(inputs_i, targets, require_known_mask=True, known_mask=known_mask.float(), require_feature=True)
        style, style_cls_loss = optimize_style_classification(inputs_i, aug_idx, require_known_mask=True, known_mask=known_mask.float(), require_feature=True)

        # reconstruction
        rec_loss = optimize_reconstruction(inputs_i, content, style)

    if iteration >= args.aug_iter:
        benign_inputs, benign_label = benign_iter.next()
        malign_inputs, malign_label = malign_iter.next()
        benign_inputs, benign_label = benign_inputs.to(device), benign_label.to(device).long()
        malign_inputs, malign_label = malign_inputs.to(device), malign_label.to(device).long()
        content_aug_loss = optimize_content_classification(benign_inputs, benign_label, disentangle=False, require_known_mask=True, known_mask=known_mask.float())
        content_neg_loss = optimize_content_openclassification(malign_inputs, malign_label, negative=True)

    group_scheduler_step()

    # display
    if iteration == 1 or (iteration % 100) == 0:
        used_time = time.time() - ss
        rest = (args.iterations - iteration)/ iteration * (used_time / 60)
        print("[{}/{}] con:{:.3e}, sty:{:.3e}, rec:{:.3e}, c_o:{:.3e}, neg:{:.3e}, aug:{:.3e}, rst:{:.2f}min, lr:{:.3e}".format(
            iteration, args.iterations, content_cls_loss.item(), style_cls_loss.item(), rec_loss.item(), content_opencls_loss.item(), content_neg_loss.item(), content_aug_loss.item(), rest, optims["gen_c"].param_groups[0]["lr"]), "\r", end="")

    # validation
    if (iteration % args.validation) == 0 or iteration == args.iterations:
        with torch.no_grad():
            generator['gen_c'].eval()
            classifier["cls_c"].eval()
            classifier["cls_c_o"].eval()
            print()
            print("### validation ###")
            sum_acc = 0.
            s = time.time()
            for j, data in enumerate(val_loader):
                inputs, targets = data
                inputs, targets = inputs.to(device).float(), targets.to(device).long()
                output = classifier["cls_c"](generator['gen_c'](inputs))
                pred_label = output.max(1)[1]
                sum_acc += (pred_label == targets).float().sum()
                if ((j+1) % 10) == 0:
                    d_p_s = 10/(time.time()-s)
                    print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format(
                        j+1, len(val_loader), d_p_s, (len(val_loader) - j-1)/d_p_s
                    ), "\r", end="")
                    s = time.time()
            acc = sum_acc/args.n_valid
            print()
            print("validation accuracy : {}".format(acc))
            condition["val_acc"].append(acc.item())
            # test
            print("### test ###")
            maximum_val_acc = acc
            sum_acc = 0.
            s = time.time()
            for j, data in enumerate(test_loader):
                inputs, targets = data
                inputs, targets = inputs.to(device).float(), targets.to(device).long()
                output = classifier["cls_c"](generator['gen_c'](inputs))
                output_open = classifier["cls_c_o"](generator['gen_c'](inputs)).view(targets.size(0), 2, -1)
                output = F.softmax(output, 1)
                output_open = F.softmax(output_open, 1)
                pred_range = torch.range(0, output_open.size(0) - 1).long().cuda()
                pred_label = output.max(1)[1]
                # our unk_score is different from open match.
                unk_score = output_open[pred_range, 1, pred_label]
                kn_score = output_open.max(1)[0]
                unk_idxs = targets >= args.tot_class
                targets[unk_idxs] = args.tot_class
                kn_idxs = targets < args.tot_class
                kn_targets = targets[kn_idxs]
                kn_pred = pred_label[kn_idxs]
                
                if j == 0:
                    unk_all = unk_score
                    label_all = targets
                else:
                    unk_all = torch.cat([unk_all, unk_score], 0)
                    label_all = torch.cat([label_all, targets], 0)

                sum_acc += (kn_pred == kn_targets).float().sum()
                if ((j+1) % 10) == 0:
                    d_p_s = 10/(time.time()-s)
                    print("[{}/{}] time : {:.1f} data/sec, rest : {:.2f} sec".format(
                        j+1, len(test_loader), d_p_s, (len(test_loader) - j-1)/d_p_s
                    ), "\r", end="")
                    s = time.time()
            print()
            test_acc = sum_acc / args.n_test
            condition["test_acc"].append(test_acc.item())
            print("test accuracy : {}".format(test_acc))
            unk_all = unk_all.data.cpu().numpy()
            label_all = label_all.data.cpu().numpy()
            roc = compute_roc(unk_all, label_all, num_known=args.tot_class)
            condition["test_roc"].append(roc.item())
        generator['gen_c'].train()
        classifier["cls_c"].train()
        classifier["cls_c_o"].train()

print("test acc : {}".format(test_acc))

exp_name += str(time.strftime('%Y-%m-%d_%H-%M-%S', time.localtime(time.time()))) # unique ID
condition["total_time"] = str((time.time() - ss) / 3600) + 'h'
if not os.path.exists(args.output):
    os.mkdir(args.output)
with open(os.path.join(args.output, exp_name + ".json"), "w") as f:
    json.dump(args.__dict__, f, indent=2)
    json.dump(condition, f, indent=2)
