## This code is adapted from https://github.com/kekmodel/FixMatch-pytorch

import argparse
import logging
import math
import os
import random
import shutil
import time
from collections import OrderedDict
import sys
import copy

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast

from dataset.cifar import DATASET_GETTERS
from utils import AverageMeter, accuracy, ECELoss


logger = logging.getLogger(__name__)
best_acc = 0
best_vote_acc = 0


class ThresholdScheduler:
    def __init__(self, warmup_epochs, init_thres, final_thres, total_steps, steps_per_epoch):
        warmup_iter = steps_per_epoch * warmup_epochs
        warmup_schedule = np.linspace(init_thres, final_thres, warmup_iter)
        decay_iter = total_steps - warmup_iter
        constant_schedule = np.ones(decay_iter)*final_thres
        self.thres_schedule = np.concatenate((warmup_schedule, constant_schedule))
        self.iter = int(-1)

    def get_threshold(self):
        self.iter += 1
        return self.thres_schedule[self.iter]

def save_checkpoint(state, is_best, checkpoint, filename='checkpoint.pth.tar'):
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint,
                                               'model_best.pth.tar'))

def save_checkpoint_epoch(state, checkpoint, epoch):
    filename = 'ep'+str(epoch)+'.pth.tar'
    filepath = os.path.join(checkpoint, filename)
    torch.save(state, filepath)

def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


def get_cosine_schedule_with_warmup(optimizer,
                                    num_warmup_steps,
                                    num_training_steps,
                                    num_cycles=7./16.,
                                    last_epoch=-1):
    def _lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        no_progress = float(current_step - num_warmup_steps) / \
            float(max(1, num_training_steps - num_warmup_steps))
        return max(0., math.cos(math.pi * num_cycles * no_progress))

    return LambdaLR(optimizer, _lr_lambda, last_epoch)

