# This code is adapted from https://github.com/facebookresearch/suncet

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import os
import argparse
from tqdm import tqdm
import json
# -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS
try:
    # -- WARNING: IF DOING DISTRIBUTED TRAINING ON A NON-SLURM CLUSTER, MAKE
    # --          SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE
    # --          THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE
    # --          TO EACH PROCESS
    os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['SLURM_LOCALID']
except Exception:
    pass

import logging
import sys
from collections import OrderedDict

import numpy as np

import torch

import src.resnet as resnet
from src.wideresnet_ssl import WideResNetProj, Online_Classifier

from src.utils import (
    gpu_timer,
    init_distributed,
    WarmupCosineSchedule,
    CSVLogger,
    AverageMeter,
	accuracy,
    ECELoss
)
from src.mod_losses import (
    init_paws_online_loss,
)
from src.data_manager import (
    init_data,
    make_transforms,
    make_multicrop_transform
)
from src.sgd import SGD
from src.lars import LARS
import torch.optim as optim

import apex
from torch.nn.parallel import DistributedDataParallel

from torch.utils.tensorboard import SummaryWriter
from mod_snn_eval import mod_init_pipe, make_embeddings_fix, evaluate_embeddings_fix
# --
log_timings = True
log_freq = 10
# --

_GLOBAL_SEED = 0
np.random.seed(_GLOBAL_SEED)
torch.manual_seed(_GLOBAL_SEED)
torch.backends.cudnn.benchmark = True

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()


