from argparse import Namespace, ArgumentParser

import os
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets
import torchvision.transforms as transforms
from utils import datautils
import models
from utils import utils
import numpy as np
import PIL
from tqdm import tqdm
import sklearn
from utils.lars_optimizer import LARS
import scipy
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

from utils.datautils import PairViewGenerator

import copy
import pdb
import math
import random

class BaseSSL(nn.Module):
    """
    """
    DATA_ROOT = os.environ.get('DATA_ROOT', os.path.dirname(os.path.abspath(__file__)) + '/data')

    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        if hparams.data == 'imagenet':
            self.IMAGENET_PATH = os.environ.get('IMAGENET_PATH', self.hparams.data_path)
            print(f"IMAGENET_PATH = {self.IMAGENET_PATH}")

    def get_ckpt(self):
        return {
            'state_dict': self.state_dict(),
            'hparams': self.hparams,
        }

    @classmethod
    def load(cls, ckpt, device=None):
        parser = ArgumentParser()
        cls.add_model_hparams(parser)
        hparams = parser.parse_args([], namespace=ckpt['hparams'])

        res = cls(hparams, device=device)
        res.load_state_dict(ckpt['state_dict'])
        return res

    @classmethod
    def default(cls, device=None, **kwargs):
        parser = ArgumentParser()
        cls.add_model_hparams(parser)
        hparams = parser.parse_args([], namespace=Namespace(**kwargs))
        res = cls(hparams, device=device)
        return res

    def forward(self, x):
        pass

    def transforms(self):
        pass

    def samplers(self):
        return None, None

    def prepare_data(self, taskaug=None):
        train_transform, test_transform = self.transforms()
        if self.hparams.data == 'cifar':
            self.trainset = datasets.CIFAR10(root=self.DATA_ROOT, train=True, download=True, transform=train_transform)
            self.testset = datasets.CIFAR10(root=self.DATA_ROOT, train=False, download=True, transform=test_transform)
        elif self.hparams.data == 'imagenet':
            traindir = os.path.join(self.IMAGENET_PATH, 'train')
            valdir = os.path.join(self.IMAGENET_PATH, 'val')
            self.trainset = datasets.ImageFolder(traindir, transform=train_transform)
            self.testset = datasets.ImageFolder(valdir, transform=test_transform)
        else:
            raise NotImplementedError

    def dataloaders(self, iters=None):
        train_batch_sampler, test_batch_sampler = self.samplers()
        if iters is not None:
            train_batch_sampler = datautils.ContinousSampler(
                train_batch_sampler,
                iters
            )

        train_loader = torch.utils.data.DataLoader(
            self.trainset,
            num_workers=self.hparams.workers,
            pin_memory=True,
            batch_sampler=train_batch_sampler,
        )
        test_loader = torch.utils.data.DataLoader(
            self.testset,
            num_workers=self.hparams.workers,
            pin_memory=True,
            batch_sampler=test_batch_sampler,
        )

        return train_loader, test_loader

    @staticmethod
    def add_parent_hparams(add_model_hparams):
        def foo(cls, parser):
            for base in cls.__bases__:
                base.add_model_hparams(parser)
            add_model_hparams(cls, parser)
        return foo

    @classmethod
    def add_model_hparams(cls, parser):
        parser.add_argument('--data', help='Dataset to use', default='cifar')
        parser.add_argument('--drop', type=bool, help='whether drop a path in the resnet', default=False)
        parser.add_argument('--arch', default='ResNet50', help='Encoder architecture')
        parser.add_argument('--batch_size', default=256, type=int, help='The number of unique images in the batch')
        parser.add_argument('--aug', default=True, type=bool, help='Applies random augmentations if True')


