# Copyright (c) 2022 Qualcomm Technologies, Inc.
# All rights reserved.

import torch
from torch import nn
from torch.nn import functional as F
from itertools import chain

from .base import Encoder
from nets import make_mlp
from util import inverse_softplus
import logging
from .slotpair import SlotAutoEncoder, SlotAverage, SlotMatchMean, SlotMatchMax
from .respair import ResNet18Dec, ResNet50Dec
import torchvision.models as models
from torchvision.models.resnet import ResNet50_Weights, ResNet18_Weights


class GaussianEncoder(Encoder):
    """
    Scalar Gaussian encoder / decoder for a VAE.

    Computes mean and log std for each data variable, samples via reparameterization trick, and
    evaluates the log likelihood / posterior.

    To avoid numerical issues with large reconstruction errors, the default initialization ensures
    a large variance in
    data space.
    """

    def __init__(
        self,
        encoder_type,
        encoder_decoder="encoder",
        hidden=(100,),
        input_features=2,
        output_features=2,
        fix_std=False,
        init_std=1.0,
        min_std=1.0e-3,
        amin=1e-4,
        resolution=64
    ):
        super().__init__(input_features, output_features)

        assert not fix_std or init_std is not None

        self.encoder_type = encoder_type
        self.encoder_decoder = encoder_decoder
        self.maskmin = amin
        self._min_std = min_std if not fix_std else 0.0
        self.resolution = resolution
        self.net, self.mean_head, self.std_head = self._create_nets(hidden, fix_std=fix_std, init_std=init_std)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(
        self,
        inputs,
        num_samples=1,
        eval_likelihood_at=None,
        deterministic=False,
        return_mean=False,
        return_std=False,
        full=True,
        reduction="sum",
    ):
        """
        Forward transformation.

        In an encoder: takes as input the observed data x and returns the latent representation
        inputs.
        In a decoder: takes as input the latent representation inputs and returns the reconstructed
        data x.

        Parameters:
        -----------
        inputs : torch.Tensor with shape (batchsize, input_features), dtype torch.float
            Data to be encoded or decoded

        Returns:
        --------
        outputs : torch.Tensor with shape (batchsize, output_features), dtype torch.float
            Encoded or decoded version of the data
        log_likelihood : torch.Tensor with shape (batchsize, output_features), dtype torch.float
            Log likelihood evaluated at eval_likelihood_at or at outputs.
        encoder_std : torh.Tensor, optional
            If `return_std` is True, returns the encoder std
        """

        # Compute mean and log std for latent variables
        mean, std = self.mean_std(inputs)
        return gaussian_encode(
            mean,
            std,
            num_samples,
            eval_likelihood_at,
            deterministic,
            return_mean=return_mean,
            return_std=return_std,
            full=full,
            reduction=reduction,
        )

    def _create_nets(self, hidden, fix_std=False, init_std=None):

        if self.encoder_type == "mlp":
            dims = [self.input_features] + hidden
            main_net = make_mlp(dims, final_activation="relu")

        elif self.encoder_type[:3] in ['res', 'vit']:

            if self.encoder_decoder == "encoder":
                if self.encoder_type == "resnet18":
                    model_ft = models.__dict__[self.encoder_type](weights=None)
                else:
                    model_ft = models.__dict__[self.encoder_type](weights=None)
                dims = [model_ft.fc.in_features]
                main_net = torch.nn.Sequential(*list(model_ft.children())[:-2])

            if self.encoder_decoder == "decoder":
                if self.encoder_type == "resnet18":
                    main_net = ResNet18Dec(z_dim=self.input_features, nc=self.output_features, resolution=self.resolution)
                else:
                    main_net = ResNet50Dec(z_dim=self.input_features, nc=self.output_features, resolution=self.resolution)
                dims = [self.output_features]

        else:
            raise NotImplementedError

        # Initialize mean head to be close to identity
        mean_head = nn.Linear(dims[-1], self.output_features)
        if dims[-1] == self.output_features:
            mean_head.weight.data.copy_(
                torch.eye(self.output_features)
                + 0.1 * torch.randn(self.output_features, self.output_features)
            )

        # Standard deviation head
        if fix_std:
            std_head = nn.Linear(dims[-1], self.output_features)
            nn.init.constant_(std_head.weight, 0.0)
            if init_std is not None:
                assert init_std > 0.0
                init_value = inverse_softplus(init_std)
                nn.init.constant_(std_head.bias, init_value)
            for param in std_head.parameters():
                param.requires_grad = False
        else:
            linear_layer = nn.Linear(dims[-1], self.output_features)
            nn.init.normal_(linear_layer.weight, 0.0, 1.0e-3)
            if init_std is not None:
                assert init_std > 0.0
                init_value = inverse_softplus(init_std - self._min_std)
                nn.init.constant_(linear_layer.bias, init_value)
            std_head = nn.Sequential(linear_layer, nn.Softplus())

        return main_net, mean_head, std_head

    def mean_std(self, x, s=None):
        """Given data, compute mean and std"""

        hidden = self.net(x)

        if self.encoder_type == "mlp":
            mean = self.mean_head(hidden)
            std = self._min_std + self.std_head(hidden)
        elif self.encoder_type[:3] == "res":
            if self.encoder_decoder == "encoder":
                if s is not None:
                    assert self.maskmin > 0
                    # min value outside instance mask
                    s = torch.clamp(s, min=self.maskmin)
                    # average pooling over the instance mask
                    h = (hidden * s).sum((2, 3)) / s.sum((2, 3))
                else:
                    h = self.avgpool(hidden).flatten(1)

                mean = self.mean_head(h)
                std = self._min_std + self.std_head(h)
            else:
                mean = self.mean_head(hidden.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
                std = self._min_std + self.std_head(hidden.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

        return mean, std

    def freezable_parameters(self):
        """Returns parameters that should be frozen during training"""
        return chain(self.net.parameters(), self.mean_head.parameters(), self.std_head.parameters())

    def unfreezable_parameters(self):
        """Returns parameters that should not be frozen during training"""
        return []


def gaussian_encode(
    mean,
    std,
    num_samples=1,
    eval_likelihood_at=None,
    deterministic=False,
    return_mean=False,
    return_std=False,
    full=True,
    reduction="sum",
):
    """
    Given mean and std of Gaussian, compute likelihoods and sample.

    In an encoder: takes as input the observed data x and returns the latent representation inputs.
    In a decoder: takes as input the latent representation inputs and returns the reconstructed data
    x.

    Parameters:
    -----------
    mean : torch.Tensor with shape (batchsize, input_features), dtype torch.float
    std : torch.Tensor with shape (batchsize, input_features), dtype torch.float

    Returns:
    --------
    outputs : torch.Tensor with shape (batchsize, output_features), dtype torch.float
        Encoded or decoded version of the data
    log_likelihood : torch.Tensor with shape (batchsize, output_features), dtype torch.float
        Log likelihood evaluated at eval_likelihood_at or at outputs.
    encoder_std : torh.Tensor, optional
        If `return_std` is True, returns the encoder std
    """
    if deterministic:
        z = mean
    else:

        u = torch.randn_like(mean)
        z = mean + std * u

    # Compute log likelihood
    if eval_likelihood_at is None:
        log_likelihood = gaussian_log_likelihood(z, mean, std, full=full, reduction=reduction)
    else:
        log_likelihood = gaussian_log_likelihood(
            eval_likelihood_at, mean, std, full=full, reduction=reduction
        )

    # Package results
    results = [z, log_likelihood]
    if return_mean:
        results.append(mean)
    if return_std:
        results.append(std)

    return tuple(results)


def gaussian_log_likelihood(x, mean, std, full=True, reduction="sum"):
    """
    Computes the log likelihood of a multivariate factorized Gaussian.

    The Gaussian log likelihood is
    `log p(x) = sum_i [- log_std_i - 0.5 * log (2 pi) - 0.5 (x_i - mu_i)^2 / exp(log_std_i)^2]`.
    """

    feature_dims = list(range(1, len(x.shape)))
    var = std**2
    log_likelihood = -F.gaussian_nll_loss(mean, x, var, full=full, reduction="none")

    if reduction == "sum":
        log_likelihood = torch.sum(log_likelihood, dim=feature_dims).unsqueeze(1)
    elif reduction == "mean":
        log_likelihood = torch.mean(log_likelihood, dim=feature_dims).unsqueeze(1)
    elif reduction == "none":
        pass
    else:
        raise ValueError(f"Unknown likelihood reduction {reduction}")

    return log_likelihood