def main(args):
    checkpoint_freq = args.checkpoint_freq

	# create tensorboard logging
    tb_dir = args.folder + '/' + args.tag + '_tb'
    os.makedirs(tb_dir, exist_ok=True)
    writer = SummaryWriter(tb_dir)
    eval_snn = args.eval_snn
    # ----------------------------------------------------------------------- #
    #  PASSED IN PARAMS FROM CONFIG FILE
    # ----------------------------------------------------------------------- #
    # -- META
    # model_name = args.model_name
    proj_dim = args.proj_dim
    load_model = args.load_checkpoint
    r_file = args.read_checkpoint
    copy_data = args.copy_data
    use_fp16 = args.use_fp16
    use_pred_head = False
    device = torch.device(args.device)
    torch.cuda.set_device(device)

    # -- CRITERTION
    reg = args.me_max
    if args.no_me_max:
        reg = False
    # supervised_views = args.supervised_views
    # classes_per_batch = args.classes_per_batch
    # s_batch_size = args.supervised_imgs_per_class

    # u_batch_size = args.unsupervised_batch_size
    temperature = args.temperature
    sharpen = args.sharpen
    online_head = args.online_head
    from_proj1 = args.from_proj1
    class_mlp = args.class_mlp
    class_depth = args.class_depth
    if args.dataset == 'cifar100':
        classes_per_batch = 100
        model_name='wide_resnet28w8'
        s_batch_size = args.sup_bs
        supervised_views = args.supervised_views
        u_batch_size = 256
        color_jitter = 0.5
        unique_classes = False


    elif args.dataset == 'cifar10':
        classes_per_batch = 10
        model_name='wide_resnet28w2'
        s_batch_size = 64
        supervised_views = 2
        u_batch_size = 256
        color_jitter = 0.5
        unique_classes = False


    else:
        model_name='resnet50'
        s_batch_size = 7
        classes_per_batch = 15
        supervised_views = 1
        u_batch_size = 64
        color_jitter = 1.0
        unique_classes = True


        raise ValueError("not implemented")
    # -- DATA
    unlabeled_frac = args.unlabeled_frac
    normalize = True
    root_path = 'datasets/'
    image_folder = args.dataset + '-data/'
    dataset_name = args.dataset
    subset_path = args.dataset + '_newsplits/'
    multicrop = args.multicrop
    label_smoothing = args.lab_smooth
    ft_dataset_name = args.dataset + '_fine_tune'

    data_seed = args.data_seed
    if 'cifar10' in dataset_name:
        # data_seed = 152
        crop_scale = (0.75, 1.0) if multicrop > 0 else (0.5, 1.0)
        mc_scale = (0.3, 0.75)
        mc_size = 18
    else:
        crop_scale = (0.14, 1.0) if multicrop > 0 else (0.08, 1.0)
        mc_scale = (0.05, 0.14)
        mc_size = 96

    # -- OPTIMIZATION
    wd = float(args.weight_decay)
    num_epochs = args.epochs
    warmup = args.warmup
    start_lr = args.lr/4
    lr = args.lr
    final_lr = args.lr / args.lr_decay_factor
    mom = 0.9
    nesterov = False

    head_factor = args.head_factor
    version = args.version

    # -- LOGGING
    folder = args.folder
    tag = args.tag

    if 'imagenet' in dataset_name:
        num_classes = 1000
    elif 'cifar100' in dataset_name:
        num_classes = 100
    else:
        num_classes = 10
    # ----------------------------------------------------------------------- #
    if eval_snn:
        print("retrieving snn train set")
        snn_train_data_loader, snn_train_data_sampler = mod_init_pipe(True,
                                                subset_path=subset_path,
                                                root_path=root_path,
                                                image_folder=image_folder,
                                                unlabeled_frac=unlabeled_frac,
                                                dataset_name=ft_dataset_name,
                                                model_name=model_name,
                                                use_pred=use_pred_head,
                                                normalize=normalize,
                                                device_str=device,
                                                split_seed=data_seed)
        print("retrieving snn test set")
        snn_test_data_loader, snn_test_data_sampler = mod_init_pipe(False,
                                                subset_path=subset_path,
                                                root_path=root_path,
                                                image_folder=image_folder,
                                                unlabeled_frac=unlabeled_frac,
                                                dataset_name=ft_dataset_name,
                                                model_name=model_name,
                                                use_pred=use_pred_head,
                                                normalize=normalize,
                                                device_str=device,
                                                split_seed=data_seed)

    # -- init torch distributed backend
    world_size, rank = init_distributed()
    logger.info(f'Initialized (rank/world-size) {rank}/{world_size}')

    # -- log/checkpointing paths
    log_file = os.path.join(folder, f'{tag}_r{rank}.csv')
    save_path = os.path.join(folder, f'{tag}' + '-ep{epoch}.pth.tar')
    latest_path = os.path.join(folder, f'{tag}-latest.pth.tar')
    best_path = os.path.join(folder, f'{tag}' + '-best.pth.tar')
    load_path = None
    if load_model:
        load_path = os.path.join(folder, r_file) if r_file is not None else latest_path

    # -- make csv_logger
    csv_logger = CSVLogger(log_file,
                           ('%d', 'epoch'),
                           ('%d', 'itr'),
                           ('%.5f', 'paws-xent-loss'),
                           ('%.5f', 'paws-me_max-reg'),
                           ('%d', 'time (ms)'))
    # save args dict
    json_path = os.path.join(folder, f'{tag}.json')
    if os.path.exists(json_path):
        json_path = json_path[:-4]+"_1.json"
    with open(json_path, "w") as dictfile:
        json.dump(dict(args._get_kwargs()), dictfile)
    # -- init model
    encoder, classifier, ema_encoder = init_model(
        device=device,
        model_name=model_name,
        use_pred=use_pred_head,
        output_dim=proj_dim,
        class_mlp=class_mlp,
		version=version,
        class_depth=class_depth,
        num_classes=num_classes,
        hidden_dim=args.hidden_dim,
        dropout=args.dropout)
    print("Total encoder params: {:.2f}M".format(
            sum(p.numel() for p in encoder.parameters())/1e6))
    print("Total classifier params: {:.2f}M".format(
                sum(p.numel() for p in classifier.parameters())/1e6))

    if world_size > 1:
        process_group = apex.parallel.create_syncbn_process_group(0)
        encoder = apex.parallel.convert_syncbn_model(encoder, process_group=process_group)

    online_paws = init_paws_online_loss(
        multicrop=multicrop,
        T=sharpen,
        me_max=reg,
        class_temp=args.class_temp)

    onclass = init_classifier_loss()

    # -- assume support images are sampled with ClassStratifiedSampler
    # labels_matrix = make_labels_matrix(
    #     num_classes=classes_per_batch,
    #     s_batch_size=s_batch_size,
    #     world_size=world_size,
    #     device=device,
    #     unique_classes=unique_classes,
    #     smoothing=label_smoothing)

    # -- make data transforms

    transform, init_transform = make_transforms(
        dataset_name=dataset_name,
        subset_path=subset_path,
        unlabeled_frac=unlabeled_frac,
        training=True,
        split_seed=data_seed,
        crop_scale=crop_scale,
        basic_augmentations=args.use_basic,
        color_jitter=color_jitter,
        normalize=normalize)
    multicrop_transform = (multicrop, None)
    if multicrop > 0:
        multicrop_transform = make_multicrop_transform(
                dataset_name=dataset_name,
                num_crops=multicrop,
                size=mc_size,
                crop_scale=mc_scale,
                normalize=normalize,
                color_distortion=color_jitter)

    # -- init data-loaders/samplers
    print("retrieving supervised and unsupervised")

    (unsupervised_loader,
     unsupervised_sampler,
     supervised_loader,
     supervised_sampler) = init_data(
         dataset_name=dataset_name,
         transform=transform,
         init_transform=init_transform,
         supervised_views=supervised_views,
         u_batch_size=u_batch_size,
         s_batch_size=s_batch_size,
         unique_classes=unique_classes,
         classes_per_batch=classes_per_batch,
         multicrop_transform=multicrop_transform,
         world_size=world_size,
         rank=rank,
         root_path=root_path,
         image_folder=image_folder,
         training=True,
         copy_data=copy_data)
    iter_supervised = None
    ipe = len(unsupervised_loader)
    logger.info(f'iterations per epoch: {ipe}')

    # -- init optimizer and scheduler
    scaler = torch.cuda.amp.GradScaler(enabled=use_fp16)

    encoder, optimizer, scheduler, classifier = init_opt(
            encoder=encoder,
            weight_decay=wd,
            start_lr=start_lr,
            ref_lr=lr,
            final_lr=final_lr,
            ref_mom=mom,
            nesterov=nesterov,
            iterations_per_epoch=ipe,
            warmup=warmup,
            num_epochs=num_epochs,
            classifier=classifier,
            head_factor=head_factor,
            class_wd=args.class_wd,
            classifier_exclude_LARS=args.class_ex_lars)

    if world_size > 1:
        encoder = DistributedDataParallel(encoder, broadcast_buffers=False)

    if args.use_mom_scheduler:
        mom_scheduler = MomentumScheduler(init_mom=args.ema_decay, final_mom=args.final_mom,warmup_from=0, warmup_epochs=args.mom_warmup_epochs,total_epochs=num_epochs, iter_per_epoch=ipe)
    else:
        mom_scheduler = None
        args.mom = args.ema_decay

    start_epoch = 0
    num_swa = 0

    # -- load training checkpoint
    if load_model:
        encoder, optimizer, start_epoch, classifier, ema_encoder, num_swa = load_checkpoint(
            r_path=load_path,
            encoder=encoder,
            classifier=classifier,
            opt=optimizer,
            scaler=scaler,
            use_fp16=use_fp16,
            ema_encoder=ema_encoder
            )
        for _ in range(start_epoch):
            for _ in range(ipe):
                scheduler.step()
                if mom_scheduler is not None:
                    args.mom = mom_scheduler.get_mom()

    args.num_swa = num_swa

    # -- TRAINING LOOP
    best_loss = None
    for epoch in range(start_epoch, num_epochs):
        logger.info('Epoch %d' % (epoch + 1))

        # -- update distributed-data-loader epoch
        unsupervised_sampler.set_epoch(epoch)
        if supervised_sampler is not None:
            supervised_sampler.set_epoch(epoch)

        loss_meter = AverageMeter()
        ploss_meter = AverageMeter()
        rloss_meter = AverageMeter()
        celoss_meter = AverageMeter()

        time_meter = AverageMeter()
        data_meter = AverageMeter()
        vm_snn_meter = AverageMeter()
        vm_head_meter = AverageMeter()
        mp_snn_meter = AverageMeter()
        mp_head_meter = AverageMeter()
        acc_snn_meter = AverageMeter()
        acc_head_meter = AverageMeter()
        ece_meter = AverageMeter()


        for itr, udata in enumerate(unsupervised_loader):
            encoder.train()
            classifier.train()
            if ema_encoder is not None:
                ema_encoder.train()
            def load_imgs():
                # -- unsupervised imgs
                uimgs = [u.to(device, non_blocking=True) for u in udata[:-1]]
                # ulab = udata[-1].repeat(len(udata[:-1])).to(device, non_blocking=True)
                ulab = udata[-1].to(device, non_blocking=True)

                # -- supervised imgs
                global iter_supervised
                try:
                    sdata = next(iter_supervised)
                except Exception:
                    iter_supervised = iter(supervised_loader)
                    logger.info(f'len.supervised_loader: {len(iter_supervised)}')
                    sdata = next(iter_supervised)
                finally:
                    # labels = torch.cat([labels_matrix for _ in range(supervised_views)])
                    simgs = [s.to(device, non_blocking=True) for s in sdata[:-1]]
                    hard_lab = sdata[-1].to(device, non_blocking=True)
                    new_labels_matrix = smoothen_labels(hard_lab, num_classes=num_classes,smoothing=label_smoothing, device=device)
                    labels = torch.cat([new_labels_matrix for _ in range(supervised_views)])
                # -- concatenate supervised imgs and unsupervised imgs
                imgs = simgs + uimgs
                return imgs, labels, hard_lab, ulab
            (imgs, labels, hard_lab, ulab), dtime = gpu_timer(load_imgs)
            data_meter.update(dtime)

            # # print(labels_matrix)
            # _, pred = torch.max(labels_matrix,dim=-1)
            # # print((pred == hard_lab).sum())
            # # print(torch.unique(hard_lab,return_counts=True))
            # # print(hard_lab)
            # new_label_matrix =
            # print(new_label_matrix)
            # _, npred = torch.max(new_label_matrix,dim=-1)
            # print((npred == hard_lab).sum())

            def train_step():

                with torch.cuda.amp.autocast(enabled=use_fp16):
                    optimizer.zero_grad()

                    rep = encoder(imgs)
                    z = encoder.fc(rep)

                    if args.use_ema or args.use_swa:
                        # ema.copy_to(ema_encoder.parameters())
                    	tar_rep = ema_encoder(imgs)
                    	tar_z = ema_encoder.fc(tar_rep)

                    # Compute paws loss in full precision
                    with torch.cuda.amp.autocast(enabled=False):

                        # Step 1. convert representations to fp32
                        z = z.float()
                        # Step 2. determine anchor views/supports and their
                        #         corresponding target views/supports
                        # --
                        num_support = supervised_views * s_batch_size * classes_per_batch
                        # --
                        anchor_supports = z[:num_support]
                        anchor_views = z[num_support:]
                        # --
                        if args.use_ema or args.use_swa:
                            tar_z = tar_z.float()
                            target_supports = tar_z[:num_support]
                            target_views = tar_z[num_support:]
                        else:
                            target_supports = z[:num_support].detach()
                            target_views = z[num_support:].detach()
                        target_views = torch.cat([
                            target_views[u_batch_size:2*u_batch_size],
                            target_views[:u_batch_size]], dim=0)

                        anchor_view_rep = rep[num_support:].float()
                        if args.use_ema or args.use_swa:
                            target_view_rep = tar_rep[num_support:].float()
                        else:
                            target_view_rep = anchor_view_rep.detach()
                        target_view_rep = torch.cat([
                            target_view_rep[u_batch_size:2*u_batch_size],
                            target_view_rep[:u_batch_size]], dim=0)
                        # Step 3. compute paws loss with me-max regularization
                        if args.monitor_stats:
                            (ploss, me_max,vm_snn, vm_head, mp_snn, mp_head, acc_snn, acc_head, ece) = online_paws(
                                anchor_views=anchor_views,
                                anchor_supports=anchor_supports,
                                anchor_support_labels=labels,
                                target_views=target_views,
                                target_supports=target_supports,
                                target_support_labels=labels,
                                online_head=classifier,
                                use_online_classifier=False,
                                monitor_stats=args.monitor_stats,
                                ulab_true=ulab,
                                anchor_views_rep=anchor_view_rep,
                                target_views_rep=target_view_rep,
                                sharpen_online_targets=args.sharpen_online_targets,
                                tau_t=args.tau_t, tau_s=temperature)
                        else:
                            (ploss, me_max) = online_paws(
                                anchor_views=anchor_views,
                                anchor_supports=anchor_supports,
                                anchor_support_labels=labels,
                                target_views=target_views,
                                target_supports=target_supports,
                                target_support_labels=labels,
                                online_head=classifier,
                                use_online_classifier=False,
                                monitor_stats=args.monitor_stats,
                                ulab_true=None,
                                tau_t=args.tau_t, tau_s=temperature)

                        loss = ploss + me_max

                    if args.use_ema_for_online:
                        srep = tar_rep[:num_support].detach()
                    else:
                        srep = rep[:num_support].detach()

                    logits = classifier(srep)

                    if args.x_soft_label:
                        celoss = onclass(logits/args.class_temp, labels)
                    else:
                        if args.supervised_views == 2:
                            hard_labs = torch.cat([hard_lab,hard_lab],dim=0)
                        elif args.supervised_views == 1:
                            hard_labs = hard_lab
                        celoss = torch.nn.functional.cross_entropy(logits/args.class_temp,hard_labs)

                tloss = loss + celoss
                scaler.scale(tloss).backward()

                lr_stats = scaler.step(optimizer)

                scaler.update()
                scheduler.step()
                if mom_scheduler is not None:
                    args.mom = mom_scheduler.get_mom()

                # update teacher
                if args.use_swa:
                    with torch.no_grad():
                        if epoch < args.swa_warmup:
                            for param_s, param_t in zip(encoder.parameters(), ema_encoder.parameters()):
                                param_t.data = param_s.detach().data

                        elif itr % args.swa_freq == 0: # update swa every few iterations
                            args.num_swa += 1
                            for param_s, param_t in zip(encoder.parameters(), ema_encoder.parameters()):
                                param_t.data.mul_(args.num_swa).add_(param_s.detach().data).div_(args.num_swa + 1)
                elif args.use_ema:
                    with torch.no_grad():
                        for param_q, param_k in zip(encoder.parameters(), ema_encoder.parameters()):
                            param_k.data.mul_(args.mom).add_((1 - args.mom) * param_q.detach().data)
                    # ema.update()

                if args.monitor_stats:
                    return (float(loss), float(ploss), float(me_max), float(celoss), \
                            lr_stats, float(vm_snn), float(vm_head), float(mp_snn), \
                            float(mp_head), float(acc_snn), float(acc_head), float(ece))
                else:
                    return (float(loss), float(ploss), float(me_max), float(celoss), lr_stats)
            if args.monitor_stats:
                (loss, ploss, rloss, celoss, lr_stats,\
                 vm_snn, vm_head, mp_snn, mp_head, acc_snn, acc_head, ece), etime = gpu_timer(train_step)
            else:
                (loss, ploss, rloss, celoss, lr_stats), etime = gpu_timer(train_step)
            loss_meter.update(loss)
            ploss_meter.update(ploss)
            rloss_meter.update(rloss)
            celoss_meter.update(celoss)
            if args.monitor_stats:
                vm_snn_meter.update(vm_snn)
                vm_head_meter.update(vm_head)
                mp_snn_meter.update(mp_snn)
                mp_head_meter.update(mp_head)
                acc_snn_meter.update(acc_snn)
                acc_head_meter.update(acc_head)
                ece_meter.update(ece)
            time_meter.update(etime)

            if (itr % log_freq == 0) or np.isnan(loss) or np.isinf(loss):
                logger.info('pseudoacc %.4f %.4f'
                            % (acc_snn_meter.avg, acc_head_meter.avg))
                logger.info('[%d, %5d] loss: %.3f (%.3f %.3f %.3f) '
                            'ece %.3f %.4f'
                            'time (%d ms; %d ms)'
                            % (epoch + 1, itr,
                               loss_meter.avg,
                               ploss_meter.avg,
                               rloss_meter.avg,
                               celoss_meter.avg,
                               ece_meter.avg,
                               args.tau_t,
                               time_meter.avg,
                               data_meter.avg))
                csv_logger.log(epoch + 1, itr,
                               ploss_meter.avg,
                               rloss_meter.avg,
                               celoss_meter.avg,
                               time_meter.avg)
                if lr_stats is not None:
                    logger.info('[%d, %5d] lr_stats: %.3f (%.2e, %.2e)'
                                % (epoch + 1, itr,
                                   lr_stats.avg,
                                   lr_stats.min,
                                   lr_stats.max))

            assert not np.isnan(loss), 'loss is nan'
                # print(vm_snn_meter.avg,
                # vm_head_meter.avg,
                # mp_snn_meter.avg,
                # mp_head_meter.avg,
                # acc_snn_meter.avg,
                # acc_head_meter.avg)
        # -- logging/checkpointing
        logger.info('avg. loss %.3f' % loss_meter.avg)
        writer.add_scalar('train/1.paws_train_loss', ploss_meter.avg, epoch+1)
        writer.add_scalar('train/2.memax_train_loss', rloss_meter.avg, epoch+1)
        writer.add_scalar('train/3.online_class_loss', celoss_meter.avg, epoch+1)
        writer.add_scalar('train/4.learning_rate', scheduler.get_last_lr()[0], epoch)

        if args.monitor_stats:
            writer.add_scalar('train/6.view_match_snn', vm_snn_meter.avg, epoch+1)
            writer.add_scalar('train/7.view_match_head', vm_head_meter.avg, epoch+1)
            writer.add_scalar('train/8.max_prob_mean_snn', mp_snn_meter.avg, epoch+1)
            writer.add_scalar('train/9.max_prob_mean_head', mp_head_meter.avg, epoch+1)
            writer.add_scalar('train/10.pseudoacc_snn', acc_snn_meter.avg, epoch+1)
            writer.add_scalar('train/11.pseudoacc_head', acc_head_meter.avg, epoch+1)
            writer.add_scalar('train/12.ECE', ece_meter.avg, epoch+1)

        if mom_scheduler is not None:
            writer.add_scalar('train/15.momentum', args.mom, epoch)

        if rank == 0:
            save_dict = {
                'encoder': encoder.state_dict(),
                'classifier': classifier.state_dict(),
                'opt': optimizer.state_dict(),
                'epoch': epoch + 1,
                'unlabel_prob': unlabeled_frac,
                'loss': loss_meter.avg,
                's_batch_size': s_batch_size,
                'u_batch_size': u_batch_size,
                'world_size': world_size,
                'lr': lr,
                'temperature': temperature,
                'amp': scaler.state_dict(),
            }
            if ema_encoder is not None:
                save_dict['ema_encoder'] = ema_encoder.state_dict()
            if args.use_swa:
                save_dict['num_swa'] = args.num_swa

            torch.save(save_dict, latest_path)
            if best_loss is None or best_loss > loss_meter.avg:
                best_loss = loss_meter.avg
                logger.info('updating "best" checkpoint')
                torch.save(save_dict, best_path)
            if (epoch + 1) % checkpoint_freq == 0 \
                    or (epoch + 1) % 10 == 0 and epoch < checkpoint_freq:
                torch.save(save_dict, save_path.format(epoch=f'{epoch + 1}'))

			# Evaluate test accuracy of online classifier
            top1_meter = AverageMeter()
            top5_meter = AverageMeter()
            with torch.no_grad():
                encoder.eval()
                classifier.eval()
                if ema_encoder is not None:
                    ema_encoder.eval()
                for batch_idx, (inputs,targets) in enumerate(snn_test_data_loader):

                    inputs = inputs.to(device)
                    targets = targets.to(device)
                    if args.use_ema or args.use_swa:
                        outputs = classifier(ema_encoder(inputs))
                    else:
                        outputs = classifier(encoder(inputs))
                    prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
                    top1_meter.update(prec1.item(), inputs.shape[0])
                    top5_meter.update(prec5.item(), inputs.shape[0])
            writer.add_scalar('online_test/1.top_1', top1_meter.avg, epoch)
            writer.add_scalar('online_test/2.top_5', top5_meter.avg, epoch)

            if eval_snn:
                with torch.no_grad():
                    encoder.eval()
                    if ema_encoder is not None:
                        ema_encoder.eval()
                    if args.use_ema or args.use_swa:
                        embs, labs = make_embeddings_fix(
                            device,
                            snn_train_data_loader,
                            snn_train_data_sampler,
                            encoder=ema_encoder,
                            encoder_fc=ema_encoder.fc)
                        top1, top5, ece = evaluate_embeddings_fix(
                            device,
                            snn_test_data_loader,
                            encoder=ema_encoder,
                            encoder_fc=ema_encoder.fc,
                            labs=labs,
                            embs=embs,
                            num_classes=num_classes,
                            temp=args.tau_t)
                    else:
                        embs, labs = make_embeddings_fix(
                            device,
                            snn_train_data_loader,
                            snn_train_data_sampler,
                            encoder=encoder,
                            encoder_fc=encoder.fc)
                        top1, top5, ece = evaluate_embeddings_fix(
                            device,
                            snn_test_data_loader,
                            encoder=encoder,
                            encoder_fc=encoder.fc,
                            labs=labs,
                            embs=embs,
                            num_classes=num_classes,
                            temp=args.tau_t)
                print(top1, top5, ece)
                writer.add_scalar('snn_test/1.top_1', top1, epoch)
                writer.add_scalar('snn_test/2.top_5', top5, epoch)
                writer.add_scalar('snn_test/3.test_ECE', ece, epoch)

            print("online classifier: {:.2f}; {:.2f}".format(top1_meter.avg,top5_meter.avg))

    writer.close()