class SimCLR(BaseSSL):
    @classmethod
    @BaseSSL.add_parent_hparams
    def add_model_hparams(cls, parser):
        # loss params
        #parser.add_argument('--temperature', default=0.1, type=float, help='Temperature in the NTXent loss')
        parser.add_argument('--reg', default=50.0, type=float, help='Temperature in the NTXent loss')
        parser.add_argument('--mix_p', default=1., type=float, help='Temperature in the NTXent loss')
        parser.add_argument('--head', default='contrastive', type=str, help='Temperature in the NTXent loss')
        parser.add_argument('--qcm', default=False, type=bool, help='query cutmix if True')
        parser.add_argument('--adpt', default=False, type=bool, help='query cutmix if True')
        parser.add_argument('--mix_method', default='mixup', type=str, help='query cutmix if True')
        parser.add_argument('--ssm', default=False, type=bool, help='query cutmix if True')
        parser.add_argument('--rottlr', default=-1, type=float, help='query cutmix if True')
        parser.add_argument('--rottlr_p', default=-1, type=float, help='query cutmix if True')
        parser.add_argument('--mixtlr', default=-1, type=float, help='query cutmix if True')
        parser.add_argument('--ori_reg', default=1, type=float, help='query cutmix if True')
        parser.add_argument('--mix_reg', default=1, type=float, help='query cutmix if True')
        parser.add_argument('--maxup', default=False, type=bool, help='query cutmix if True')
        # data params
        parser.add_argument('--multiplier', default=2, type=int)
        parser.add_argument('--rot_div', default=1, type=int)
        parser.add_argument('--bs', default=1, type=int)
        parser.add_argument('--n_supp', default=1, type=int)
        parser.add_argument('--n_query', default=1, type=int)
        parser.add_argument('--maxup_m', default=1, type=int)
        parser.add_argument('--n_proto', default=1, type=int)
        parser.add_argument('--kmeans_iters', default=1, type=int)
        parser.add_argument('--color_dist_s', default=1., type=float, help='Color distortion strength')
        parser.add_argument('--scale_lower', default=0.08, type=float, help='The minimum scale factor for RandomResizedCrop')
        # ddp
        parser.add_argument('--sync_bn', default=True, type=bool,
            help='Syncronises BatchNorm layers between all processes if True'
        )

    def __init__(self, hparams, device=None):
        super().__init__(hparams)

        self.hparams.dist = getattr(self.hparams, 'dist', 'dp')

        model = models.encoder.EncodeProject(hparams)
        self.reset_parameters()
        if device is not None:
            model = model.to(device)
        if self.hparams.dist == 'ddp':
            if self.hparams.sync_bn:
                model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
            dist.barrier()
            if device is not None:
                model = model.to(device)
            self.model = DDP(model, [hparams.gpu], find_unused_parameters=True)
        elif self.hparams.dist == 'dp':
            self.model = nn.DataParallel(model)
        else:
            raise NotImplementedError

        if self.hparams.rottlr > 0:
            self.criterion_rot = models.losses.Xent_rot_random(
                tau=self.hparams.temperature,
                multiplier=self.hparams.multiplier,
                distributed=(self.hparams.dist == 'ddp'),
            )

        if self.hparams.head == 'contrastive':
            self.criterion = models.losses.NTXent(
                tau=hparams.temperature,
                bs=hparams.bs,
                tlr=hparams.rottlr,
                multiplier=hparams.multiplier,
                distributed=(hparams.dist == 'ddp'),
            )
            self.test_criterion = self.criterion
        elif self.hparams.head == 'PN':
            self.criterion = models.losses.NTXent_PN(
                    bs=hparams.bs,
                    n_supp=hparams.n_supp,
                    n_query=hparams.n_query,
                    tau=hparams.temperature,
                    tlr=hparams.rottlr,
                    reg=hparams.reg,
                    multiplier=hparams.multiplier,
                    distributed=(hparams.dist == 'ddp'),
                )

            self.criterion_mix = models.losses.NTXent_PN(
                    bs=hparams.bs,
                    n_supp=hparams.n_supp,
                    n_query=hparams.n_query,
                    tau=hparams.temperature,
                    tlr=-1,
                    reg=hparams.reg,
                    multiplier=hparams.multiplier,
                    distributed=(hparams.dist == 'ddp'),
                )
            self.test_criterion = self.criterion
        elif self.hparams.head == 'R2D2':
            if self.hparams.ssm:
                n_query = int(hparams.n_query//2)
                multiplier = hparams.multiplier - n_query
                self.criterion = models.losses.NTXent_R2D2(
                    n_supp=hparams.n_supp,
                    n_query=n_query,
                    tau=hparams.temperature,
                    reg=hparams.reg,
                    multiplier=multiplier,
                    distributed=(hparams.dist == 'ddp'),
                )
                self.test_criterion = self.criterion

            else:
                self.criterion = models.losses.NTXent_R2D2(
                    bs=hparams.bs,
                    n_supp=hparams.n_supp,
                    n_query=hparams.n_query,
                    tau=hparams.temperature,
                    tlr=hparams.rottlr,
                    reg=hparams.reg,
                    multiplier=hparams.multiplier,
                    distributed=(hparams.dist == 'ddp'),
                )
                self.test_criterion = self.criterion

    def reset_parameters(self):
        def conv2d_weight_truncated_normal_init(p):
            fan_in = p.shape[1]
            stddev = np.sqrt(1. / fan_in) / .87962566103423978
            r = scipy.stats.truncnorm.rvs(-2, 2, loc=0, scale=1., size=p.shape)
            r = stddev * r
            with torch.no_grad():
                p.copy_(torch.FloatTensor(r))

        def linear_normal_init(p):
            with torch.no_grad():
                p.normal_(std=0.01)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                conv2d_weight_truncated_normal_init(m.weight)
            elif isinstance(m, nn.Linear):
                linear_normal_init(m.weight)


    def rand_bbox(self, size, lam=0.5):
        W = size[-1]
        H = size[-2]
        cut_rat = np.sqrt(1. - lam)
        cut_w = np.int(W * cut_rat)
        cut_h = np.int(H * cut_rat)

        ## uniform
        cx = np.random.randint(W)
        cy = np.random.randint(H)

        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)

        return bbx1, bby1, bbx2, bby2


    def step(self, batch, it):
        if self.hparams.rottlr > 0 and self.hparams.qcm:
            np.random.seed(it)
            r = np.random.rand(1)
            prob = np.random.rand(1)

            np.random.seed(it)
            lam = np.random.beta(1., 1.)
            torch.manual_seed(it)

            x, _ = batch

            ## mix and rotate image batch
            xs = x[:self.hparams.n_supp * self.hparams.batch_size]
            xq = x[self.hparams.n_supp * self.hparams.batch_size:]

            xs_r = torch.flip(xs, (0,))
            xq_r = torch.flip(xq, (0,))

            if prob < 0.5:
                # global-level mixtures
                mixed_q = lam*xq+(1-lam)* xq_r
            else:
                # region-level mixtures
                mixed_q = xq.clone()

                bbx1, bby1, bbx2, bby2 = self.rand_bbox(xs.size(), lam)
                mixed_q[:, :, bbx1:bbx2, bby1:bby2] = xq_r[:, :, bbx1:bbx2, bby1:bby2]

                lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (xs.size()[-1] * xs.size()[-2]))

            x_mixed = mixed_q

            x_rots = []
            rot_len = int(x.shape[0]/self.hparams.rot_div)
            ks = torch.tensor(np.random.randint(4, size=rot_len)).cuda()
            for i in range(rot_len):
                x_rots.append(torch.rot90(x[i], ks[i], [1, 2]))
            x_rots = torch.stack(x_rots)
     
            ## ori & mixup
            x_mixed = torch.cat([xs, xq, mixed_q])
            z = self.model(x_mixed)

            zs = z[:self.hparams.batch_size]
            zq = z[self.hparams.batch_size:(2*self.hparams.batch_size)]
            zqmixed = z[(2*self.hparams.batch_size):]

            z_ori = torch.cat([zs, zq])
            z_mixed1 = torch.cat([zs, zqmixed])
            z_mixed2 = torch.cat([zs, torch.flip(zqmixed, (0,))])

            loss_mixed1, acc = self.criterion(z_mixed1)
            loss_mixed2, acc = self.criterion(z_mixed2)

            loss_ori, acc = self.criterion(z_ori)

            ##loss for large rotation
            z_rots = self.model(x_rots, out='rot')
            loss_rot = self.criterion_rot(z_rots, ks)

            ##loss for query mix term
            loss_mix = lam * loss_mixed1 + (1-lam) * loss_mixed2

            rot_reg = self.hparams.rottlr 
            loss = self.hparams.ori_reg * loss_ori + rot_reg * loss_rot + self.hparams.mix_reg * loss_mix
        elif self.hparams.qcm:
            np.random.seed(it)
            r = np.random.rand(1)
            prob = np.random.rand(1)

            np.random.seed(it)
            lam = np.random.beta(1., 1.)
            torch.manual_seed(it)
            x, _ = batch

            xs = x[:self.hparams.n_supp * self.hparams.batch_size]
            xq = x[self.hparams.n_supp * self.hparams.batch_size:]

            xs_r = torch.flip(xs, (0,))
            xq_r = torch.flip(xq, (0,))

            if prob < 0.5:
                # global-level mixtures
                mixed_q = lam*xq+(1-lam)* xq_r
            else:
                # region-level mixtures
                mixed_q = xq.clone()

                bbx1, bby1, bbx2, bby2 = self.rand_bbox(xs.size(), lam)
                mixed_q[:, :, bbx1:bbx2, bby1:bby2] = xq_r[:, :, bbx1:bbx2, bby1:bby2]

                lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (xs.size()[-1] * xs.size()[-2]))

            x_mixed = torch.cat([xs, xq, mixed_q])
            z = self.model(x_mixed)

            zs = z[:self.hparams.batch_size]
            zq = z[self.hparams.batch_size:(2*self.hparams.batch_size)]
            zqmixed = z[(2*self.hparams.batch_size):]

            z_ori = torch.cat([zs, zq])
            z_mixed1 = torch.cat([zs, zqmixed])
            z_mixed2 = torch.cat([zs, torch.flip(zqmixed, (0,))])

            loss_mixed1, acc = self.criterion(z_mixed1)
            loss_mixed2, acc = self.criterion(z_mixed2)

            loss_ori, acc = self.criterion(z_ori)
            loss = self.hparams.ori_reg * loss_ori + lam * loss_mixed1 + (1-lam) * loss_mixed2
        elif self.hparams.rottlr > 0 and self.hparams.mixtlr > 0:
            np.random.seed(it)   
            r = np.random.rand(1)
            prob = np.random.rand(1)

            np.random.seed(it)  
            lam = np.random.beta(1., 1.)
            if r < self.hparams.mixtlr:
                torch.manual_seed(it)
                x, _ = batch

                xs = x[:self.hparams.n_supp * self.hparams.batch_size]
                xq = x[self.hparams.n_supp * self.hparams.batch_size:]

                xs_r = torch.flip(xs, (0,))
                xq_r = torch.flip(xq, (0,))

                if prob < 0.5:
                    # global-level mixtures
                    mixed_s = lam*xs+(1-lam)* xs_r
                    mixed_q = lam*xq+(1-lam)* xq_r
                else:
                    # region-level mixtures
                    mixed_s = xs.clone()
                    mixed_q = xq.clone()

                    bbx1, bby1, bbx2, bby2 = self.rand_bbox(xs.size(), lam)
                    mixed_s[:, :, bbx1:bbx2, bby1:bby2] = xs_r[:, :, bbx1:bbx2, bby1:bby2] 
                    mixed_q[:, :, bbx1:bbx2, bby1:bby2] = xq_r[:, :, bbx1:bbx2, bby1:bby2] 

                    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (xs.size()[-1] * xs.size()[-2]))

                x_mixed = torch.cat([mixed_s, mixed_q])
                z_mixed = self.model(x_mixed)

                self.criterion_mix = models.losses.NTXent(
                    tau=self.hparams.temperature,
                    bs=self.hparams.bs,
                    tlr=-1,
                    multiplier=self.hparams.multiplier,
                    distributed=(self.hparams.dist == 'ddp'),
                )

                loss_mixed, acc = self.criterion_mix(z_mixed)

                n_way = int(x.shape[0]/self.hparams.multiplier)
                x_ = x[:n_way]
                x_rots = []
                rotate_labels = [1,2,3]
                for k in rotate_labels:
                    x_rots.append(torch.rot90(x_, k, [2, 3]))
                x_rots = torch.cat(x_rots)
                x = torch.cat([x, x_rots])

                z = self.model(x)
                loss_ori, acc = self.criterion(z)
                loss = self.hparams.ori_reg * loss_ori + loss_mixed
            else:
                torch.manual_seed(it)
                x, _ = batch
                z = self.model(x)
                loss, acc = self.criterion(z)
        elif self.hparams.rottlr > 0:
            torch.manual_seed(it)

            x, _ = batch
            n_way = int(x.shape[0]/self.hparams.multiplier)

            x_rots = []
            rot_len = int(x.shape[0]/self.hparams.rot_div)
            ks = torch.tensor(np.random.randint(4, size=rot_len)).cuda()
            for i in range(rot_len):
                x_rots.append(torch.rot90(x[i], ks[i], [1, 2]))
            x_rots = torch.stack(x_rots)

            z = self.model(x)
            z_rots = self.model(x_rots, out='rot')

            ##original loss from contrastive learning or meta-learning
            #loss_ori, acc = self.criterion(z[:(n_way*self.hparams.multiplier)])
            loss_ori, acc = self.criterion(z)

            ##loss for large rotation
            loss_rot = self.criterion_rot(z_rots, ks)

            rot_reg = (self.hparams.rottlr) #* (math.cos(math.pi * it/self.hparams.iters) + 1) / 2
            loss = self.hparams.ori_reg * loss_ori + rot_reg * loss_rot

        elif self.hparams.rottlr_p > 0:
            torch.manual_seed(it)
            np.random.seed(it)   

            x, _ = batch
            n_way = int(x.shape[0]/self.hparams.multiplier)

            prob = np.random.rand(n_way)
            ks = np.random.randint(4, size=n_way)
            for i in range(n_way):
                if prob[i] <= self.hparams.rottlr_p:
                    x[i] = torch.rot90(x[i], ks[i], [1, 2])
                    x[i + 256] = torch.rot90(x[i + 256], ks[i], [1, 2])

            z = self.model(x)
            loss, acc = self.criterion(z)
        else:
            torch.manual_seed(it)
            random.seed(it)
            x, _ = batch
            z = self.model(x)
            loss, acc = self.criterion(z)
        return {
            'loss': loss,
            'contrast_acc': acc,
        }


    def self_mix(self, data):
        size = data.size()
        W = size[-1]
        H = size[-2]
        ## uniform
        cx = np.random.randint(W)
        cy = np.random.randint(H)

        cut_w = W//2
        cut_h = H//2

        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)

        while True:
            bbxn = np.random.randint(0, W-(bbx2-bbx1))
            bbyn = np.random.randint(0, H-(bby2-bby1))
    
            if bbxn != bbx1 or bbyn != bby1:
                break

        ## random rotate the croped when it's square
        if (bbx2 - bbx1) == (bby2 - bby1):
            k = random.sample([0, 1, 2, 3], 1)[0]
        else:
            k = 0
        data[:, :, bbx1:bbx2, bby1:bby2] = torch.rot90(data[:, :, bbxn:bbxn + (bbx2-bbx1), bbyn:bbyn + (bby2-bby1)], k, [2,3])

        return data


    def test_forward(self, batch):
        x, _ = batch
        z = self.model(x)
        loss, acc = self.test_criterion(z)
        return {
            'loss': loss,
            'contrast_acc': acc,
        }

    def encode(self, x):
        return self.model(x, out='h')

    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def train_step(self, batch, it=None):
        logs = self.step(batch, it)

        if self.hparams.dist == 'ddp':
            self.trainsampler.set_epoch(it)
        if it is not None:
            logs['epoch'] = it / len(self.batch_trainsampler)

        return logs

    def test_step(self, batch):
        return self.test_forward(batch)

    def samplers(self):
        if self.hparams.dist == 'ddp':
            trainsampler = torch.utils.data.distributed.DistributedSampler(self.trainset)
            print(f'Process {dist.get_rank()}: {len(trainsampler)} training samples per epoch')
            testsampler = torch.utils.data.distributed.DistributedSampler(self.testset)
            print(f'Process {dist.get_rank()}: {len(testsampler)} test samples')
        else:
            trainsampler = torch.utils.data.sampler.RandomSampler(self.trainset)
            testsampler = torch.utils.data.sampler.RandomSampler(self.testset)

        batch_sampler = datautils.MultiplyBatchSampler
        batch_sampler.MULTILPLIER = self.hparams.multiplier

        # need for DDP to sync samplers between processes
        self.trainsampler = trainsampler
        self.batch_trainsampler = batch_sampler(trainsampler, self.hparams.batch_size, drop_last=True)

        return (
            self.batch_trainsampler,
            batch_sampler(testsampler, self.hparams.batch_size, drop_last=True)
        )

    def transforms(self):
        if self.hparams.data == 'cifar':
            random_transform = transforms.Compose([
                transforms.RandomResizedCrop(
                    32,
                    scale=(self.hparams.scale_lower, 1.0),
                    interpolation=PIL.Image.BICUBIC,
                ),
                transforms.RandomHorizontalFlip(),
                datautils.get_color_distortion(s=self.hparams.color_dist_s),
                transforms.ToTensor(),
                datautils.Clip(),
            ])
            train_transform = random_transform

            test_transform = train_transform

        elif self.hparams.data == 'imagenet':
            from utils.datautils import GaussianBlur

            im_size = 224
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(
                    im_size,
                    scale=(self.hparams.scale_lower, 1.0),
                    interpolation=PIL.Image.BICUBIC,
                ),
                transforms.RandomHorizontalFlip(0.5),
                datautils.get_color_distortion(s=self.hparams.color_dist_s),
                transforms.ToTensor(),
                GaussianBlur(im_size // 10, 0.5),
                datautils.Clip(),
            ])
            test_transform = train_transform
        return train_transform, test_transform

    def get_ckpt(self):
        return {
            'state_dict': self.model.module.state_dict(),
            'hparams': self.hparams,
        }

    def load_state_dict(self, state):
        k = next(iter(state.keys()))
        if k.startswith('model.module'):
            super().load_state_dict(state)
        else:
            self.model.module.load_state_dict(state)


