"""
Code adapted from
https://github.com/cambridge-mlg/convcnp/blob/master/convcnp/set_conv.py
"""
from attrdict import AttrDict
import numpy as np
import torch
from torch.distributions.normal import Normal
import torch.nn as nn

from krt.models.conv.modules import ConvDeepSet, SimpleConv, UNet
from krt.models.conv.utils import to_multiple

__all__ = ['ConvCNP']


class ConvCNP(nn.Module):
    """One-dimensional ConvCNP model.

    Args:
        learn_length_scale (bool): Learn the length scale.
        points_per_unit (int): Number of points per unit interval on input.
            Used to discretize function.
        architecture (:class:`nn.Module`): Convolutional architecture to place
            on functional representation (rho).
    """

    def __init__(
        self,
        dim_x: int,
        dim_y: int,
        architecture: str,
        learn_length_scale: bool = True,
        points_per_unit: int = 64,
    ):
        super(ConvCNP, self).__init__()
        assert dim_x == 1
        assert dim_y == 1
        self.activation = nn.Sigmoid()
        self.sigma_fn = nn.Softplus()
        if architecture == 'simple':
            self.conv_net = SimpleConv()
        elif architecture == 'unet':
            self.conv_net = UNet()
        else:
            raise ValueError(f'Unrecognized architecture: {architecture}')
        self.multiplier = 2 ** self.conv_net.num_halving_layers

        # Compute initialisation.
        self.points_per_unit = points_per_unit
        init_length_scale = 2.0 / self.points_per_unit

        self.l0 = ConvDeepSet(
            in_channels=1,
            out_channels=self.conv_net.in_channels,
            learn_length_scale=learn_length_scale,
            init_length_scale=init_length_scale,
            use_density=True
        )
        self.mean_layer = ConvDeepSet(
            in_channels=self.conv_net.out_channels,
            out_channels=1,
            learn_length_scale=learn_length_scale,
            init_length_scale=init_length_scale,
            use_density=False
        )
        self.sigma_layer = ConvDeepSet(
            in_channels=self.conv_net.out_channels,
            out_channels=1,
            learn_length_scale=learn_length_scale,
            init_length_scale=init_length_scale,
            use_density=False
        )
        self.device = 'cpu'

    def to(self, device):
        self.device = device
        self.l0.to(device)
        self.mean_layer.to(device)
        self.sigma_layer.to(device)
        return super().to(device)

    def forward(
        self,
        batch: AttrDict,
        use_training_range_for_min_max: bool = True,
    ):
        """Forward

        Args:
            batch: Batch containing:
                xt: Target x points with shape (batch, L_T, x dim).
                xc: Condition x points with shape (batch, L_C, x dim).
                yc: Condition y points with shape (batch, L_C, y dim)
            use_training_range_for_min_max: Whether we expect the training
                data to be within -2 and 2.

        Returns: Output containing
            mean: Mean prediction w shape (batch, L_T, y_dim)
            std: Standard deviation prediction w shape (batch, L_T, y_dim)
        """
        x = batch.xc
        y = batch.yc
        x_out = batch.xt
        # Ensure that `x`, `y`, and `t` are rank-3 tensors.
        if len(x.shape) == 2:
            x = x.unsqueeze(2)
        if len(y.shape) == 2:
            y = y.unsqueeze(2)
        if len(x_out.shape) == 2:
            x_out = x_out.unsqueeze(2)

        # Determine the grid on which to evaluate functional representation.
        if use_training_range_for_min_max:
            # This was the code in the original repo. However, this will cause
            # problems if there are translational shifts at test time.
            x_min = min(torch.min(x).cpu().numpy(),
                        torch.min(x_out).cpu().numpy(), -2.) - 0.1
            x_max = max(torch.max(x).cpu().numpy(),
                        torch.max(x_out).cpu().numpy(), 2.) + 0.1
        else:
            x_min = min(torch.min(x).cpu().numpy(),
                        torch.min(x_out).cpu().numpy()) - 0.1
            x_max = max(torch.max(x).cpu().numpy(),
                        torch.max(x_out).cpu().numpy()) + 0.1
        num_points = int(to_multiple(self.points_per_unit * (x_max - x_min),
                                     self.multiplier))
        x_grid = torch.linspace(x_min, x_max, num_points).to(self.device)
        x_grid = x_grid[None, :, None].repeat(x.shape[0], 1, 1)

        # Apply first layer and conv net. Take care to put the axis ranging
        # over the data last.
        h = self.activation(self.l0(x, y, x_grid))
        h = h.permute(0, 2, 1)
        h = h.reshape(h.shape[0], h.shape[1], num_points)
        h = self.conv_net(h)
        h = h.reshape(h.shape[0], h.shape[1], -1).permute(0, 2, 1)

        # Check that shape is still fine!
        if h.shape[1] != x_grid.shape[1]:
            raise RuntimeError('Shape changed.')

        # Produce means and standard deviations.
        mean = self.mean_layer(x_grid, h, x_out)
        sigma = self.sigma_fn(self.sigma_layer(x_grid, h, x_out))
        return AttrDict({
            'mean': mean,
            'std': sigma,
        })

    def loss(self, batch, model_out, reduce_ll=True):
        pred_tar = Normal(model_out.mean, model_out.std)
        outs = AttrDict()
        if reduce_ll:
            outs.tar_ll = pred_tar.log_prob(batch.yt).sum(-1).mean()
        else:
            outs.tar_ll = pred_tar.log_prob(batch.yt).sum(-1)
        outs.loss = - (outs.tar_ll)
        outs.stats = {}
        return outs

    def predict(
        self,
        xc,
        yc,
        xt,
        use_training_range_for_min_max: bool = True,
    ) -> Normal:
        """Predict y for the x given the condition.

        Args:
            xc: The x data with shape (num_points, num_conditions, dim_x).
            yc: The y data with shape (num_points, num_conditions, dim_y).
            xt: The x points to predict for as shape (num_points, dim_x).
            use_training_range_for_min_max: Whether we expect the training
                data to be within -2 and 2.
        """
        # Predict.
        model_out = self.forward(AttrDict({
            'xc': xc,
            'yc': yc,
            'xt': xt,
        }), use_training_range_for_min_max=use_training_range_for_min_max)
        return Normal(model_out.mean, model_out.std)

    def seq_ll(
        self,
        xc: torch.Tensor,
        yc: torch.Tensor,
        xt: torch.Tensor,
        yt: torch.Tensor,
        autoreg: bool = True,
        use_training_range_for_min_max: bool = True,
        **kwargs
    ) -> torch.Tensor:
        """Get the log likelihood of the target set given the condition set.

        Args:
            xc: The x conditional points w shape (batch, L_C, D_X)
            yc: The y conditional points w shape (batch, L_C, D_Y)
            xt: The x target points w shape (batch, L_T, D_X)
            yt: The y target points w sahpe (batch, L_T, D_Y).
            autoreg: Whether to do autoregressive approach to compute joint
                joint log likelihood. If this is false then it is assumed that
                target set is conditionally independent.
            use_training_range_for_min_max: Whether we expect the training
                data to be within -2 and 2.

        Returns: Log likelihood of each sequence w shape (batch,)
        """
        B, LT, _ = yt.shape
        if autoreg:
            lls = torch.zeros(B, device=self.device)
            for lt in range(LT):
                curr_xc = torch.cat([xc, xt[:, :lt]], dim=1)
                curr_yc = torch.cat([yc, yt[:, :lt]], dim=1)
                curr_xt = xt[:, lt:]
                curr_yt = yt[:, lt:]
                with torch.no_grad():
                    dist = self.predict(
                        curr_xc, curr_yc, curr_xt,
                        use_training_range_for_min_max=use_training_range_for_min_max,
                    )
                lls += dist.log_prob(curr_yt).sum(dim=-1)[:, 0]
            return lls
        else:
            with torch.no_grad():
                dist = self.predict(
                    xc, yc, xt,
                    use_training_range_for_min_max=use_training_range_for_min_max,
                )
            lls = dist.log_prob(yt).sum(-1).sum(-1)
            return lls

    @property
    def num_params(self):
        """Number of parameters in model."""
        return np.sum([torch.tensor(param.shape).prod()
                       for param in self.parameters()])