def load_checkpoint(
    r_path,
    encoder,
    opt,
    scaler,
    classifier,
    use_fp16=False,
    ema_encoder = None
):
    checkpoint = torch.load(r_path, map_location='cpu')
    epoch = checkpoint['epoch']

    # -- loading encoder
    encoder.load_state_dict(checkpoint['encoder'])
    logger.info(f'loaded encoder from epoch {epoch}')

    # -- loading classifier
    classifier.load_state_dict(checkpoint['classifier'])
    logger.info(f'loaded classifier from epoch {epoch}')

    # --- loading ema
    if ema_encoder is not None and 'ema_encoder' in checkpoint:
        ema_encoder.load_state_dict(checkpoint['ema_encoder'])
        logger.info(f'loaded EMA encoder from epoch {epoch}')
    elif ema_encoder is not None:
        ema_encoder.load_state_dict(encoder.state_dict())
        logger.info(f'loaded EMA encoder with encoder checkpoint weights')
    # -- loading optimizer
    opt.load_state_dict(checkpoint['opt'])

    # load swa parameters
    if args.use_swa and 'num_swa' in checkpoint:
        num_swa = checkpoint['num_swa']
    else:
        num_swa = 0
    if use_fp16:
        scaler.load_state_dict(checkpoint['amp'])

    logger.info(f'loaded optimizers from epoch {epoch}')
    logger.info(f'read-path: {r_path}')
    del checkpoint
    return encoder, opt, epoch, classifier, ema_encoder, num_swa