def main():
    parser = argparse.ArgumentParser(description='PyTorch FixMatch Training')
    parser.add_argument('--gpu_id', default='0', type=int,
                        help='id(s) for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--num_workers', type=int, default=4,
                        help='number of workers')
    parser.add_argument('--dataset', default='cifar100', type=str,
                        choices=['cifar10', 'cifar100'],
                        help='dataset name')
    parser.add_argument('--num_labeled', type=int, default=4000,
                        help='number of labeled data')
    parser.add_argument("--expand_labels", action="store_true",
                        help="expand labels to fit eval steps")
    parser.add_argument('--arch', default='wideresnet', type=str,
                        choices=['wideresnet', 'resnext','resnet18'],
                        help='dataset name')
    parser.add_argument("--non_lin_clas", action="store_true")
    parser.add_argument('--num_epochs', default=1024, type=int)
    parser.add_argument("--clas_bnorm", action="store_true")
    parser.add_argument('--depth', type=int, default=2,
                        help='depth of non linear classifier')
    parser.add_argument('--eval_step', default=1024, type=int,
                        help='number of eval steps to run')
    parser.add_argument('--start_epoch', default=0, type=int,
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('--batch_size', default=64, type=int,
                        help='train batchsize')
    parser.add_argument('--lr', '--learning-rate', default=0.03, type=float,
                        help='initial learning rate')
    parser.add_argument('--warmup', default=0, type=float,
                        help='warmup epochs (unlabeled data based)')
    parser.add_argument('--wdecay', default=5e-4, type=float,
                        help='weight decay')
    parser.add_argument('--nesterov', action='store_true', default=True,
                        help='use nesterov momentum')
    parser.add_argument('--use_ema', action='store_true',default=True,
                        help='use EMA model')
    parser.add_argument('--ema_decay', default=0.999, type=float,
                        help='EMA decay rate')
    parser.add_argument('--mu', default=7, type=int,
                        help='coefficient of unlabeled batch size')
    parser.add_argument('--lambda_u', default=1, type=float,
                        help='coefficient of unlabeled loss')
    parser.add_argument('--T', default=1, type=float,
                        help='pseudo label temperature')
    parser.add_argument('--threshold', default=0.95, type=float,
                        help='pseudo label threshold')
    parser.add_argument('--out', default='result',
                        help='directory to output the result')
    parser.add_argument('--resume', default='', type=str,
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--seed', default=None, type=int,
                        help="random seed")
    parser.add_argument("--amp", action="store_true",
                        help="use 16-bit (mixed) precision through NVIDIA apex AMP")
    parser.add_argument("--opt_level", type=str, default="O1",
                        help="apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--no_progress', action='store_true',
                        help="don't use progress bar")
    parser.add_argument("--pretrained_model", action="store_true")
    parser.add_argument("--pt_path", default='', type=str)

    parser.add_argument('--dropout', default=0., type=float)

    parser.add_argument('--hidden_dim', default=512, type=int)
    parser.add_argument('--feat_dim', default=256, type=int)
    parser.add_argument('--proj_depth', default=1, type=int)
    parser.add_argument("--user", type=str, default="user")
    parser.add_argument("--use_same_idx", action="store_true", default = True)

    parser.add_argument('--temp', default=0.5, type=float)
    parser.add_argument('--bayes',type=str,choices=['none','vote','avg'],default='none')
    parser.add_argument('--optim_separate',action='store_true')
    parser.add_argument('--bayes_sch',action='store_true')

    parser.add_argument('--bayes_samples',type=int,default=50)
    parser.add_argument('--kl',type=float,default=1.)
    parser.add_argument('--bayes_lr',type=float,default=1.)
    parser.add_argument('--std_threshold', default=0.02, type=float,
                        help='pseudo label threshold in std of bayes output')

    parser.add_argument('--prior_mu',type=float,default=0)
    parser.add_argument('--prior_sig',type=float,default=1)

    parser.add_argument('--vote_threshold', default=0.95, type=float)
    parser.add_argument('--quantile', default=-1, type=float)
    parser.add_argument('--q_warmup', default=-1, type=int)
    parser.add_argument('--q_queue', action='store_true')
    parser.add_argument('--save_ep', action='store_true', default=True)
    parser.add_argument('--flipout', action='store_true')
    parser.add_argument('--reparam', action='store_true')
    parser.add_argument('--save_buffer_sd', action='store_true')
    parser.add_argument('--quan_scheduler', action='store_true')

    parser.add_argument('--quansch_warmup',type=int, default=10)
    parser.add_argument('--init_quan',type=float, default=0.1)
    parser.add_argument('--final_quan',type=float, default=0.9)


    args = parser.parse_args()
    global best_acc
    global best_vote_acc

    args.total_steps = args.num_epochs * args.eval_step

    args.path_to_npy = './fm_npy/'

    def create_model(args):
        if args.arch == 'wideresnet':
            import models.wideresnet_ssl as models
            if args.bayes != 'none':
                model = models.BayesCEWideResNet(num_classes=args.num_classes,
                                                depth=args.model_depth,
                                                widen_factor=args.model_width,
                                                dropout=0.,
                                                prior_mu=args.prior_mu,
                                                prior_sigma=args.prior_sig,
                                                flipout=args.flipout,
                                                reparam = args.reparam,
                                                save_buffer_sd=args.save_buffer_sd
                                                )
            else:
                model = models.SupCEWideResNet(num_classes=args.num_classes,
                                            depth=args.model_depth,
                                            widen_factor=args.model_width,
                                            dropout=0.,
                                            )
            proj = None

            if args.pretrained_model:
                model_dict = model.state_dict()
                pt_path = args.pt_path
                ckpt = torch.load(pt_path, map_location='cpu')
                state_dict = ckpt['model']
                new_state_dict = {}
                for k, v in state_dict.items():
                    k = k.replace("module.", "")
                    if k not in model.state_dict(): continue # only update necessary keys
                    new_state_dict[k] = v
                model_dict.update(new_state_dict)
                model.load_state_dict(model_dict)

                print("loaded pretrained model from ", pt_path)

        else:
            raise ValueError("only wideresnet allowed!")

        print("Total params: {:.2f}M".format(
            sum(p.numel() for p in model.parameters())/1e6))
        return model, proj

    if args.local_rank == -1:
        device = torch.device('cuda', args.gpu_id)
        args.world_size = 1
        args.n_gpu = torch.cuda.device_count()
    else:
        raise "only single gpu training"
        torch.cuda.set_device(args.local_rank)
        device = torch.device('cuda', args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.world_size = torch.distributed.get_world_size()
        args.n_gpu = 1

    args.device = device
    print(args.n_gpu)
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    logger.warning(
        f"Process rank: {args.local_rank}, "
        f"device: {args.device}, "
        f"n_gpu: {args.n_gpu}, "
        f"distributed training: {bool(args.local_rank != -1)}, "
        f"16-bits training: True",)

    print(dict(args._get_kwargs()))

    if args.seed is not None:
        set_seed(args)

    if args.local_rank in [-1, 0]:
        os.makedirs(args.out, exist_ok=True)
        args.writer = SummaryWriter(args.out)

    if args.dataset == 'cifar10':
        args.num_classes = 10
        if args.arch == 'wideresnet':
            args.model_depth = 28
            args.model_width = 2
        elif args.arch == 'resnext':
            args.model_cardinality = 4
            args.model_depth = 28
            args.model_width = 4

    elif args.dataset == 'cifar100':
        args.num_classes = 100
        if args.arch == 'wideresnet':
            args.model_depth = 28
            args.model_width = 8
        elif args.arch == 'resnext':
            args.model_cardinality = 8
            args.model_depth = 29
            args.model_width = 64

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()

    labeled_dataset, unlabeled_dataset, test_dataset, labeled_idx = DATASET_GETTERS[args.dataset](
        args, './data',trans='fixmatch')
    print(labeled_idx)
    np.save(args.out + '/idx.npy',labeled_idx)
    if args.local_rank == 0:
        torch.distributed.barrier()

    train_sampler = RandomSampler if args.local_rank == -1 else DistributedSampler

    labeled_trainloader = DataLoader(
        labeled_dataset,
        sampler=train_sampler(labeled_dataset),
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        drop_last=True)

    unlabeled_trainloader = DataLoader(
        unlabeled_dataset,
        sampler=train_sampler(unlabeled_dataset),
        batch_size=args.batch_size*args.mu,
        num_workers=args.num_workers,
        drop_last=True)

    test_loader = DataLoader(
        test_dataset,
        sampler=SequentialSampler(test_dataset),
        batch_size=args.batch_size,
        num_workers=args.num_workers)

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()

    model,proj = create_model(args)

    if args.local_rank == 0:
        torch.distributed.barrier()

    model.to(args.device)
    if proj is not None:
        proj.to(args.device)

    no_decay = ['bias', 'bn']

    if args.bayes != 'none' and args.optim_separate:
        grouped_parameters = [
            {'params': [p for n, p in model.encoder.named_parameters() if not any(
                nd in n for nd in no_decay)], 'weight_decay': args.wdecay},
            {'params': [p for n, p in model.encoder.named_parameters() if any(
                nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    else:
        grouped_parameters = [
            {'params': [p for n, p in model.named_parameters() if not any(
                nd in n for nd in no_decay)], 'weight_decay': args.wdecay},
            {'params': [p for n, p in model.named_parameters() if any(
                nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    if proj is not None:
        grouped_parameters += [{'params':proj.parameters(), 'weight_decay':args.wdecay}]

    optimizer = optim.SGD(grouped_parameters,
                            lr=args.lr,
                            momentum=0.9, nesterov=args.nesterov)

    if args.bayes != 'none' and args.optim_separate:
        bayes_optimizer = optim.Adam(model.fc.parameters(), lr=args.bayes_lr)
        if args.bayes_sch:
            bayes_scheduler = get_cosine_schedule_with_warmup(
                bayes_optimizer, args.warmup, args.total_steps)
        else:
            bayes_scheduler = None
    else:
        bayes_optimizer = None
        bayes_scheduler = None
    args.epochs = math.ceil(args.total_steps / args.eval_step)

    scheduler = get_cosine_schedule_with_warmup(
            optimizer, args.warmup, args.total_steps)

    if args.use_ema:
        from models.ema import ModelEMA
        ema_model = ModelEMA(args, model, args.ema_decay, args.device)
    else:
        ema_model = None

    if args.quan_scheduler:
        quantile_sch = ThresholdScheduler(args.quansch_warmup,args.init_quan,args.final_quan,args.total_steps,args.eval_step)
    else:
        quantile_sch = None
    args.start_epoch = 0

    quan_queue = []

    if args.resume:
        print("==> Resuming from checkpoint..")
        assert os.path.isfile(
            args.resume), "Error: no checkpoint directory found!"
        args.out = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        if args.use_ema:
            ema_model.ema.load_state_dict(checkpoint['ema_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        if bayes_optimizer is not None:
            bayes_optimizer.load_state_dict(checkpoint['bayes_optimizer'])
        if bayes_scheduler is not None:
            bayes_scheduler.load_state_dict(checkpoint['bayes_scheduler'])
        quan_queue = checkpoint['quan_queue']

        for _ in range(args.start_epoch):
            for _ in range(args.eval_step):
                if quantile_sch is not None:
                    args.quantile = quantile_sch.get_threshold()

    scaler = GradScaler()

    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank],
            output_device=args.local_rank, find_unused_parameters=True)

    print("***** Running training *****")
    print(f"  Task = {args.dataset}@{args.num_labeled}")
    print(f"  Num Epochs = {args.epochs}")
    print(f"  Batch size per GPU = {args.batch_size}")
    print(
        f"  Total train batch size = {args.batch_size*args.world_size}")
    print(f"  Total optimization steps = {args.total_steps}")

    model.zero_grad()
    train(args, labeled_trainloader, unlabeled_trainloader, test_loader,
          model, optimizer, ema_model, scheduler, scaler=scaler, proj=proj,
          bayes_optimizer=bayes_optimizer, bayes_scheduler=bayes_scheduler,
          quantile_sch=quantile_sch, quan_queue=quan_queue)

def bayes_predict(args, bayeslayer, reps):
    ''' creates args.bayes_samples models and get mean and std of softmax(output) '''
    with torch.no_grad():
        with autocast():
            outputs = [bayeslayer(reps)[0].softmax(dim=-1) for _ in range(args.bayes_samples)]
        outputs = torch.stack(outputs)
        mean_output = torch.mean(outputs, 0)
        std_output = torch.std(outputs,0)
        # max_prob, preds = torch.max(mean_preds, dim=-1)
    return mean_output, std_output

def bayes_vote(args, bayeslayer, reps):
    ''' creates args.bayes_samples models and get mean and std of softmax(output) '''
    with torch.no_grad():
        with autocast():
            outputs = [bayeslayer(reps)[0].softmax(dim=-1) for _ in range(args.bayes_samples)]
        outputs = torch.stack(outputs)
        max_prob, preds = torch.max(outputs,dim=-1)
        modevote_pred = preds.mode(0)[0] ## Todo: need tie-breaker?
        agreement = (preds == modevote_pred).sum(0) / modevote_pred.shape[0]
    return modevote_pred, agreement

def bayes_vote_predict(args, bayeslayer, reps):
    ''' creates args.bayes_samples models and get mean and std of softmax(output) '''
    with torch.no_grad():
        with autocast():
            outputs = [bayeslayer(reps)[0].softmax(dim=-1) for _ in range(args.bayes_samples)]
        outputs = torch.stack(outputs)

        max_prob, preds = torch.max(outputs,dim=-1)
        modevote_pred = preds.mode(0)[0] ## Todo: need tie-breaker?
        agreement = (preds == modevote_pred).sum(0) / modevote_pred.shape[0]

        mean_output = torch.mean(outputs, 0)
        std_output = torch.std(outputs,0)

    return modevote_pred, agreement, mean_output, std_output

def train(args, labeled_trainloader, unlabeled_trainloader, test_loader,
          model, optimizer, ema_model, scheduler, scaler, proj=None,
          bayes_optimizer=None,bayes_scheduler=None,
          quantile_sch=None, quan_queue=[]):
    # if args.amp:
    #     from apex import amp
    global best_acc
    global best_vote_acc

    test_accs = []
    end = time.time()

    if args.world_size > 1:
        labeled_epoch = 0
        unlabeled_epoch = 0
        labeled_trainloader.sampler.set_epoch(labeled_epoch)
        unlabeled_trainloader.sampler.set_epoch(unlabeled_epoch)

    labeled_iter = iter(labeled_trainloader)
    unlabeled_iter = iter(unlabeled_trainloader)

    model.train()
    if proj is not None:
        proj.train()
    for epoch in range(args.start_epoch, args.epochs):
        stime = time.time()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        losses_x = AverageMeter()
        losses_u = AverageMeter()
        mask_probs = AverageMeter()
        pseudolab_acc = AverageMeter()
        pseudolab_acc_s = AverageMeter()

        impurity_rate = AverageMeter()
        ece_meter = AverageMeter()
        pmax_meter = AverageMeter()
        ece_s_meter = AverageMeter()
        pmax_s_meter = AverageMeter()

        if args.bayes != 'none':
            losses_kl = AverageMeter()

        ece_score = ECELoss()

        if not args.no_progress:
            p_bar = tqdm(range(args.eval_step),
                         disable=args.local_rank not in [-1, 0])
        for batch_idx in range(args.eval_step):
            try:
                inputs_x, targets_x = labeled_iter.next()
            except:
                if args.world_size > 1:
                    labeled_epoch += 1
                    labeled_trainloader.sampler.set_epoch(labeled_epoch)
                labeled_iter = iter(labeled_trainloader)
                inputs_x, targets_x = labeled_iter.next()

            try:
                (inputs_u_w, inputs_u_s), targets_u_true = unlabeled_iter.next()
            except:
                if args.world_size > 1:
                    unlabeled_epoch += 1
                    unlabeled_trainloader.sampler.set_epoch(unlabeled_epoch)
                unlabeled_iter = iter(unlabeled_trainloader)

                (inputs_u_w, inputs_u_s), targets_u_true = unlabeled_iter.next()


            data_time.update(time.time() - end)
            batch_size = inputs_x.shape[0]
            inputs = torch.cat((inputs_x, inputs_u_w, inputs_u_s)).to(args.device)

            targets_x = targets_x.to(args.device)
            targets_u_true = targets_u_true.to(args.device)

            if quantile_sch is not None:
                args.quantile = quantile_sch.get_threshold()


            # if args.amp:
            with autocast():
                ftime = time.time()
                rep = model.encoder(inputs)
                if args.bayes != 'none':
                    logits, Lkl = model.fc(rep)
                else:
                    logits = model.fc(rep)


                logits_x = logits[:batch_size]
                logits_u_w, logits_u_s = logits[batch_size:].chunk(2)

                del logits

                Lx = F.cross_entropy(logits_x, targets_x, reduction='mean')
                if args.bayes != 'none':
                    Lx += Lkl / logits_x.shape[0] * args.kl

                if args.bayes != 'none': # sample many models to get prediction and uncertainty
                    rep_u = rep.detach()[batch_size:]
                    if args.bayes == 'vote':
                        rep_u_w, rep_u_s = rep_u.chunk(2)
                        targets_u, certainty_u = bayes_vote(args, model.fc, rep_u_w) # softmax applied inside

                        mask = certainty_u.ge(args.vote_threshold).float()
                    else:
                        mean_output_u, std_output_u = bayes_predict(args, model.fc, rep_u) # softmax applied inside
                        mean_output_u_w, mean_output_u_s = mean_output_u.chunk(2)
                        std_output_u_w, std_output_u_s = std_output_u.chunk(2)

                        max_probs_u_w, targets_u = torch.max(mean_output_u_w,dim=-1)
                        pred_std = torch.gather(std_output_u_w,1,targets_u.view(-1,1)).squeeze(1)

                        mask = pred_std.le(args.std_threshold) # accept samples with std less than threshold
                        accepted_std = pred_std[mask]

                        if args.quantile != -1 and epoch > args.q_warmup:
                            new_threshold = torch.quantile(pred_std, args.quantile).item()
                            if args.q_queue:
                                quan_queue.append(new_threshold)
                                if len(quan_queue)>50: quan_queue.pop(0) # maintain last 50 values
                                new_threshold = np.mean(quan_queue)
                            args.std_threshold = new_threshold

                        mask = mask.float()
                else:
                    pseudo_label = torch.softmax(logits_u_w.detach(), dim=-1)
                    max_probs, targets_u = torch.max(pseudo_label, dim=-1)
                    mask = max_probs.ge(args.threshold).float()
                pseudoacc = (targets_u == targets_u_true).float().mean()
                if args.bayes != 'none':
                    ece = ece_score(mean_output_u_w, targets_u_true)
                    pmax_mean = torch.mean(max_probs_u_w)
                else:
                    ece = ece_score(pseudo_label, targets_u_true)
                    pmax_mean = torch.mean(max_probs)
                pred_s = torch.softmax(logits_u_s.detach(),dim=-1)
                pmax_s, tar_s = torch.max(pred_s,dim=-1)
                pseudoacc_s = (tar_s == targets_u_true).float().mean()
                ece_s = ece_score(pred_s, targets_u_true)
                pmax_s_mean = torch.mean(pmax_s)

                if mask.sum() > 0:
                    impurity = ((targets_u == targets_u_true) * mask).sum() / mask.sum()
                    impurity_rate.update(impurity.item())

                Lu = (F.cross_entropy(logits_u_s, targets_u,
                                  reduction='none') * mask).mean()

                loss = Lx + args.lambda_u * Lu

            optimizer.zero_grad()
            if bayes_optimizer is not None:
                bayes_optimizer.zero_grad()

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            if bayes_optimizer is not None:
                scaler.step(bayes_optimizer)
            scaler.update()

            losses.update(loss.item())
            losses_x.update(Lx.item())
            losses_u.update(Lu.item())
            if args.bayes != 'none':
                losses_kl.update(Lkl.item())

            # optimizer.step()
            scheduler.step()
            if bayes_scheduler is not None:
                bayes_scheduler.step()

            with autocast():
                if args.use_ema:
                    ema_model.update(model)
            model.zero_grad()
            if proj is not None:
                proj.zero_grad()

            batch_time.update(time.time() - end)
            end = time.time()
            mask_probs.update(mask.mean().item())
            pseudolab_acc.update(pseudoacc.item())
            pseudolab_acc_s.update(pseudoacc_s.item())

            ece_meter.update(ece.item())
            pmax_meter.update(pmax_mean.item())
            ece_s_meter.update(ece_s.item())
            pmax_s_meter.update(pmax_s_mean.item())

            bayes_lr = bayes_scheduler.get_last_lr()[0] if bayes_scheduler is not None else args.bayes_lr
            if not args.no_progress:
                p_bar.set_description("Train Ep:{epoch}. It:{batch:4}. LR:{lr:.3f}. Lx:{loss_x:.3f}. Lu:{loss_u:.4f}. Mask:{mask:.2f}. quan:{quan:.3f}. ece:{ece:.3f}. acc:{pseudoacc:.3f}. pmax:{maxprob:.3f}. s_ece:{ece_s:.3f}. s_acc:{pseudoacc_s:.3f}. s_pmax:{maxprob_s:.3f}".format(
                    epoch=epoch + 1,
                    batch=batch_idx + 1,
                    lr=scheduler.get_last_lr()[0],
                    loss_x=losses_x.avg,
                    loss_u=losses_u.avg,
                    mask=mask_probs.avg,
                    quan=args.quantile,
                    ece=ece_meter.avg,
                    pseudoacc= pseudolab_acc.avg,
                    maxprob= pmax_meter.avg,
                    ece_s= ece_s_meter.avg,
                    maxprob_s= pmax_s_meter.avg,
                    pseudoacc_s= pseudolab_acc_s.avg))

                p_bar.update()

        if not args.no_progress:
            p_bar.close()

        if args.use_ema:
            test_model = ema_model.ema
        else:
            test_model = copy.deepcopy(model)

        if args.local_rank in [-1, 0]:
            if args.bayes != 'none':
                test_loss, test_acc, test_vote_acc, test_ece = test_bayes(args, test_loader, test_model, epoch)
            else:
                test_loss, test_acc, test_ece = test(args, test_loader, test_model, epoch)
                test_vote_acc = 0.
            args.writer.add_scalar('train/1.train_loss', losses.avg, epoch)
            args.writer.add_scalar('train/2.train_loss_x', losses_x.avg, epoch)
            args.writer.add_scalar('train/3.train_loss_u', losses_u.avg, epoch)
            args.writer.add_scalar('train/4.mask', mask_probs.avg, epoch)
            args.writer.add_scalar('train/9.pseudo_acc', pseudolab_acc.avg, epoch)
            args.writer.add_scalar('train/10.impurity_rate', impurity_rate.avg, epoch)
            args.writer.add_scalar('train/11.learning_rate', scheduler.get_last_lr()[0], epoch)
            if args.bayes != 'none':
                args.writer.add_scalar('train/12.train_loss_kl', losses_kl.avg, epoch)
                if bayes_scheduler is not None:
                    args.writer.add_scalar('train/15.bayes_lr', bayes_scheduler.get_last_lr()[0], epoch)
                if args.quantile != -1:
                    args.writer.add_scalar('train/16.std_threshold', args.std_threshold, epoch)
            args.writer.add_scalar('train/17.ECE', ece_meter.avg, epoch)
            args.writer.add_scalar('train/18.max_prob_mean', pmax_meter.avg, epoch)
            args.writer.add_scalar('train/19.ECE_s', ece_s_meter.avg, epoch)
            args.writer.add_scalar('train/20.max_prob_s_mean', pmax_s_meter.avg, epoch)
            args.writer.add_scalar('train/21.pseudo_acc_s', pseudolab_acc_s.avg, epoch)

            args.writer.add_scalar('test/1.test_acc', test_acc, epoch)
            args.writer.add_scalar('test/2.test_loss', test_loss, epoch)
            args.writer.add_scalar('test/3.test_vote_acc', test_vote_acc, epoch)
            args.writer.add_scalar('test/4.test_ECE', test_ece, epoch)

            is_best = test_acc > best_acc
            best_acc = max(test_acc, best_acc)

            is_best_vote = test_vote_acc > best_vote_acc
            best_vote_acc = max(test_vote_acc,best_vote_acc)
            model_to_save = model.module if hasattr(model, "module") else model
            if args.use_ema:
                ema_to_save = ema_model.ema.module if hasattr(
                    ema_model.ema, "module") else ema_model.ema
            save_dict = {
                'epoch': epoch + 1,
                'state_dict': model_to_save.state_dict(),
                'ema_state_dict': ema_to_save.state_dict() if args.use_ema else None,
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'quan_queue': quan_queue
            }
            if bayes_optimizer is not None:
                save_dict['bayes_optimizer'] = bayes_optimizer.state_dict()
            if bayes_scheduler is not None:
                save_dict['bayes_scheduler'] = bayes_scheduler.state_dict()

            save_checkpoint(save_dict, is_best, args.out)

            if args.save_ep and epoch%50 == 0:
                save_checkpoint_epoch(save_dict, args.out, epoch)
            test_accs.append(test_acc)
            print('Best top-1 acc:{:.2f} | Best top-1 vote acc: {:.2f}'.format(best_acc,best_vote_acc))
            # print('Mean top-1 acc: {:.2f}\n'.format(
                # np.mean(test_accs[-20:])))
        line_to_print = (
                f'epoch: {epoch+1} | train_loss: {losses.avg:.3f} | '
                f'test_acc: {test_acc:.3f} | lr: {scheduler.get_last_lr()[0]:.6f}  | '
                f'mask: {mask_probs.avg:.3f} '
                f'pseudo_acc: {pseudolab_acc.avg:.3f} | impurity_rate: {impurity_rate.avg:.3f} | '
                f'time per epoch: {time.time()-stime}'
            )
        print(line_to_print)
        sys.stdout.flush()

    if args.local_rank in [-1, 0]:
        args.writer.close()

def test(args, test_loader, model, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()
    ecemeter = AverageMeter()

    if not args.no_progress:
        test_loader = tqdm(test_loader,
                           disable=args.local_rank not in [-1, 0])
    ece_score = ECELoss()

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            data_time.update(time.time() - end)
            model.eval()

            inputs = inputs.to(args.device)
            targets = targets.to(args.device)
            with autocast():
                outputs = model(inputs)
                loss = F.cross_entropy(outputs, targets)

            ece = ece_score(outputs.softmax(dim=-1),targets)

            prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.shape[0])
            top1.update(prec1.item(), inputs.shape[0])
            top5.update(prec5.item(), inputs.shape[0])
            ecemeter.update(ece.item(), inputs.shape[0])

            batch_time.update(time.time() - end)
            end = time.time()
            if not args.no_progress:
                test_loader.set_description("Test Iter: {batch:4}/{iter:4}. Data: {data:.3f}s. Batch: {bt:.3f}s. Loss: {loss:.4f}. top1: {top1:.2f}. top5: {top5:.2f}. ".format(
                    batch=batch_idx + 1,
                    iter=len(test_loader),
                    data=data_time.avg,
                    bt=batch_time.avg,
                    loss=losses.avg,
                    top1=top1.avg,
                    top5=top5.avg,
                ))
        if not args.no_progress:
            test_loader.close()

    print("top-1 acc: {:.2f}".format(top1.avg))
    print("top-5 acc: {:.2f}".format(top5.avg))
    print("ECE: {:.2f}".format(ecemeter.avg))

    return losses.avg, top1.avg, ecemeter.avg


def test_bayes(args, test_loader, model, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    top1vote = AverageMeter()
    end = time.time()
    ecemeter = AverageMeter()

    if not args.no_progress:
        test_loader = tqdm(test_loader,
                           disable=args.local_rank not in [-1, 0])
    ece_score = ECELoss()
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            data_time.update(time.time() - end)
            model.eval()

            inputs = inputs.to(args.device)
            targets = targets.to(args.device)
            with autocast():
                reps = model.encoder(inputs)
                modevote, uncertainty, mean_output, std_output = bayes_vote_predict(args,model.fc, reps)
                ece = ece_score(mean_output, targets)

                outputs, kl = model(inputs)
                loss = F.cross_entropy(outputs, targets)
                voteprec1 = modevote.eq(targets).sum(0)/targets.size(0) * 100

            prec1, prec5 = accuracy(mean_output, targets, topk=(1, 5))

            # prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.shape[0])
            top1.update(prec1.item(), inputs.shape[0])
            top5.update(prec5.item(), inputs.shape[0])
            top1vote.update(voteprec1.item(), inputs.shape[0])
            ecemeter.update(ece.item(), inputs.shape[0])

            batch_time.update(time.time() - end)
            end = time.time()
            if not args.no_progress:
                test_loader.set_description("Test Iter: {batch:4}/{iter:4}. Data: {data:.3f}s. Batch: {bt:.3f}s. Loss: {loss:.4f}. top1: {top1:.2f}. top5: {top5:.2f}. ".format(
                    batch=batch_idx + 1,
                    iter=len(test_loader),
                    data=data_time.avg,
                    bt=batch_time.avg,
                    loss=losses.avg,
                    top1=top1.avg,
                    top5=top5.avg,
                ))
        if not args.no_progress:
            test_loader.close()

    print("top-1 acc:{:.2f}, top-1 vote acc:{:.2f}".format(top1.avg,top1vote.avg))
    print("top-5 acc:{:.2f}".format(top5.avg))
    print("ECE:{:.2f}".format(ecemeter.avg))

    return losses.avg, top1.avg, top1vote.avg, ecemeter.avg

if __name__ == '__main__':
    main()
