# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import os

# # -- FOR DISTRIBUTED TRAINING ENSURE ONLY 1 DEVICE VISIBLE PER PROCESS
# try:
#     # -- WARNING: IF DOING DISTRIBUTED TRAINING ON A NON-SLURM CLUSTER, MAKE
#     # --          SURE TO UPDATE THIS TO GET LOCAL-RANK ON NODE, OR ENSURE
#     # --          THAT YOUR JOBS ARE LAUNCHED WITH ONLY 1 DEVICE VISIBLE
#     # --          TO EACH PROCESS
#     os.environ['CUDA_VISIBLE_DEVICES'] = os.environ['SLURM_LOCALID']
# except Exception:
#     pass

import logging
import sys
from collections import OrderedDict

import numpy as np

import torch

import src.resnet as resnet
# import src.wide_resnet as wide_resnet
# from src.wide_resnet import Online_Classifier
from src.utils import (
    gpu_timer,
    init_distributed,
    WarmupCosineSchedule,
    CSVLogger,
    AverageMeter
)
from src.mod_losses import (
    init_paws_loss,
    make_labels_matrix
)
from src.data_manager import (
    init_data,
    make_transforms,
    make_multicrop_transform
)
from src.sgd import SGD
from src.lars import LARS

# import apex
from torch.nn.parallel import DistributedDataParallel

import tensorboard_logger as tb_logger

# --
log_timings = True
log_freq = 50
checkpoint_freq = 10
# --

_GLOBAL_SEED = 0
np.random.seed(_GLOBAL_SEED)
torch.manual_seed(_GLOBAL_SEED)
torch.backends.cudnn.benchmark = True

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()