def init_model(
    device,
    model_name='resnet50',
    use_pred=False,
    output_dim=128,
    class_mlp=False,
    num_classes=100,
	version='fixmatch',
    class_depth=2,
    hidden_dim=128,
    dropout=0.0
):
    # print(model_name, args.dataset)
    if 'wide_resnet' in model_name:
        # encoder = wide_resnet.__dict__[model_name](dropout_rate=0.0)
        # hidden_dim = 512 if args.dataset == 'cifar100' else 128
        widefac = 8 if args.dataset == 'cifar100' else 2
        # print(hidden_dim, widefac)

        encoder = WideResNetProj(num_classes=num_classes,hidden_dim=hidden_dim,
            output_dim=output_dim,widen_factor=widefac, version=version, dropout=dropout)
        if args.use_ema or args.use_swa:
            ema_encoder = WideResNetProj(num_classes=num_classes,hidden_dim=hidden_dim,
                output_dim=output_dim,widen_factor=widefac, version=version, dropout=dropout)

    else:
        encoder = resnet.__dict__[model_name]()
        hidden_dim = 2048
        if 'w2' in model_name:
            hidden_dim *= 2
        elif 'w4' in model_name:
            hidden_dim *= 4
        if args.use_ema or args.use_swa:
            ema_encoder = resnet.__dict__[model_name]()
    if not (args.use_ema or args.use_swa):
        ema_encoder = None

    # -- projection head
    # encoder.fc = torch.nn.Sequential(OrderedDict([
    #     ('fc1', torch.nn.Linear(hidden_dim, hidden_dim)),
    #     ('bn1', torch.nn.BatchNorm1d(hidden_dim)),
    #     ('relu1', torch.nn.ReLU(inplace=True)),
    #     ('fc2', torch.nn.Linear(hidden_dim, hidden_dim)),
    #     ('bn2', torch.nn.BatchNorm1d(hidden_dim)),
    #     ('relu2', torch.nn.ReLU(inplace=True)),
    #     ('fc3', torch.nn.Linear(hidden_dim, output_dim))
    # ]))

    # -- prediction head
    # encoder.pred = None
    # if use_pred:
    #     mx = 4  # 4x bottleneck prediction head
    #     pred_head = OrderedDict([])
    #     pred_head['bn1'] = torch.nn.BatchNorm1d(output_dim)
    #     pred_head['fc1'] = torch.nn.Linear(output_dim, output_dim//mx)
    #     pred_head['bn2'] = torch.nn.BatchNorm1d(output_dim//mx)
    #     pred_head['relu'] = torch.nn.ReLU(inplace=True)
    #     pred_head['fc2'] = torch.nn.Linear(output_dim//mx, output_dim)
    #     encoder.pred = torch.nn.Sequential(pred_head)
    if args.use_ema or args.use_swa:
        # remove gradients for ema encoder
        for p in ema_encoder.parameters():
            p.requires_grad = False
        print("set no grads for teacher")

        # initialize to same weights as student
        ema_encoder.load_state_dict(encoder.state_dict())
        ema_encoder.to(device)
    else:
        ema_encoder = None
    encoder.to(device)

    # logger.info(encoder)
    classifier = Online_Classifier(num_classes=num_classes,model_name=model_name,\
                mlp=class_mlp,depth=class_depth)
    classifier.to(device)
    # logger.info(classifier)

    return encoder, classifier, ema_encoder