class SSLEval(BaseSSL):
    @classmethod
    @BaseSSL.add_parent_hparams
    def add_model_hparams(cls, parser):
        parser.add_argument('--test_bs', default=256, type=int)
        parser.add_argument('--encoder_ckpt', default='', help='Path to the encoder checkpoint')
        parser.add_argument('--precompute_emb_bs', default=-1, type=int,
            help='If it\'s not equal to -1 embeddings are precomputed and fixed before training with batch size equal to this.'
        )
        parser.add_argument('--finetune', default=False, type=bool, help='Finetunes the encoder if True')
        parser.add_argument('--augmentation', default='RandomResizedCrop', help='')
        parser.add_argument('--scale_lower', default=0.08, type=float, help='The minimum scale factor for RandomResizedCrop')

    def __init__(self, hparams, device=None):
        super().__init__(hparams)

        self.hparams.dist = getattr(self.hparams, 'dist', 'dp')

        if hparams.encoder_ckpt != '':
            ckpt = torch.load(hparams.encoder_ckpt, map_location=device)
            if getattr(ckpt['hparams'], 'dist', 'dp') == 'ddp':
                ckpt['hparams'].dist = 'dp'
            if self.hparams.dist == 'ddp':
                ckpt['hparams'].dist = 'gpu:%d' % hparams.gpu

            self.encoder = models.REGISTERED_MODELS[ckpt['hparams'].problem].load(ckpt, device=device)
        else:
            print('===> Random encoder is used!!!')
            self.encoder = SimCLR.default(device=device)
        self.encoder.to(device)

        if not hparams.finetune:
            for p in self.encoder.parameters():
                p.requires_grad = False
            self.encoder.eval()
        elif hparams.dist == 'ddp':
            raise NotImplementedError
        else:
            self.encoder.train()

        if hparams.data == 'cifar':
            hdim = self.encode(torch.ones(32, 3, 32, 32).to(device)).shape[1]
            n_classes = 10
        elif hparams.data == 'imagenet':
            hdim = self.encode(torch.ones(32, 3, 224, 224).to(device)).shape[1]
            n_classes = 1000

        if hparams.arch == 'linear':
            model = nn.Linear(hdim, n_classes).to(device)
            model.weight.data.zero_()
            model.bias.data.zero_()
            self.model = model
        else:
            raise NotImplementedError

        if hparams.dist == 'ddp':
            self.model = DDP(model, [hparams.gpu])

    def encode(self, x):
        return self.encoder.model(x, out='h')

    def step(self, batch):
        if self.hparams.problem == 'eval' and self.hparams.data == 'imagenet':
            batch[0] = batch[0] / 255.
        h, y = batch
        if self.hparams.precompute_emb_bs == -1:
            h = self.encode(h)
        p = self.model(h)
        loss = F.cross_entropy(p, y)
        acc = (p.argmax(1) == y).float()
        return {
            'loss': loss,
            'acc': acc,
        }

    def forward(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def train_step(self, batch, it=None):
        logs = self.step(batch)
        if it is not None:
            iters_per_epoch = len(self.trainset) / self.hparams.batch_size
            iters_per_epoch = max(1, int(np.around(iters_per_epoch)))
            logs['epoch'] = it / iters_per_epoch
        if self.hparams.dist == 'ddp' and self.hparams.precompute_emb_bs == -1:
            self.object_trainsampler.set_epoch(it)

        return logs

    def test_step(self, batch):
        logs = self.step(batch)
        if self.hparams.dist == 'ddp':
            utils.gather_metrics(logs)
        return logs

    def prepare_data(self, taskaug=None):
        super().prepare_data()

        def create_emb_dataset(dataset):
            embs, labels = [], []
            loader = torch.utils.data.DataLoader(
                dataset,
                num_workers=self.hparams.workers,
                pin_memory=True,
                batch_size=self.hparams.precompute_emb_bs,
                shuffle=False,
            )
            for x, y in tqdm(loader):
                if self.hparams.data == 'imagenet':
                    x = x.to(torch.device('cuda'))
                    x = x / 255.
                e = self.encode(x)
                embs.append(utils.tonp(e))
                labels.append(utils.tonp(y))
            embs, labels = np.concatenate(embs), np.concatenate(labels)
            dataset = torch.utils.data.TensorDataset(torch.FloatTensor(embs), torch.LongTensor(labels))
            return dataset

        if self.hparams.precompute_emb_bs != -1:
            print('===> Precompute embeddings:')
            assert not self.hparams.aug
            with torch.no_grad():
                self.encoder.eval()
                self.testset = create_emb_dataset(self.testset)
                self.trainset = create_emb_dataset(self.trainset)
        
        print(f'Train size: {len(self.trainset)}')
        print(f'Test size: {len(self.testset)}')

    def dataloaders(self, iters=None):
        if self.hparams.dist == 'ddp' and self.hparams.precompute_emb_bs == -1:
            trainsampler = torch.utils.data.distributed.DistributedSampler(self.trainset)
            testsampler = torch.utils.data.distributed.DistributedSampler(self.testset, shuffle=False)
        else:
            trainsampler = torch.utils.data.RandomSampler(self.trainset)
            testsampler = torch.utils.data.SequentialSampler(self.testset)

        self.object_trainsampler = trainsampler
        trainsampler = torch.utils.data.BatchSampler(
            self.object_trainsampler,
            batch_size=self.hparams.batch_size, drop_last=False,
        )
        if iters is not None:
            trainsampler = datautils.ContinousSampler(trainsampler, iters)

        train_loader = torch.utils.data.DataLoader(
            self.trainset,
            num_workers=self.hparams.workers,
            pin_memory=True,
            batch_sampler=trainsampler,
        )
        test_loader = torch.utils.data.DataLoader(
            self.testset,
            num_workers=self.hparams.workers,
            pin_memory=True,
            sampler=testsampler,
            batch_size=self.hparams.test_bs,
        )
        return train_loader, test_loader

    def transforms(self):
        if self.hparams.data == 'cifar':
            trs = []
            if 'RandomResizedCrop' in self.hparams.augmentation:
                trs.append(
                    transforms.RandomResizedCrop(
                        32,
                        scale=(self.hparams.scale_lower, 1.0),
                        interpolation=PIL.Image.BICUBIC,
                    )
                )
            if 'RandomCrop' in self.hparams.augmentation:
                trs.append(transforms.RandomCrop(32, padding=4, padding_mode='reflect'))
            if 'color_distortion' in self.hparams.augmentation:
                trs.append(datautils.get_color_distortion(self.encoder.hparams.color_dist_s))

            train_transform = transforms.Compose(trs + [
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                datautils.Clip(),
            ])
            test_transform = transforms.Compose([
                transforms.ToTensor(),
            ])
        elif self.hparams.data == 'imagenet':
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(
                    224#,
                    #scale=(self.hparams.scale_lower, 1.0),
                    #interpolation=PIL.Image.BICUBIC,
                ),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                lambda x: (255*x).byte(),
            ])
            test_transform = transforms.Compose([
                datautils.CenterCropAndResize(proportion=0.875, size=224),
                transforms.ToTensor(),
                lambda x: (255 * x).byte(),
            ])
        return train_transform if self.hparams.aug else test_transform, test_transform

    def train(self, mode=True):
        if self.hparams.finetune:
            super().train(mode)
        else:
            self.model.train(mode)

    def get_ckpt(self):
        return {
            'state_dict': self.state_dict() if self.hparams.finetune else self.model.state_dict(),
            'hparams': self.hparams,
        }

    def load_state_dict(self, state):
        if self.hparams.finetune:
            super().load_state_dict(state)
        else:
            if hasattr(self.model, 'module'):
                self.model.module.load_state_dict(state)
            else:
                self.model.load_state_dict(state)