def main_worker(gpu, args, params):

    torch.distributed.init_process_group(
        backend='nccl', init_method=args.dist_url,
        world_size=args.world_size, rank=args.rank)

    print(' '.join(sys.argv))

    torch.cuda.set_device(gpu)
    torch.backends.cudnn.benchmark = True


    # ----------------------------------------------------------------------- #
    #  PASSED IN PARAMS FROM CONFIG FILE
    # ----------------------------------------------------------------------- #
    # -- META
    model_name = params['meta']['model_name']
    output_dim = params['meta']['output_dim']
    load_model = params['meta']['load_checkpoint']
    r_file = params['meta']['read_checkpoint']
    copy_data = params['meta']['copy_data']
    use_fp16 = params['meta']['use_fp16']
    if args.input_pred == -1:
        use_pred_head = params['meta']['use_pred_head']
    elif args.input_pred == 0:
        use_pred_head = False
    elif args.input_pred == 1:
        use_pred_head = True
    else:
        raise ValueError("only 0 or 1")
    # device = torch.device(params['meta']['device'])
    # torch.cuda.set_device(device)

    # -- CRITERTION
    reg = params['criterion']['me_max']
    supervised_views = params['criterion']['supervised_views']
    classes_per_batch = params['criterion']['classes_per_batch']
    s_batch_size = params['criterion']['supervised_imgs_per_class']
    u_batch_size = params['criterion']['unsupervised_batch_size']
    temperature = params['criterion']['temperature']
    sharpen = params['criterion']['sharpen']

    # -- DATA
    unlabeled_frac = params['data']['unlabeled_frac']
    color_jitter = params['data']['color_jitter_strength']
    normalize = params['data']['normalize']
    root_path = params['data']['root_path']
    image_folder = params['data']['image_folder']
    dataset_name = params['data']['dataset']
    subset_path = params['data']['subset_path']
    unique_classes = params['data']['unique_classes_per_rank']
    multicrop = params['data']['multicrop']
    label_smoothing = params['data']['label_smoothing']
    data_seed = None
    if 'cifar10' in dataset_name:
        data_seed = params['data']['data_seed']
        crop_scale = (0.75, 1.0) if multicrop > 0 else (0.5, 1.0)
        mc_scale = (0.3, 0.75)
        mc_size = 18
    else:
        crop_scale = (0.14, 1.0) if multicrop > 0 else (0.08, 1.0)
        mc_scale = (0.05, 0.14)
        mc_size = 96

    # -- OPTIMIZATION
    wd = float(params['optimization']['weight_decay'])
    num_epochs = args.num_epochs
    # num_epochs = params['optimization']['epochs']
    warmup = params['optimization']['warmup']
    start_lr = params['optimization']['start_lr']
    if args.input_lr == -1:
        lr = params['optimization']['lr']
        final_lr = params['optimization']['final_lr']

    else:
        lr = args.input_lr
        final_lr = lr/100
    mom = params['optimization']['momentum']
    nesterov = params['optimization']['nesterov']

    # -- LOGGING
    # folder = params['logging']['folder']
    folder = args.checkpoint_dir
    tag = args.exp

    # tag = params['logging']['write_tag']
    if 'imagenet' in dataset_name:
        num_classes = 1000
    elif 'cifar100' in dataset_name:
        num_classes = 100
    else:
        num_classes = 10


    # tb_dir = '/gpfs/u/home/BNSS/BNSSlhch/scratch/suncet_ori/' + folder + '/' + tag + '_tb'
    # os.makedirs(tb_dir, exist_ok=True)
    if args.rank == 0:
        tblogger = tb_logger.Logger(logdir=args.log_dir, flush_secs=2)
    # ----------------------------------------------------------------------- #

    # # -- init torch distributed backend
    # world_size, rank = init_distributed()
    # logger.info(f'Initialized (rank/world-size) {rank}/{world_size}')

    # -- log/checkpointing paths
    # log_file = os.path.join(folder, f'{tag}_r{args.rank}.csv')
    # save_path = os.path.join(folder, f'{tag}' + '-ep{epoch}.pth.tar')
    # latest_path = os.path.join(folder, f'{tag}-latest.pth.tar')
    # best_path = os.path.join(folder, f'{tag}' + '-best.pth.tar')
    log_file = args.checkpoint_dir / f'r{args.rank}.csv'
    # save_path = args.checkpoint_dir / 'ep{epoch}.pth.tar'
    latest_path = args.checkpoint_dir / 'latest.pth.tar'
    best_path = args.checkpoint_dir / 'best.pth.tar'
    load_path = None


    # load_path = os.path.join(folder, r_file) if r_file is not None else latest_path

    # -- make csv_logger
    csv_logger = CSVLogger(log_file,
                           ('%d', 'epoch'),
                           ('%d', 'itr'),
                           ('%.5f', 'paws-xent-loss'),
                           ('%.5f', 'paws-me_max-reg'),
                           ('%d', 'time (ms)'))

    # -- init model
    encoder, encoder_t = init_model(
        device=gpu,
        model_name=model_name,
        use_pred=use_pred_head,
        output_dim=output_dim)

    if args.world_size > 1:
        # process_group = apex.parallel.create_syncbn_process_group(0)
        # encoder = apex.parallel.convert_syncbn_model(encoder, process_group=process_group)
        encoder = torch.nn.SyncBatchNorm.convert_sync_batchnorm(encoder)
        if encoder_t is not None:
            encoder_t = torch.nn.SyncBatchNorm.convert_sync_batchnorm(encoder_t)

        encoder = torch.nn.parallel.DistributedDataParallel(
                encoder, device_ids=[gpu])
        if encoder_t is not None:
            encoder_t = torch.nn.parallel.DistributedDataParallel(
                    encoder_t, device_ids=[gpu])

            encoder_t_noddp = encoder_t.module

    else:
        encoder_t_noddp = encoder_t

    if encoder_t is not None:
        encoder_t_noddp.load_state_dict(encoder.module.state_dict())

        # teacher doesn't need gradients
        for p in encoder_t.parameters():
            p.requires_grad = False

    # -- init losses
    paws = init_paws_loss(
        multicrop=multicrop,
        tau_t=temperature,
        tau_s=temperature,
        T=sharpen,
        me_max=reg)
    # -- assume support images are sampled with ClassStratifiedSampler
    labels_matrix = make_labels_matrix(
        num_classes=classes_per_batch,
        s_batch_size=s_batch_size,
        world_size=args.world_size,
        device=gpu,
        unique_classes=unique_classes,
        smoothing=label_smoothing)

    # -- make data transforms
    transform, init_transform = make_transforms(
        dataset_name=dataset_name,
        subset_path=subset_path,
        unlabeled_frac=unlabeled_frac,
        training=True,
        split_seed=data_seed,
        crop_scale=crop_scale,
        basic_augmentations=False,
        color_jitter=color_jitter,
        normalize=normalize)
    multicrop_transform = (multicrop, None)
    if multicrop > 0:
        multicrop_transform = make_multicrop_transform(
                dataset_name=dataset_name,
                num_crops=multicrop,
                size=mc_size,
                crop_scale=mc_scale,
                normalize=normalize,
                color_distortion=color_jitter)

    # -- init data-loaders/samplers
    (unsupervised_loader,
     unsupervised_sampler,
     supervised_loader,
     supervised_sampler) = init_data(
         dataset_name=dataset_name,
         transform=transform,
         init_transform=init_transform,
         supervised_views=supervised_views,
         u_batch_size=u_batch_size,
         s_batch_size=s_batch_size,
         unique_classes=unique_classes,
         classes_per_batch=classes_per_batch,
         multicrop_transform=multicrop_transform,
         world_size=args.world_size,
         rank=args.rank,
         root_path=root_path,
         image_folder=image_folder,
         training=True,
         copy_data=copy_data)
    iter_supervised = None
    ipe = len(unsupervised_loader)
    logger.info(f'iterations per epoch: {ipe}')

    # -- init optimizer and scheduler
    scaler = torch.cuda.amp.GradScaler(enabled=use_fp16)
    hfac = 0. if args.no_online else 0.1

    encoder, optimizer, scheduler = init_opt(
        encoder=encoder,
        weight_decay=wd,
        start_lr=start_lr,
        ref_lr=lr,
        final_lr=final_lr,
        ref_mom=mom,
        nesterov=nesterov,
        iterations_per_epoch=ipe,
        warmup=warmup,
        num_epochs=num_epochs,
        head_factor=hfac,
        class_wd= 0.)
    # if args.world_size > 1:
    #     encoder = DistributedDataParallel(encoder, broadcast_buffers=False)

    if args.use_mom_scheduler:
        mom_scheduler = MomentumScheduler(init_mom=args.ema_decay, \
        final_mom=args.final_ema_decay,warmup_from=0, warmup_epochs=args.mom_warmup_epochs, \
        total_epochs=num_epochs, iter_per_epoch=ipe)
    else:
        mom_scheduler = None
        args.ema_mom = args.ema_decay

    start_epoch = 0
    num_swa = 0
    # -- load training checkpoint

    # automatically load model if exists
    if (args.checkpoint_dir / 'latest.pth.tar').is_file():
        encoder, optimizer, start_epoch, encoder_t, num_swa = load_checkpoint(
            r_path=args.checkpoint_dir / 'latest.pth.tar',
            encoder=encoder,
            opt=optimizer,
            scaler=scaler,
            use_fp16=use_fp16,
            ema_encoder=encoder_t,
            use_swa=args.use_swa)
        for _ in range(start_epoch):
            for _ in range(ipe):
                scheduler.step()
                if mom_scheduler is not None:
                    args.ema_mom = mom_scheduler.get_mom()

    args.num_swa = num_swa
    CELoss = torch.nn.CrossEntropyLoss().cuda(gpu)
    # -- TRAINING LOOP
    best_loss = None
    for epoch in range(start_epoch, num_epochs):
        if epoch > num_epochs - 10:
            checkpoint_freq = 1
        else:
            checkpoint_freq = 10
        logger.info('Epoch %d' % (epoch + 1))

        # -- update distributed-data-loader epoch
        unsupervised_sampler.set_epoch(epoch)
        if supervised_sampler is not None:
            supervised_sampler.set_epoch(epoch)

        loss_meter = AverageMeter()
        ploss_meter = AverageMeter()
        rloss_meter = AverageMeter()
        celoss_meter = AverageMeter()

        time_meter = AverageMeter()
        data_meter = AverageMeter()
        mp_snn_meter = AverageMeter()
        acc_snn_meter = AverageMeter()
        ece_meter = AverageMeter()
        online_acc_meter = AverageMeter()
        online_acc_meter2 = AverageMeter()
        nan_counter = 0

        for itr, udata in enumerate(unsupervised_loader):

            def load_imgs():
                # -- unsupervised imgs
                uimgs = [u.to(gpu, non_blocking=True) for u in udata[:-1]]
                ulabels = udata[-1].to(gpu,non_blocking=True)
                # -- supervised imgs
                global iter_supervised
                try:
                    sdata = next(iter_supervised)
                except Exception:
                    iter_supervised = iter(supervised_loader)
                    logger.info(f'len.supervised_loader: {len(iter_supervised)}')
                    sdata = next(iter_supervised)
                finally:
                    simgs = [s.to(gpu, non_blocking=True) for s in sdata[:-1]]
                    hard_lab = sdata[-1].to(gpu, non_blocking=True)
                    new_labels_matrix = smoothen_labels(hard_lab, num_classes=num_classes, max_classes=classes_per_batch*args.world_size, smoothing=label_smoothing, device=gpu)
                    # print(new_labels_matrix.shape)
                    labels = torch.cat([new_labels_matrix for _ in range(supervised_views)])
                    hard_labels = torch.cat([hard_lab for _ in range(supervised_views)])

                # -- concatenate supervised imgs and unsupervised imgs
                imgs = simgs + uimgs
                return imgs, labels, hard_labels, ulabels
            (imgs, labels, hard_lab, ulabels), dtime = gpu_timer(load_imgs)
            data_meter.update(dtime)

            # def train_step():
            with torch.cuda.amp.autocast(enabled=use_fp16):
                optimizer.zero_grad()

                # --
                # h: representations of 'imgs' before head
                # z: representations of 'imgs' after head
                # -- If use_pred_head=False, then encoder.pred (prediction
                #    head) is None, and _forward_head just returns the
                #    identity, z=h
                h, z, mid = encoder(imgs, return_before_head=True)
                # ema.copy_to(encoder_t.parameters())
                if encoder_t is not None:
                    tar_h, tar_z, tar_mid = encoder_t(imgs, return_before_head = True)
                else:
                    tar_h, tar_z, tar_mid = h.detach(), z.detach(), mid.detach()
                # Compute paws loss in full precision
                with torch.cuda.amp.autocast(enabled=False):

                    # Step 1. convert representations to fp32
                    h, z, mid = h.float(), z.float(), mid.float()
                    tar_h, tar_mid = tar_h.float(), tar_mid.float()
                    # Step 2. determine anchor views/supports and their
                    #         corresponding target views/supports
                    # --
                    num_support = supervised_views * s_batch_size * classes_per_batch

                    # --
                    anchor_supports = z[:num_support]
                    anchor_views = z[num_support:]
                    # --
                    target_supports = tar_h[:num_support].detach()
                    target_views = tar_h[num_support:].detach()

                    target_views = torch.cat([
                        target_views[u_batch_size:2*u_batch_size],
                        target_views[:u_batch_size]], dim=0)

                    # Step 3. compute paws loss with me-max regularization
                    (ploss, me_max, acc_snn, mp_snn, ece) = paws(
                        anchor_views=anchor_views,
                        anchor_supports=anchor_supports,
                        anchor_support_labels=labels,
                        target_views=target_views,
                        target_supports=target_supports,
                        target_support_labels=labels,
                        ulab_true=ulabels)
                    loss = ploss + me_max

                    if args.world_size > 1:
                        logits = encoder.module.classifier(tar_mid[:num_support].detach())
                    else:
                        logits = encoder.classifier(tar_mid[:num_support].detach())

                    celoss = CELoss(logits,hard_lab)
                    online_acc = torch.sum(torch.eq(torch.argmax(logits,dim=1), hard_lab)) /logits.size(0)
                    loss += celoss

            scaler.scale(loss).backward()
            lr_stats = scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            if mom_scheduler is not None:
                args.ema_mom = mom_scheduler.get_mom()
            # update ema
            with torch.no_grad():
                if args.use_ema:
                    if args.world_size > 1:
                        for param_q, param_k in zip(encoder.module.parameters(), encoder_t_noddp.parameters()):
                            param_k.data.mul_(args.ema_mom).add_((1 - args.ema_mom) * param_q.detach().data)
                    else:
                        for param_q, param_k in zip(encoder.parameters(), encoder_t_noddp.parameters()):
                            param_k.data.mul_(args.ema_mom).add_((1 - args.ema_mom) * param_q.detach().data)

                elif args.use_swa and epoch < args.swa_warmup:
                    if args.world_size > 1:
                        for param_q, param_k in zip(encoder.module.parameters(), encoder_t_noddp.parameters()):
                            param_k.data = param_q.detach().data
                    else:
                        for param_q, param_k in zip(encoder.parameters(), encoder_t_noddp.parameters()):
                            param_k.data = param_q.detach().data
                elif args.use_swa:
                    args.num_swa += 1
                    if args.world_size > 1:
                        for param_q, param_k in zip(encoder.module.parameters(), encoder_t_noddp.parameters()):
                            param_k.data.mul_(args.num_swa).add_(param_q.detach().data).div_(args.num_swa + 1)
                    else:
                        for param_q, param_k in zip(encoder.parameters(), encoder_t_noddp.parameters()):
                            param_k.data.mul_(args.num_swa).add_(param_q.detach().data).div_(args.num_swa + 1)
                # ema.update()
                # return (loss, ploss, me_max, celoss, acc_snn, mp_snn, ece, online_acc, lr_stats)

            # (loss, ploss, rloss, celoss, acc_snn, mp_snn, ece, online_acc, lr_stats), etime = gpu_timer(train_step)
            loss_meter.update(float(loss))
            ploss_meter.update(float(ploss))
            rloss_meter.update(float(me_max))
            celoss_meter.update(float(celoss))
            acc_snn_meter.update(float(acc_snn))
            mp_snn_meter.update(float(mp_snn))
            ece_meter.update(float(ece))
            # time_meter.update(etime)
            online_acc_meter.update(float(online_acc))
            torch.distributed.reduce(online_acc.div_(args.world_size),0)
            online_acc_meter2.update(float(online_acc))

            if (itr % log_freq == 0) or np.isnan(float(loss)) or np.isinf(float(loss)):
                csv_logger.log(epoch + 1, itr,
                               ploss_meter.avg,
                               rloss_meter.avg,
                               time_meter.avg,
                               celoss_meter.avg,
                               acc_snn_meter.avg,
                               mp_snn_meter.avg,
                               ece_meter.avg,
                               online_acc_meter.avg,
                               online_acc_meter2.avg)
                logger.info('[%d, %5d] loss: %.3f (%.3f %.3f) '
                            '(%d ms; %d ms)'
                            % (epoch + 1, itr,
                               loss_meter.avg,
                               ploss_meter.avg,
                               rloss_meter.avg,
                               time_meter.avg,
                               data_meter.avg))
                logger.info('ce: %.3f mp: %.3f pacc: %.3f ece: %.3f; onacc %.3f %.3f '
                            % (celoss_meter.avg,
                               mp_snn_meter.avg,
                               acc_snn_meter.avg,
                               ece_meter.avg,
                               online_acc_meter.avg,
                               online_acc_meter2.avg))
                if lr_stats is not None:
                    logger.info('[%d, %5d] lr_stats: %.3f (%.2e, %.2e)'
                                % (epoch + 1, itr,
                                   lr_stats.avg,
                                   lr_stats.min,
                                   lr_stats.max))

            assert not np.isnan(float(loss)), 'loss is nan'

        # -- logging/checkpointing
        logger.info('avg. loss %.3f' % loss_meter.avg)
        if args.rank == 0:
            tblogger.log_value('train/1.paws_train_loss', ploss_meter.avg, epoch+1)
            tblogger.log_value('train/2.memax_train_loss', rloss_meter.avg, epoch+1)
            tblogger.log_value('train/3.online_class_loss', celoss_meter.avg, epoch+1)
            tblogger.log_value('train/4.learning_rate', scheduler.get_last_lr()[0], epoch)
            tblogger.log_value('train/5.max_prob_mean_snn', mp_snn_meter.avg, epoch+1)
            tblogger.log_value('train/6.pseudoacc_snn', acc_snn_meter.avg, epoch+1)
            tblogger.log_value('train/7.ECE', ece_meter.avg, epoch+1)
            tblogger.log_value('train/8.online_class_acc', online_acc_meter.avg, epoch+1)
            tblogger.log_value('train/8.online_class_acc2', online_acc_meter2.avg, epoch+1)
            if args.use_mom_scheduler:
                tblogger.log_value('train/9.momentum', args.ema_mom, epoch+1)
            tblogger.log_value('train/10.NaN_counter', nan_counter, epoch+1)


        if args.rank == 0:
            save_dict = {
                'encoder': encoder.state_dict(),
                'opt': optimizer.state_dict(),
                'epoch': epoch + 1,
                'unlabel_prob': unlabeled_frac,
                'loss': loss_meter.avg,
                's_batch_size': s_batch_size,
                'u_batch_size': u_batch_size,
                'world_size': args.world_size,
                'lr': lr,
                'temperature': temperature,
                'amp': scaler.state_dict(),
            }
            if args.use_swa:
                save_dict['num_swa'] = args.num_swa
            if encoder_t is not None:
                save_dict['ema_encoder'] = encoder_t.state_dict()
            torch.save(save_dict, latest_path)
            if best_loss is None or best_loss > loss_meter.avg:
                best_loss = loss_meter.avg
                logger.info('updating "best" checkpoint')
                torch.save(save_dict, best_path)
            if (epoch + 1) % checkpoint_freq == 0 \
                    or (epoch + 1) % 10 == 0 and epoch < checkpoint_freq:
                torch.save(save_dict, args.checkpoint_dir / f'ep{epoch+1}.pth.tar')