class MomentumScheduler:
    def __init__(self, init_mom=0.996, final_mom=1, warmup_from=0, warmup_epochs=0, total_epochs = 800, iter_per_epoch=1):
        warmup_iter = iter_per_epoch * warmup_epochs
        warmup_sch = np.linspace(warmup_from, init_mom, warmup_iter)
        num_iter = iter_per_epoch * (total_epochs - warmup_epochs)
        cos_sch = final_mom + 0.5 * (init_mom - final_mom) * (1 + np.cos(np.pi * np.arange(num_iter) / num_iter))
        self.cos_schedule = np.concatenate([warmup_sch,cos_sch])
        self.iter = int(-1)

    def get_mom(self):
        self.iter += 1
        return self.cos_schedule[self.iter]


def init_opt(
    encoder,
    iterations_per_epoch,
    start_lr,
    ref_lr,
    ref_mom,
    nesterov,
    warmup,
    num_epochs,
    weight_decay=1e-6,
    final_lr=0.0,
    classifier=None,
    head_factor=1.,
    class_wd=0,
    classifier_exclude_LARS=False
):
    param_groups = [
        {'params': (p for n, p in encoder.named_parameters()
                    if ('bias' not in n) and ('bn' not in n))},
        {'params': (p for n, p in encoder.named_parameters()
                    if ('bias' in n) or ('bn' in n)),
         'LARS_exclude': True,
         'weight_decay': 0}
    ]
    if classifier is not None:
        param_groups += [{'params':classifier.parameters(), 'lr':ref_lr*head_factor, 'weight_decay':class_wd, 'LARS_exclude': classifier_exclude_LARS}] # scale lr for head

    optimizer = SGD(
        param_groups,
        weight_decay=weight_decay,
        momentum=0.9,
        nesterov=nesterov,
        lr=ref_lr)
    scheduler = WarmupCosineSchedule(
        optimizer,
        warmup_steps=warmup*iterations_per_epoch,
        start_lr=start_lr,
        ref_lr=ref_lr,
        final_lr=final_lr,
        T_max=num_epochs*iterations_per_epoch)
    optimizer = LARS(optimizer, trust_coefficient=0.001)
    return encoder, optimizer, scheduler, classifier

