import torch
import torch.nn as nn
import torch.nn.functional as F


class DSPN(nn.Module):
    """ Deep Set Prediction Networks
    Yan Zhang, Jonathon Hare, Adam Prügel-Bennett
    NeurIPS 2019
    https://arxiv.org/abs/1906.06565
    """

    def __init__(self, encoder, set_channels, max_set_size, channels, iters, lr):
        """
        encoder: Set encoder module that takes a set as input and returns a representation thereof.
            It should have a forward function that takes two arguments:
            - a set: FloatTensor of size (batch_size, input_channels, maximum_set_size). Each set
            should be padded to the same maximum size with 0s, even across batches.
            - a mask: FloatTensor of size (batch_size, maximum_set_size). This should take the value 1
            if the corresponding element is present and 0 if not.

        channels: Number of channels of the set to predict.

        max_set_size: Maximum size of the set.

        iter: Number of iterations to run the DSPN algorithm for.

        lr: Learning rate of inner gradient descent in DSPN.
        """
        super().__init__()
        self.encoder = encoder
        self.iters = iters
        self.lr = lr
        self.channels = channels
        self.n_cg = 3
        self.lamb = 100

        self.starting_set = nn.Parameter(0.1 * torch.randn(1, max_set_size, set_channels))

    def objective(self, target_repr, current_set, initial_set):
        # compute representation of current set
        predicted_repr = self.encoder(current_set)
        # how well does the representation matches the target
        repr_loss = 0.5 * F.mse_loss(
            predicted_repr, target_repr, reduction='none'
        ).sum(dim=0).mean()
        regularizer = 0.5 * F.mse_loss(current_set, initial_set, reduction='none').sum(dim=0).mean()

        ret = repr_loss + 0.1 * regularizer
        return ret

    @torch.enable_grad()
    def forward(self, target_repr, starting_set=None, training=None):
        """
        Conceptually, DSPN simply turns the target_repr feature vector into a set.

        target_repr: Representation that the predicted set should match. FloatTensor of size (batch_size, repr_channels).
        Note that repr_channels can be different from self.channels.
        This can come from a set processed with the same encoder as self.encoder (auto-encoder), or a different
        input completely (normal supervised learning), such as an image encoded into a feature vector.
        """
        if training is None:
            training = self.training
        starting_set2 = self.starting_set.expand(
                target_repr.size(0), *self.starting_set.size()[1:]
            )
        if starting_set is None:
        # copy same initial set over batch
            current_set = starting_set2
        else:
            current_set = starting_set

        # info used for loss computation
        intermediate_sets = [current_set]
        # info used for debugging
        repr_losses = []
        grad_norms = []

        if not training:
            current_set = current_set.detach()

        # optimise repr_loss for fixed number of steps
        for i in range(self.iters):
            # regardless of grad setting in train or eval, each iteration requires torch.autograd.grad to be used
            if not current_set.requires_grad:
                current_set.requires_grad_(True)

            repr_loss = self.objective(target_repr, current_set, starting_set2)
            # change to make to set and masks to improve the representation
            set_grad = torch.autograd.grad(
                inputs=[current_set],
                outputs=repr_loss,
                only_inputs=True,
                create_graph=training,
            )[0]
            # update set with gradient descent
            # print(set_grad[0])
            with torch.set_grad_enabled(training):
                current_set = current_set - self.lr * set_grad
            # save some memory in eval mode
            if not training:
                current_set = current_set.detach()
                repr_loss = repr_loss.detach()
                set_grad = set_grad.detach()
            # keep track of intermediates
            intermediate_sets.append(current_set)
            repr_losses.append(repr_loss)
            grad_norms.append(set_grad.norm())
        return intermediate_sets[-1], set_grad