def load_checkpoint(
    r_path,
    encoder,
    opt,
    scaler,
    use_fp16=False,
    ema_encoder = None,
    use_swa = False
):
    checkpoint = torch.load(r_path, map_location='cpu')
    epoch = checkpoint['epoch']

    # -- loading encoder
    encoder.load_state_dict(checkpoint['encoder'])
    logger.info(f'loaded encoder from epoch {epoch}')
    if ema_encoder is not None:
        ema_encoder.load_state_dict(checkpoint['ema_encoder'])
        logger.info(f'loaded EMA encoder from epoch {epoch}')

    # -- loading optimizer
    opt.load_state_dict(checkpoint['opt'])
    if use_fp16:
        scaler.load_state_dict(checkpoint['amp'])

    if use_swa and 'num_swa' in checkpoint:
        num_swa = checkpoint['num_swa']
    else:
        num_swa = 0
    logger.info(f'loaded optimizers from epoch {epoch}')
    logger.info(f'read-path: {r_path}')
    del checkpoint
    return encoder, opt, epoch, ema_encoder, num_swa


def init_model(
    device,
    model_name='resnet50',
    use_pred=False,
    output_dim=128
):
    if 'wide_resnet' in model_name:
        encoder = wide_resnet.__dict__[model_name](dropout_rate=0.0)
        hidden_dim = 128
        if args.use_ema or args.use_swa:
            encoder_t = wide_resnet.__dict__[model_name](dropout_rate=0.0)
        else:
            encoder_t = None

    else:
        encoder = resnet.__dict__[model_name]()
        if args.use_ema or args.use_swa:
            encoder_t = resnet.__dict__[model_name]()
        else:
            encoder_t = None
        hidden_dim = 2048
        if 'w2' in model_name:
            hidden_dim *= 2
        elif 'w4' in model_name:
            hidden_dim *= 4

    # -- projection head
    encoder.fc = torch.nn.Sequential(OrderedDict([
        ('fc1', torch.nn.Linear(hidden_dim, hidden_dim)),
        ('bn1', torch.nn.BatchNorm1d(hidden_dim)),
        ('relu1', torch.nn.ReLU(inplace=True)),
        ('fc2', torch.nn.Linear(hidden_dim, hidden_dim)),
        ('bn2', torch.nn.BatchNorm1d(hidden_dim)),
        ('relu2', torch.nn.ReLU(inplace=True)),
        ('fc3', torch.nn.Linear(hidden_dim, output_dim))
    ]))
    if args.use_ema or args.use_swa:
        encoder_t.fc = torch.nn.Sequential(OrderedDict([
        ('fc1', torch.nn.Linear(hidden_dim, hidden_dim)),
        ('bn1', torch.nn.BatchNorm1d(hidden_dim)),
        ('relu1', torch.nn.ReLU(inplace=True)),
        ('fc2', torch.nn.Linear(hidden_dim, hidden_dim)),
        ('bn2', torch.nn.BatchNorm1d(hidden_dim)),
        ('relu2', torch.nn.ReLU(inplace=True)),
        ('fc3', torch.nn.Linear(hidden_dim, output_dim))
        ]))

    # -- prediction head
    encoder.pred = None
    if use_pred:
        mx = 4  # 4x bottleneck prediction head
        pred_head = OrderedDict([])
        pred_head['bn1'] = torch.nn.BatchNorm1d(output_dim)
        pred_head['fc1'] = torch.nn.Linear(output_dim, output_dim//mx)
        pred_head['bn2'] = torch.nn.BatchNorm1d(output_dim//mx)
        pred_head['relu'] = torch.nn.ReLU(inplace=True)
        pred_head['fc2'] = torch.nn.Linear(output_dim//mx, output_dim)
        encoder.pred = torch.nn.Sequential(pred_head)
    if args.use_ema or args.use_swa:
        encoder_t.pred = None
        if use_pred:
            mx = 4  # 4x bottleneck prediction head
            pred_head = OrderedDict([])
            pred_head['bn1'] = torch.nn.BatchNorm1d(output_dim)
            pred_head['fc1'] = torch.nn.Linear(output_dim, output_dim//mx)
            pred_head['bn2'] = torch.nn.BatchNorm1d(output_dim//mx)
            pred_head['relu'] = torch.nn.ReLU(inplace=True)
            pred_head['fc2'] = torch.nn.Linear(output_dim//mx, output_dim)
            encoder_t.pred = torch.nn.Sequential(pred_head)
        encoder_t.to(device)

    encoder.to(device)

    # classifier = Online_Classifier()
    # classifier.to(gpu)
    logger.info(encoder)
    # logger.info(classifier)

    return encoder, encoder_t


def init_opt(
    encoder,
    iterations_per_epoch,
    start_lr,
    ref_lr,
    ref_mom,
    nesterov,
    warmup,
    num_epochs,
    weight_decay=1e-6,
    final_lr=0.0,
    # classifier=None,
    head_factor=0.1,
    class_wd=0.
):
    param_groups = [
        {'params': (p for n, p in encoder.named_parameters()
                    if ('bias' not in n) and ('bn' not in n) and ('bn' not in n) and ('classifier' not in n))},
        {'params': (p for n, p in encoder.named_parameters()
                    if (('bias' in n) or ('bn' in n)) and ('classifier' not in n)),
         'LARS_exclude': True,
         'weight_decay': 0}
    ]
    # if classifier is not None:
    param_groups += [{'params':(p for n, p in encoder.named_parameters()
                if 'classifier' in n), 'lr':ref_lr*head_factor, 'weight_decay':class_wd}] # scale lr for head
    # param_groups += [{'params': encoder.module.classifier.parameters(), 'lr':ref_lr*head_factor, 'weight_decay':class_wd}] # scale lr for head

    optimizer = SGD(
        param_groups,
        weight_decay=weight_decay,
        momentum=0.9,
        nesterov=nesterov,
        lr=ref_lr)
    scheduler = WarmupCosineSchedule(
        optimizer,
        warmup_steps=warmup*iterations_per_epoch,
        start_lr=start_lr,
        ref_lr=ref_lr,
        final_lr=final_lr,
        T_max=num_epochs*iterations_per_epoch)
    optimizer = LARS(optimizer, trust_coefficient=0.001)
    return encoder, optimizer, scheduler

def smoothen_labels(hard_labels, num_classes=1000, max_classes=960, smoothing=0.1, device='cpu'):
    null_logit = smoothing/num_classes
    t_logit = 1 - null_logit * (num_classes-1)

    # for target > 960 let target = target - 40
    ind = torch.where(hard_labels>=max_classes)[0]
    hard_labels.scatter_(0, ind, hard_labels[ind]-(num_classes-max_classes))
    hard_labels[hard_labels<0] = 0

    label_matrix = torch.ones(len(hard_labels),max_classes) * null_logit
    label_matrix = label_matrix.to(device)
    label_matrix.scatter_(1,hard_labels.unsqueeze(1),t_logit)
    return label_matrix

class MomentumScheduler:
    def __init__(self, init_mom=0.996, final_mom=1, warmup_from=0, warmup_epochs=0, total_epochs = 800, iter_per_epoch=1):
        warmup_iter = iter_per_epoch * warmup_epochs
        warmup_sch = np.linspace(warmup_from, init_mom, warmup_iter)
        num_iter = iter_per_epoch * (total_epochs - warmup_epochs)
        cos_sch = final_mom + 0.5 * (init_mom - final_mom) * (1 + np.cos(np.pi * np.arange(num_iter) / num_iter))
        self.cos_schedule = np.concatenate([warmup_sch,cos_sch])
        self.iter = int(-1)

    def get_mom(self):
        self.iter += 1
        return self.cos_schedule[self.iter]

if __name__ == "__main__":
    main()