def smoothen_labels(hard_labels, num_classes=100,smoothing=0.1, device='cpu'):
    null_logit = smoothing/num_classes
    t_logit = 1 - null_logit * (num_classes-1)

    label_matrix = torch.ones(len(hard_labels),num_classes) * null_logit
    label_matrix = label_matrix.to(device)
    label_matrix.scatter_(1,hard_labels.unsqueeze(1),t_logit)
    return label_matrix



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='train PAWS online')
    parser.add_argument('--tag', default='test', type=str)
    parser.add_argument('--folder', default='results/', type=str)
    parser.add_argument('--eval_snn', action='store_true', default=True)
    # parser.add_argument('--model_name', default='wide_resnet28w8', type=str)
    parser.add_argument('--proj_dim', default=512, type=int)
    parser.add_argument('--hidden_dim', default=512, type=int)

    parser.add_argument('--load_checkpoint', action='store_true')
    parser.add_argument('--read_checkpoint', default=None, type=str)
    parser.add_argument('--copy_data', action='store_true')
    parser.add_argument('--use_fp16', action='store_true',default=True)
    parser.add_argument('--device', default='cuda:0', type=str)
    parser.add_argument('--me_max', action='store_true',default=True)
    parser.add_argument('--no_me_max', action='store_true')
    parser.add_argument('--dataset', default='cifar100', type=str)
    parser.add_argument('--sup_bs', default=6, type=int)
    parser.add_argument('--supervised_views', default=2, type=int)
    parser.add_argument('--temperature', default=0.1, type=float)
    parser.add_argument('--class_temp', default=1., type=float)

    parser.add_argument('--sharpen', default=0.25, type=float)

    parser.add_argument('--online_head', default='linear', type=str)
    parser.add_argument('--from_proj1', action='store_true')
    parser.add_argument('--class_mlp', action='store_true')
    parser.add_argument('--class_depth', default=2, type=int)
    parser.add_argument('--data_seed', default=0, type=int)

    parser.add_argument('--epochs', default=600, type=int)
    # parser.add_argument('--final_lr', default=0.032, type=float)
    parser.add_argument('--lr_decay_factor', default=100, type=float)

    parser.add_argument('--lr', default=3.2, type=float)

    parser.add_argument('--warmup', default=10, type=int)
    parser.add_argument('--weight_decay', default=1e-6, type=float)
    parser.add_argument('--class_wd', default=1e-6, type=float)
    parser.add_argument('--class_ex_lars', action='store_true')

    parser.add_argument('--head_factor', default=1., type=float)
    parser.add_argument('--version', default='fixmatch', type=str)
    parser.add_argument('--unlabeled_frac', default=0.92, type=float)
    parser.add_argument('--use_basic', action='store_true')
    parser.add_argument('--x_no_detach', action='store_true')
    parser.add_argument('--memax_class', action='store_true')
    parser.add_argument('--monitor_stats', action='store_true')
    parser.add_argument('--multicrop', default=6, type=int)
    parser.add_argument('--x_soft_label', action='store_true')
    parser.add_argument('--sharpen_online_targets', action='store_true')
    parser.add_argument('--lab_smooth', default=0.1, type=float)
    parser.add_argument('--dropout', default=0.0, type=float)
    parser.add_argument('--tau_t', default=0.1, type=float)

    parser.add_argument('--use_ema', action='store_true')
    parser.add_argument('--ema_decay', default=0.996, type=float)
    parser.add_argument('--use_ema_for_online', action='store_true')
    parser.add_argument('--checkpoint_freq', default=50, type=int)
    parser.add_argument('--use_mom_scheduler', action='store_true')
    parser.add_argument('--mom_warmup_epochs', default=0, type=int)
    parser.add_argument('--final_mom', default=1, type=float)
    parser.add_argument('--use_swa', action='store_true')
    parser.add_argument('--swa_freq', default=1, type=int)
    parser.add_argument('--swa_warmup', default=0, type=int)


    args = parser.parse_args()

    main(args)
