#!/usr/bin/env python3
import torch
import torchvision
from uimnet import utils
from uimnet.algorithms.erm import ERM
import torch.cuda.amp as amp
import numpy as np


class DeepAE(ERM):
    """
    Random Network Distillation
    https://arxiv.org/abs/1810.12894
    """
    HPARAMS = dict(ERM.HPARAMS)
    HPARAMS.update({
        "encoder_width": (256, lambda: int(np.random.choice([64, 128, 256]))),
        "encoder_depth": (2, lambda: int(np.random.choice([2, 3, 4]))),
        "zdim": (128, lambda: int(np.random.choice([32, 64, 128]))),
    })
    def __init__(
        self,
        num_classes,
        arch,
        device="cuda",
        seed=0,
        use_mixed_precision=False, sn=False, sn_coef=1, sn_bn=False):

        super(DeepAE, self).__init__(
            num_classes,
            arch,
            device,
            seed,
            use_mixed_precision=use_mixed_precision, sn=sn, sn_coef=sn_coef, sn_bn=sn_bn)

        self.ae_loss = torch.nn.MSELoss()

        self.has_native_measure = True

    def construct_networks(self):
        # init network, conveniently decomposed in featurizer and classifier
        featurizer = torchvision.models.__dict__[self.arch](
            num_classes=self.num_classes,
            pretrained=False,
            # important when using large batch sizes (Goyal & al 2017)
            zero_init_residual=True)

        classifier = torch.nn.Linear(
            featurizer.fc.in_features,
            featurizer.fc.out_features, bias=True)
        torch.nn.init.normal_(classifier.weight, mean=0, std=0.01)
        featurizer.fc = utils.Identity()

        encoder = utils.MLP(
            [classifier.in_features] +
            [self.hparams["encoder_width"]] * (self.hparams["encoder_depth"] - 1) + [self.hparams['zdim']],
          batch_norm=True)

        decoder = utils.MLP(
            [self.hparams['zdim']] +
            [self.hparams["encoder_width"]] * (self.hparams["encoder_depth"] - 1) + [classifier.in_features],
          batch_norm=True)

        return dict(featurizer=featurizer,
                    classifier=classifier,
                    encoder=encoder,
                    decoder=decoder)

    def update(self, x, y, epoch=None):
        if epoch is not None:
            self.adjust_learning_rate_(epoch)

        for param in self.parameters():
            param.grad = None

        x, y = self.process_minibatch(x, y)
        with amp.autocast(enabled=self.use_mixed_precision):
            h = self.networks['featurizer'](x)
            s = self.networks['classifier'](h)

            loss = self.loss(s, y)

            h_hat = self.networks['decoder'](self.networks['encoder'](h.detach()))
            ae_loss = self.ae_loss(h_hat, h)

            cost = loss + ae_loss + \
                self.hparams['weight_decay'] * self.get_l2_reg()


        self.grad_scaler.scale(cost).backward()

        for name, optimizer in self.optimizers.items():
          self.grad_scaler.step(optimizer)
          self.grad_scaler.update()

        return {
            'loss': loss.item(),
          'ae_loss': ae_loss.item(),
            'cost': cost.item(),}
    def uncertainty(self, x):
        h = self.networks['featurizer'](x.to(self.device))
        h_hat = self.networks['decoder'](self.networks['encoder'](h))
        return (h - h_hat).pow(2).mean(dim=1)

if __name__ == '__main__':
  pass