class SemiSupervisedEval(SSLEval):
    @classmethod
    @BaseSSL.add_parent_hparams
    def add_model_hparams(cls, parser):
        parser.add_argument('--train_size', default=-1, type=int)
        parser.add_argument('--data_split_seed', default=42, type=int)
        parser.add_argument('--n_augs_train', default=-1, type=int)
        parser.add_argument('--n_augs_test', default=-1, type=int)
        parser.add_argument('--acc_on_unlabeled', default=False, type=bool)

    def prepare_data(self):
        super(SSLEval, self).prepare_data()

        if len(self.trainset) != self.hparams.train_size:
            idxs, unlabeled_idxs = sklearn.model_selection.train_test_split(
                np.arange(len(self.trainset)),
                train_size=self.hparams.train_size,
                random_state=self.hparams.data_split_seed,
            )
            if self.hparams.data == 'cifar' or self.hparams.data == 'cifar100':
                if self.hparams.acc_on_unlabeled:
                    self.trainset_unlabeled = copy.deepcopy(self.trainset)
                    self.trainset_unlabeled.data = self.trainset.data[unlabeled_idxs]
                    self.trainset_unlabeled.targets = np.array(self.trainset.targets)[unlabeled_idxs]
                    print(f'Test size (0): {len(self.testset)}')
                    print(f'Unlabeled train size (1):  {len(self.trainset_unlabeled)}')

                self.trainset.data = self.trainset.data[idxs]
                self.trainset.targets = np.array(self.trainset.targets)[idxs]

                print('Training dataset size:', len(self.trainset))
            else:
                assert not self.hparams.acc_on_unlabeled
                if isinstance(self.trainset, torch.utils.data.TensorDataset):
                    self.trainset.tensors = [t[idxs] for t in self.trainset.tensors]
                else:
                    #self.trainset.samples = [self.trainset.samples[i] for i in idxs]
                    train_files = urllib.request.urlopen(f'imagenet_subsets/{self.hparams.train_size}percent.txt').readlines()
                    samples = []
                    for fname in train_files:
                        fname = fname.decode().strip()
                        cls = fname.split('_')[0]
                        samples.append(
                        (self.IMAGENET_PATH + 'train/' + cls + '/' + fname, self.trainset.class_to_idx[cls]))
                    self.trainset.samples = samples

                print('Training dataset size:', len(self.trainset))

    def transforms(self):
        train_transform, test_transform = SSLEval.transforms(self)
        return (
            train_transform if self.hparams.n_augs_train == -1 else ens_train_transfom,
            test_transform if self.hparams.n_augs_test == -1 else ens_test_transform
        )

    def step(self, batch, it=None):
        if 'eval' in self.hparams.problem and self.hparams.data == 'imagenet':
            batch[0] = batch[0] / 255.
        h, y = batch
        if len(h.shape) == 4:
            h = self.encode(h)
        p = self.model(h)
        loss = F.cross_entropy(p, y)
        acc = (p.argmax(1) == y).float()
        return {
            'loss': loss,
            'acc': acc,
        }

    def test_step(self, batch):
        if not self.hparams.acc_on_unlabeled:
            return super().test_step(batch)
        # TODO: refactor
        x, y, d = batch
        logs = {}
        keys = set()
        for didx in [0, 1]:
            if torch.any(d == didx):
                t = super().test_step([x[d == didx], y[d == didx]])
                for k, v in t.items():
                    keys.add(k)
                    logs[k + f'_{didx}'] = v
        for didx in [0, 1]:
            for k in keys:
                logs[k + f'_{didx}'] = logs.get(k + f'_{didx}', torch.tensor([]))
        return logs


