""" Set of auxiliary functions and classes. """

from __future__ import print_function
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
from torch.optim import Optimizer
import sys
import torch.nn.functional as F
from torchvision.models.inception import inception_v3
import numpy as np
import math
from scipy.stats import entropy


def error_print(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)
    print(*args, file=sys.stdout, **kwargs)


def terminate_on_nan(loss):
    if torch.isnan(loss).any():
        error_print("Terminating program -- NaN detected.")
        exit()


def count_pars(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def list2string(list_to_parse):
    output = ""
    for list_elem in list_to_parse:
        output += str(list_elem) + "_"
    return output


class EMA(nn.Module):
    """ Exponential Moving Average.
    Note that we store shadow params as learnable parameters. We force torch.save() to store them properly..
    """

    def __init__(self, model, decay):
        super().__init__()
        if decay < 0.0 or decay > 1.0:
            raise ValueError('EMA decay must in [0,1]')
        self.decay = decay
        self.shadow_params = nn.ParameterDict({})
        self.train_params = {}
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow_params[self._dotless(name)] = nn.Parameter(param.data.clone(), requires_grad=False)

    @staticmethod
    def _dotless(name):
        return name.replace('.', '^')

    @torch.no_grad()
    def update(self, model):
        if self.decay > 0:
            for name, param in model.named_parameters():
                if param.requires_grad:
                    self.shadow_params[self._dotless(name)].data = \
                        self.decay * self.shadow_params[self._dotless(name)].data + (1.0 - self.decay) * param.data

    @torch.no_grad()
    def assign(self, model):
        # ema assignment
        train_params_has_items = bool(self.train_params)
        if self.decay > 0:
            for name, param in model.named_parameters():
                if param.requires_grad:
                    if train_params_has_items:
                        self.train_params[name].data.copy_(param.data)
                    else:
                        self.train_params[name] = param.data.clone()
                    param.data.copy_(self.shadow_params[self._dotless(name)].data)

    @torch.no_grad()
    def restore(self, model):
        if self.decay > 0:
            for name, param in model.named_parameters():
                if param.requires_grad:
                    param.data.copy_(self.train_params[name].data)


class Crop2d(nn.Module):

    def __init__(self, num):
        super().__init__()
        self.num = num

    def forward(self, input):
        if self.num == 0:
            return input
        else:
            return input[:, :, self.num:-self.num, self.num:-self.num]


def weights_init(module):
    """ Weight initialization for different neural network components. """
    classname = module.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.orthogonal_(module.weight)
    elif classname.find('BatchNorm') != -1:
        module.weight.data.normal_(1.0, 0.02)
        module.bias.data.fill_(0)
    elif classname.find('Linear') != -1:
        module.weight.data.normal_(0.0, 0.02)


def WN(module, norm=True):
    classname = module.__class__.__name__
    if norm:
        if classname.find('ConvTranspose') != -1:
            return weight_norm(module, dim=1, name='weight')
        elif classname.find('Conv') != -1:
            return weight_norm(module, dim=0, name='weight')
        else:
            return module
    else:
        return module


class MaskedConv2d(nn.Conv2d):
    """ Masked version of a regular 2D CNN. """
    def __init__(self, mask_type, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert mask_type in {'A', 'B'}
        self.register_buffer('mask', self.weight.data.clone())
        _, _, kH, kW = self.weight.size()
        self.mask.fill_(1)
        self.mask[:, :, kH // 2, kW // 2 + (mask_type == 'B'):] = 0
        self.mask[:, :, kH // 2 + 1:] = 0

    def forward(self, x):
        self.weight.data *= self.mask
        return super(MaskedConv2d, self).forward(x)


class ARCNN(nn.Module):
    def __init__(self, num_layers, num_outputs, z_size, h_size):
        super().__init__()
        self.num_outputs = num_outputs
        self.conv_a = MaskedConv2d(mask_type='A', in_channels=z_size, out_channels=h_size,
                                   kernel_size=3, stride=1, padding=1)
        self.conv_b = []
        for i in range(num_layers):
            self.conv_b.append(nn.ELU(True))
            self.conv_b.append(MaskedConv2d(mask_type='B',
                                            in_channels=h_size,
                                            out_channels=z_size * num_outputs if i == (num_layers-1) else h_size,
                                            kernel_size=3, stride=1, padding=1))
        self.conv_b = nn.Sequential(*self.conv_b)

    def forward(self, x, context):
        x = self.conv_b(self.conv_a(x) + context)
        return list(x.chunk(self.num_outputs, 1))


class Adamax(Optimizer):
    """Implements Adamax algorithm (a variant of Adam based on infinity norm).

    It has been proposed in `Adam: A Method for Stochastic Optimization`__.

    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 2e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)

    __ https://arxiv.org/abs/1412.6980
    """

    def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(Adamax, self).__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError('Adamax does not support sparse gradients')
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                    state['exp_inf'] = torch.zeros_like(p, memory_format=torch.preserve_format)

                exp_avg, exp_inf = state['exp_avg'], state['exp_inf']
                beta1, beta2 = group['betas']
                eps = group['eps']

                state['step'] += 1

                if group['weight_decay'] != 0:
                    grad = grad.add(p, alpha=group['weight_decay'])

                # Update biased first moment estimate.
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                # Update the exponentially weighted infinity norm.
                norm_buf = torch.cat([
                    exp_inf.mul_(beta2).unsqueeze(0),
                    grad.abs().add_(eps).unsqueeze_(0)
                ], 0)

                torch.max(norm_buf, 0, keepdim=False, out=(exp_inf, torch.empty_strided(exp_inf.size(),
                                                                                        exp_inf.stride(),
                                                                                        dtype=torch.long,
                                                                                        layout=exp_inf.layout,
                                                                                        device=exp_inf.device)))

                bias_correction = 1 - beta1 ** state['step']
                clr = group['lr'] / bias_correction

                p.addcdiv_(exp_avg, exp_inf, value=-clr)

        return loss


def one_hot(indices, depth, dim):
    indices = indices.unsqueeze(dim)
    size = list(indices.size())
    size[dim] = depth
    y_onehot = torch.zeros(size).cuda()
    y_onehot.zero_()
    y_onehot.scatter_(dim, indices, 1)

    return y_onehot


class Quantize(object):
    """Quantize tensor images which are expected to be in [0, 1]. """

    def __init__(self, nbits=8):
        self.nbits = nbits

    def __call__(self, tensor):
        if self.nbits < 8:
            tensor = torch.floor(tensor * 255 / 2 ** (8 - self.nbits))
            tensor /= (2 ** self.nbits - 1)
        return tensor

    def __repr__(self):
        return self.__class__.__name__ + '(nbits={0})'.format(self.nbits)


def run_cuda_diagnostics(requested_num_gpus):
    print("CUDA available: ", torch.cuda.is_available())
    print("Requested num devices: ", requested_num_gpus)
    print("Available num of devices: ", torch.cuda.device_count())
    print("CUDNN backend: ", torch.backends.cudnn.enabled)
    assert requested_num_gpus <= torch.cuda.device_count(), "Not enough GPUs available."


class Flatten3D(nn.Module):
    def forward(self, x):
        x = x.view(x.size()[0], -1)
        return x


class Unsqueeze3D(nn.Module):
    def forward(self, x):
        x = x.unsqueeze(-1)
        x = x.unsqueeze(-1)
        return x


class InceptionScorer(nn.Module):

    def __init__(self, device, splits=1):
        super().__init__()
        self.device = device
        classifier = inception_v3(pretrained=True, transform_input=False).to(device)
        self.classifier = classifier.eval()
        self.splits = splits
        self.up = nn.Upsample(size=(299, 299), mode='bilinear').to(device)

    def score(self, image):

        # convert to corresponding device
        image = image.to(self.device)

        # add batch dimension if needed
        if image.dim() == 3:
            image = image.unsqueeze(0)

        # resize if needed
        if image.shape[2] != 299 or image.shape[3] != 299:
            image = self.up(image)

        # classify
        y = self.classifier(image).detach()

        # get
        preds = F.softmax(y).data.cpu().numpy()

        # Now compute the mean kl-div
        split_scores = []
        N = image.size(0)

        for k in range(self.splits):
            part = preds[k * (N // self.splits): (k + 1) * (N // self.splits), :]
            py = np.mean(part, axis=0)
            scores = []
            for i in range(part.shape[0]):
                pyx = part[i, :]
                scores.append(entropy(pyx, py))
            split_scores.append(np.exp(np.mean(scores)))

        return np.mean(split_scores)

        """
        # compute entropy
        ent = F.softmax(y, dim=1) * F.log_softmax(y, dim=1)
        ent = ent.sum().item()
        # compute softmax
        sm = F.softmax(y, dim=1).sum(dim=0)
        sm = sm * torch.log(sm)
        # compute score
        output = ent - sm.sum().item()
        return math.exp(output)
        """


class Reshape4x4(nn.Module):
    def forward(self, x):
        x = x.view(list(x.shape[:-1]) + [-1, 4, 4])
        return x


class Contiguous(nn.Module):
    def forward(self, x):
        x = x.contiguous()
        return x


class EncWrapper(nn.Module):
    # Implemented for compatibility with disentanglement_lib evaluation
    def __init__(self, encoder_list):
        super().__init__()
        self.enc_list = nn.Sequential(*encoder_list)

    def forward(self, input):
        self.z_params = self.enc_list(input)
        mu, logvar = self.z_params.chunk(2, dim=1)
        return mu, logvar