def configure_optimizers(args, model, cur_iter=-1):
    iters = args.iters

    def exclude_from_wd_and_adaptation(name):
        if 'bn' in name:
            return True
        if args.opt == 'lars' and 'bias' in name:
            return True

    if args.opt == 'lars':
        param_groups = [
            {
                'params': [p for name, p in model.named_parameters() if not exclude_from_wd_and_adaptation(name)],
                'weight_decay': args.weight_decay,
                'layer_adaptation': True,
            },
            {
                'params': [p for name, p in model.named_parameters() if exclude_from_wd_and_adaptation(name)],
                'weight_decay': 0.,
                'layer_adaptation': False,
            },
        ]
    else:
        param_groups = [
            {
                'params': [p for name, p in model.named_parameters()],
                'weight_decay': args.weight_decay,
            },
        ]        


    LR = args.lr

    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(
            param_groups,
            lr=LR,
            momentum=0.9,
        )
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(
            param_groups,
            lr=LR,
        )
    elif args.opt == 'lars':
        optimizer = torch.optim.SGD(
            param_groups,
            lr=LR,
            momentum=0.9,
        )
        larc_optimizer = LARS(optimizer)
    else:
        raise NotImplementedError

    if args.lr_schedule == 'warmup-anneal':
        scheduler = utils.LinearWarmupAndCosineAnneal(
            optimizer,
            args.warmup,
            iters,
            last_epoch=cur_iter,
        )
    elif args.lr_schedule == 'linear':
        scheduler = utils.LinearLR(optimizer, iters, last_epoch=cur_iter)
    elif args.lr_schedule == 'const':
        scheduler = None
    elif args.lr_schedule == 'step':
        m = [int(950/1000 * args.iters), int(975/1000 * args.iters)]
        scheduler = utils.LinearWarmupAndMultiStep(
            optimizer,
            args.warmup,
            iters,
            milestones=m, gamma=0.2,
            last_epoch=cur_iter,
        )
    else:
        raise NotImplementedError

    if args.opt == 'lars':
        optimizer = larc_optimizer

    return optimizer, scheduler
