import torch
import torch.nn as nn
from attrdict import AttrDict
from torch.distributions import MultivariateNormal, Normal
from gpytorch.kernels import RBFKernel, ScaleKernel

import logging
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import normal

import warnings

import abc
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

import operator
from functools import reduce

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import rv_discrete
from torch.distributions import Normal
from torch.distributions.independent import Independent
import math

import torch
import torch.nn as nn
from torch.nn import functional as F

import warnings

import torch
import torch.nn as nn

class GaussianConv2d(nn.Module):
    def __init__(self, kernel_size=5, **kwargs):
        super().__init__()
        self.kwargs = kwargs
        assert kernel_size % 2 == 1
        self.kernel_sizes = (kernel_size, kernel_size)
        self.exponent = -(
            (torch.arange(0, kernel_size).view(-1, 1).float() - kernel_size // 2) ** 2
        )

        self.reset_parameters()

    def reset_parameters(self):
        self.weights_x = nn.Parameter(torch.tensor([1.0]))
        self.weights_y = nn.Parameter(torch.tensor([1.0]))

    def forward(self, X):
        # only switch first time to device
        self.exponent = self.exponent.to(X.device)

        marginal_x = torch.softmax(self.exponent * self.weights_x, dim=0)
        marginal_y = torch.softmax(self.exponent * self.weights_y, dim=0).T

        in_chan = X.size(1)
        filters = marginal_x @ marginal_y
        filters = filters.view(1, 1, *self.kernel_sizes).expand(
            in_chan, 1, *self.kernel_sizes
        )

        return F.conv2d(X, filters, groups=in_chan, **self.kwargs)


class ConvBlock(nn.Module):
    def __init__(
        self,
        in_chan,
        out_chan,
        Conv,
        kernel_size=5,
        dilation=1,
        activation=nn.ReLU(),
        Normalization=nn.Identity,
        **kwargs
    ):
        super().__init__()
        self.activation = activation

        padding = kernel_size // 2

        Conv = make_depth_sep_conv(Conv)

        self.conv = Conv(in_chan, out_chan, kernel_size, padding=padding, **kwargs)
        self.norm = Normalization(in_chan)

        self.reset_parameters()

    def reset_parameters(self):
        weights_init(self)

    def forward(self, X):
        return self.conv(self.activation(self.norm(X)))


class ResConvBlock(nn.Module):
    def __init__(
        self,
        in_chan,
        out_chan,
        Conv,
        kernel_size=5,
        activation=nn.ReLU(),
        Normalization=nn.Identity,
        is_bias=True,
        n_conv_layers=1,
    ):
        super().__init__()
        self.activation = activation
        self.n_conv_layers = n_conv_layers
        assert self.n_conv_layers in [1, 2]

        if kernel_size % 2 == 0:
            raise ValueError("`kernel_size={}`, but should be odd.".format(kernel_size))

        padding = kernel_size // 2

        if self.n_conv_layers == 2:
            self.norm1 = Normalization(in_chan)
            self.conv1 = make_depth_sep_conv(Conv)(
                in_chan, in_chan, kernel_size, padding=padding, bias=is_bias
            )
        self.norm2 = Normalization(in_chan)
        self.conv2_depthwise = Conv(
            in_chan, in_chan, kernel_size, padding=padding, groups=in_chan, bias=is_bias
        )
        self.conv2_pointwise = Conv(in_chan, out_chan, 1, bias=is_bias)

        self.reset_parameters()

    def reset_parameters(self):
        weights_init(self)

    def forward(self, X):

        if self.n_conv_layers == 2:
            out = self.conv1(self.activation(self.norm1(X)))
        else:
            out = X

        out = self.conv2_depthwise(self.activation(self.norm2(out)))
        # adds residual before point wise => output can change number of channels
        out = out + X
        out = self.conv2_pointwise(out.contiguous())  # for some reason need contiguous
        return out


class ResNormalizedConvBlock(ResConvBlock):
    def __init__(
        self,
        in_chan,
        out_chan,
        Conv,
        kernel_size=5,
        activation=nn.ReLU(),
        is_bias=True,
        **kwargs
    ):
        super().__init__(
            in_chan,
            out_chan,
            Conv,
            kernel_size=kernel_size,
            activation=activation,
            is_bias=is_bias,
            Normalization=nn.Identity,
            **kwargs
        )  # make sure no normalization

    def reset_parameters(self):
        weights_init(self)
        self.bias = nn.Parameter(torch.tensor([0.0]))

        self.temperature = nn.Parameter(torch.tensor([0.0]))
        init_param_(self.temperature)

    def forward(self, X):
        signal, conf_1 = X.chunk(2, dim=1)
        # make sure confidence is in 0 1 (might not be due to the pointwise trsnf)
        conf_1 = conf_1.clamp(min=0, max=1)
        X = signal * conf_1

        numerator = self.conv1(self.activation(X))
        numerator = self.conv2_depthwise(self.activation(numerator))
        density = self.conv2_depthwise(self.conv1(conf_1))
        out = numerator / torch.clamp(density, min=1e-5)

        # adds residual before point wise => output can change number of channels

        # make sure that confidence cannot decrease and cannot be greater than 1
        conf_2 = conf_1 + torch.sigmoid(
            density * F.softplus(self.temperature) + self.bias
        )
        conf_2 = conf_2.clamp(max=1)
        out = out + X

        out = self.conv2_pointwise(out)
        conf_2 = self.conv2_pointwise(conf_2)

        return torch.cat([out, conf_2], dim=1)


class CNN(nn.Module):
    def __init__(self, n_channels, ConvBlock, n_blocks=3, is_chan_last=False, **kwargs):

        super().__init__()
        self.n_blocks = n_blocks
        self.is_chan_last = is_chan_last
        self.in_out_channels = self._get_in_out_channels(n_channels, n_blocks)
        self.conv_blocks = nn.ModuleList(
            [
                ConvBlock(in_chan, out_chan, **kwargs)
                for in_chan, out_chan in self.in_out_channels
            ]
        )
        self.is_return_rep = False  # never return representation for vanilla conv

        self.reset_parameters()

    def reset_parameters(self):
        weights_init(self)

    def _get_in_out_channels(self, n_channels, n_blocks):
        """Return a list of tuple of input and output channels."""
        if isinstance(n_channels, int):
            channel_list = [n_channels] * (n_blocks + 1)
        else:
            channel_list = list(n_channels)

        assert len(channel_list) == (n_blocks + 1), "{} != {}".format(
            len(channel_list), n_blocks + 1
        )

        return list(zip(channel_list, channel_list[1:]))

    def forward(self, X):
        if self.is_chan_last:
            X = channels_to_2nd_dim(X)

        X, representation = self.apply_convs(X)

        if self.is_chan_last:
            X = channels_to_last_dim(X)

        if self.is_return_rep:
            return X, representation

        return X

    def apply_convs(self, X):
        for conv_block in self.conv_blocks:
            X = conv_block(X)
        return X, None


class UnetCNN(CNN):
    def __init__(
        self,
        n_channels,
        ConvBlock,
        Pool,
        upsample_mode,
        max_nchannels=256,
        pooling_size=2,
        is_force_same_bottleneck=False,
        is_return_rep=False,
        **kwargs
    ):

        self.max_nchannels = max_nchannels
        super().__init__(n_channels, ConvBlock, **kwargs)
        self.pooling_size = pooling_size
        self.pooling = Pool(self.pooling_size)
        self.upsample_mode = upsample_mode
        self.is_force_same_bottleneck = is_force_same_bottleneck
        self.is_return_rep = is_return_rep

    def apply_convs(self, X):
        n_down_blocks = self.n_blocks // 2
        residuals = [None] * n_down_blocks

        # Down
        for i in range(n_down_blocks):
            X = self.conv_blocks[i](X)
            residuals[i] = X
            X = self.pooling(X)

        # Bottleneck
        X = self.conv_blocks[n_down_blocks](X)
        # Representation before forcing same bottleneck
        representation = X.view(*X.shape[:2], -1).mean(-1)

        if self.is_force_same_bottleneck and self.training:
            # forces the u-net to use the bottleneck by giving additional information
            # there. I.e. taking average between bottleenck of different samples
            # of the same functions. Because bottleneck should be a global representation
            # => should not depend on the sample you chose
            batch_size = X.size(0)
            batch_1 = X[: batch_size // 2, ...]
            batch_2 = X[batch_size // 2 :, ...]
            X_mean = (batch_1 + batch_2) / 2
            X = torch.cat([X_mean, X_mean], dim=0)

        # Up
        for i in range(n_down_blocks + 1, self.n_blocks):
            X = F.interpolate(
                X,
                mode=self.upsample_mode,
                scale_factor=self.pooling_size,
                align_corners=True,
            )
            X = torch.cat(
                (X, residuals[n_down_blocks - i]), dim=1
            )  # concat on channels
            X = self.conv_blocks[i](X)

        return X, representation

    def _get_in_out_channels(self, n_channels, n_blocks):
        """Return a list of tuple of input and output channels for a Unet."""
        # doubles at every down layer, as in vanilla U-net
        factor_chan = 2

        assert n_blocks % 2 == 1, "n_blocks={} not odd".format(n_blocks)
        # e.g. if n_channels=16, n_blocks=5: [16, 32, 64]
        channel_list = [factor_chan ** i * n_channels for i in range(n_blocks // 2 + 1)]
        # e.g.: [16, 32, 64, 64, 32, 16]
        channel_list = channel_list + channel_list[::-1]
        # bound max number of channels by self.max_nchannels (besides first and
        # last dim as this is input / output cand sohould not be changed)
        channel_list = (
            channel_list[:1]
            + [min(c, self.max_nchannels) for c in channel_list[1:-1]]
            + channel_list[-1:]
        )
        # e.g.: [(16, 32), (32,64), (64, 64), (64, 32), (32, 16)]
        in_out_channels = super()._get_in_out_channels(channel_list, n_blocks)
        # e.g.: [(16, 32), (32,64), (64, 64), (128, 32), (64, 16)] due to concat
        idcs = slice(len(in_out_channels) // 2 + 1, len(in_out_channels))
        in_out_channels[idcs] = [
            (in_chan * 2, out_chan) for in_chan, out_chan in in_out_channels[idcs]
        ]
        return in_out_channels

class MLP(nn.Module):
    def __init__(
        self,
        input_size,
        output_size,
        hidden_size=32,
        n_hidden_layers=1,
        activation=nn.ReLU(),
        is_bias=True,
        dropout=0,
        is_force_hid_smaller=False,
        is_res=False,
    ):
        super().__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.n_hidden_layers = n_hidden_layers
        self.is_res = is_res

        if is_force_hid_smaller and self.hidden_size > max(
            self.output_size, self.input_size
        ):
            self.hidden_size = max(self.output_size, self.input_size)
            txt = "hidden_size={} larger than output={} and input={}. Setting it to {}."
            warnings.warn(
                txt.format(hidden_size, output_size, input_size, self.hidden_size)
            )
        elif self.hidden_size < min(self.output_size, self.input_size):
            self.hidden_size = min(self.output_size, self.input_size)
            txt = (
                "hidden_size={} smaller than output={} and input={}. Setting it to {}."
            )
            warnings.warn(
                txt.format(hidden_size, output_size, input_size, self.hidden_size)
            )

        self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity()
        self.activation = activation

        self.to_hidden = nn.Linear(self.input_size, self.hidden_size, bias=is_bias)
        self.linears = nn.ModuleList(
            [
                nn.Linear(self.hidden_size, self.hidden_size, bias=is_bias)
                for _ in range(self.n_hidden_layers - 1)
            ]
        )
        self.out = nn.Linear(self.hidden_size, self.output_size, bias=is_bias)

        self.reset_parameters()

    def forward(self, x):
        out = self.to_hidden(x)
        out = self.activation(out)
        x = self.dropout(out)

        for linear in self.linears:
            out = linear(x)
            out = self.activation(out)
            if self.is_res:
                out = out + x
            out = self.dropout(out)
            x = out

        out = self.out(x)
        return out

    def reset_parameters(self):
        linear_init(self.to_hidden, activation=self.activation)
        for lin in self.linears:
            linear_init(lin, activation=self.activation)
        linear_init(self.out)
        
class UnsharedExpRBF(nn.Module):
    def __init__(
        self,
        x_dim,
        max_dist=1 / 256,
        max_dist_weight=0.99,
        p=2,
        **kwargs
    ):
        super().__init__()

        self.max_dist = max_dist
        self.max_dist_weight = max_dist_weight
        self.length_scale_param = nn.Parameter(torch.tensor([0.0]*2))
        self.p = p
        self.reset_parameters()

    def reset_parameters(self):
        weights_init(self)
        # set the parameter depending on the weight to give to a maxmum distance
        # query. i.e. exp(- (max_dist / sigma).pow(p)) = max_dist_weight
        # => sigma = max_dist / ((- log(max_dist_weight))**(1/p))
        max_dist_sigma = self.max_dist / (
            (-math.log(self.max_dist_weight)) ** (1 / self.p)
        )
        # inverse_softplus : log(exp(y) - 1)
        max_dist_param = math.log(math.exp(max_dist_sigma) - 1)
        self.length_scale_param = nn.Parameter(torch.tensor([max_dist_param]*2))

    def forward(self, diff):

        # size=[batch_size, n_keys, n_queries, 1]
        dist = torch.norm(diff, p=self.p, dim=-1, keepdim=True)

        # compute exponent making sure no division by 0
        sigma = 1e-5 + F.softplus(self.length_scale_param)

        inp = -(dist / sigma).pow(self.p)

        # size=[batch_size, n_keys, n_queries, 1]
        out = torch.exp(inp)
        
        # size=[batch_size, n_keys, 1]
        density = out[...,1:].sum(dim=-2)

        # size=[batch_size, n_keys, n_queries, 1]
        out = out[...,0:1] / (density.unsqueeze(2) + 1e-8)

        return out, density


# META ENCODERS
class DiscardIthArg(nn.Module):
    def __init__(self, *args, i=0, To=nn.Identity, **kwargs):
        super().__init__()
        self.i = i
        self.destination = To(*self.filter_args(*args), **kwargs)

    def filter_args(self, *args):
        return [arg for i, arg in enumerate(args) if i != self.i]

    def forward(self, *args, **kwargs):
        return self.destination(*self.filter_args(*args), **kwargs)


def discard_ith_arg(module, i, **kwargs):
    def discarded_arg(*args, **kwargs2):
        return DiscardIthArg(*args, i=i, To=module, **kwargs, **kwargs2)

    return discarded_arg


class MergeFlatInputs(nn.Module):
    def __init__(self, FlatModule, x1_dim, x2_dim, n_out, is_sum_merge=False, **kwargs):
        super().__init__()
        self.is_sum_merge = is_sum_merge

        if self.is_sum_merge:
            dim = x1_dim
            self.resizer = MLP(x2_dim, dim)  # transform to be the correct size
        else:
            dim = x1_dim + x2_dim

        self.flat_module = FlatModule(dim, n_out, **kwargs)
        self.reset_parameters()

    def reset_parameters(self):
        weights_init(self)

    def forward(self, x1, x2):
        if self.is_sum_merge:
            x2 = self.resizer(x2)
            # use activation because if not 2 linear layers in a row => useless computation
            out = torch.relu(x1 + x2)
        else:
            out = torch.cat((x1, x2), dim=-1)

        return self.flat_module(out)


def merge_flat_input(module, is_sum_merge=False, **kwargs):
    def merged_flat_input(x_shape, flat_dim, n_out, **kwargs2):
        assert isinstance(x_shape, int)
        return MergeFlatInputs(
            module,
            x_shape,
            flat_dim,
            n_out,
            is_sum_merge=is_sum_merge,
            **kwargs2,
            **kwargs
        )

    return merged_flat_input

class ExpRBF(nn.Module):
    def __init__(self, x_dim, max_dist=1 / 256, max_dist_weight=0.9, p=2, **kwargs):
        super().__init__()

        self.max_dist = max_dist
        self.max_dist_weight = max_dist_weight
        self.length_scale_param = nn.Parameter(torch.tensor([0.0]))
        self.p = p
        self.reset_parameters()

    def reset_parameters(self):
        weights_init(self)
        # set the parameter depending on the weight to give to a maxmum distance
        # query. i.e. exp(- (max_dist / sigma).pow(p)) = max_dist_weight
        # => sigma = max_dist / ((- log(max_dist_weight))**(1/p))
        max_dist_sigma = self.max_dist / (
            (-math.log(self.max_dist_weight)) ** (1 / self.p)
        )
        # inverse_softplus : log(exp(y) - 1)
        max_dist_param = math.log(math.exp(max_dist_sigma) - 1)
        self.length_scale_param = nn.Parameter(torch.tensor([max_dist_param]))

    def forward(self, diff):

        # size=[batch_size, n_keys, n_queries, kq_size]
        dist = torch.norm(diff, p=self.p, dim=-1, keepdim=True)

        # compute exponent making sure no division by 0
        sigma = 1e-5 + F.softplus(self.length_scale_param)

        inp = -(dist / sigma).pow(self.p)
        out = torch.softmax(
            inp, dim=-2
        )  # numerically stable normalization of the weights by density

        # size=[batch_size, n_keys, kq_size]
        density = torch.exp(inp).sum(dim=-2)

        return out, density


class MlpRBF(nn.Module):
    def __init__(self, x_dim, is_abs_dist=True, window_size=0.25, **kwargs):
        super().__init__()
        self.is_abs_dist = is_abs_dist
        self.window_size = window_size
        self.mlp = MLP(x_dim, 1, n_hidden_layers=3, hidden_size=16)
        self.reset_parameters()

    def reset_parameters(self):
        weights_init(self)

    def forward(self, diff):
        abs_diff = diff.abs()

        # select only points with distance less than window_size (for extrapolation + speed)
        mask = abs_diff < self.window_size

        if self.is_abs_dist:
            diff = abs_diff

        # sparse operation (apply only on mask) => 2-3x speedup
        weight = mask_and_apply(
            diff, mask, lambda x: self.mlp(x.unsqueeze(1)).abs().squeeze()
        )
        weight = weight * mask.float()  # set to 0 points that are further than windo

        density = weight.sum(dim=-2, keepdim=True)
        out = weight / (density + 1e-5)  # don't divide by 0

        return out, density.squeeze(-1)


class SetConv(nn.Module):
    def __init__(
        self, x_dim, in_channels, out_channels, RadialBasisFunc=ExpRBF, **kwargs
    ):
        super().__init__()
        assert x_dim == 1, "Currently only supports single spatial dimension `x_dim==1`"
        self.radial_basis_func = RadialBasisFunc(x_dim, **kwargs)
        self.resizer = nn.Linear(in_channels + 1, out_channels)
        self.reset_parameters()

    def reset_parameters(self):
        weights_init(self)

    def forward(self, keys, queries, values):
        """
        Compute the set convolution between {key, value} and {query}.

        TODO
        ----
        - should sort the keys and queries to not compute differences if outside
        of given receptive field (large memory savings).

        Parameters
        ----------
        keys : torch.Tensor, size=[batch_size, n_keys, kq_size]
        queries : torch.Tensor, size=[batch_size, n_queries, kq_size]
        values : torch.Tensor, size=[batch_size, n_keys, in_channels]

        Return
        ------
        targets : torch.Tensor, size=[batch_size, n_queries, out_channels]
        """
        # prepares for broadcasted computations
        keys = keys.unsqueeze(1)
        queries = queries.unsqueeze(2)
        values = values.unsqueeze(1)

        # weight size = [batch_size, n_queries, n_keys, 1]
        # density size = [batch_size, n_queries, 1]
        weight, density = self.radial_basis_func(keys - queries)

        # size = [batch_size, n_queries, value_size]
        targets = (weight * values).sum(dim=2)

        # size = [batch_size, n_queries, value_size+1]
        targets = torch.cat([targets, density], dim=-1)

        return self.resizer(targets)
    
def weights_init(module, **kwargs):
    """Initialize a module and all its descendents.

    Parameters
    ----------
    module : nn.Module
       module to initialize.
    """
    module.is_resetted = True
    for m in module.modules():
        try:
            if hasattr(module, "reset_parameters") and module.is_resetted:
                # don't reset if resetted already (might want special)
                continue
        except AttributeError:
            pass

        if isinstance(m, torch.nn.modules.conv._ConvNd):
            # used in https://github.com/brain-research/realistic-ssl-evaluation/
            nn.init.kaiming_normal_(m.weight, mode="fan_out", **kwargs)
        elif isinstance(m, nn.Linear):
            linear_init(m, **kwargs)
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()

def sum_from_nth_dim(t, dim):
    """Sum all dims from `dim`. E.g. sum_after_nth_dim(torch.rand(2,3,4,5), 2).shape = [2,3]"""
    return t.view(*t.shape[:dim], -1).sum(-1)


def logcumsumexp(x, dim):
    """Numerically stable log cumsum exp. SLow workaround waiting for https://github.com/pytorch/pytorch/pull/36308"""

    if (dim != -1) or (dim != x.ndimension() - 1):
        x = x.transpose(dim, -1)

    out = []
    for i in range(1, x.size(-1) + 1):
        out.append(torch.logsumexp(x[..., :i], dim=-1, keepdim=True))
    out = torch.cat(out, dim=-1)

    if (dim != -1) or (dim != x.ndimension() - 1):
        out = out.transpose(-1, dim)
    return out


class LightTailPareto(rv_discrete):
    def _cdf(self, k, alpha):
        # alpha is factor like in SUMO paper
        # m is minimum number of samples
        m = self.a  # lower bound of support

        # in the paper they us P(K >= k) but cdf is P(K <= k) = 1 - P(K > k) = 1 - P(K >= k + 1)
        k = k + 1

        # make sure has at least m samples
        k = np.clip(k - m, a_min=1, a_max=None)  # makes sure no division by 0
        alpha = alpha - m

        # sample using pmf 1/k but with finite expectation
        cdf = 1 - np.where(k < alpha, 1 / k, (1 / alpha) * (0.9) ** (k - alpha))

        return cdf


def isin_range(x, valid_range):
    """Check if array / tensor is in a given range elementwise."""
    return ((x >= valid_range[0]) & (x <= valid_range[1])).all()


def channels_to_2nd_dim(X):
    """
    Takes a signal with channels on the last dimension (for most operations) and
    returns it with channels on the second dimension (for convolutions).
    """
    return X.permute(*([0, X.dim() - 1] + list(range(1, X.dim() - 1))))


def channels_to_last_dim(X):
    """
    Takes a signal with channels on the second dimension (for convolutions) and
    returns it with channels on the last dimension (for most operations).
    """
    return X.permute(*([0] + list(range(2, X.dim())) + [1]))


def mask_and_apply(x, mask, f):
    """Applies a callable on a masked version of a input."""
    tranformed_selected = f(x.masked_select(mask))
    return x.masked_scatter(mask, tranformed_selected)


def indep_shuffle_(a, axis=-1):
    """
    Shuffle `a` in-place along the given axis.

    Apply `numpy.random.shuffle` to the given axis of `a`.
    Each one-dimensional slice is shuffled independently.

    Credits : https://github.com/numpy/numpy/issues/5173
    """
    b = a.swapaxes(axis, -1)
    # Shuffle `b` in-place along the last axis.  `b` is a view of `a`,
    # so `a` is shuffled in place, too.
    shp = b.shape[:-1]
    for ndx in np.ndindex(shp):
        np.random.shuffle(b[ndx])


def ratio_to_int(percentage, max_val):
    """Converts a ratio to an integer if it is smaller than 1."""
    if 1 <= percentage <= max_val:
        out = percentage
    elif 0 <= percentage < 1:
        out = percentage * max_val
    else:
        raise ValueError("percentage={} outside of [0,{}].".format(percentage, max_val))

    return int(out)


def prod(iterable):
    """Compute the product of all elements in an iterable."""
    return reduce(operator.mul, iterable, 1)


def rescale_range(X, old_range, new_range):
    """Rescale X linearly to be in `new_range` rather than `old_range`."""
    old_min = old_range[0]
    new_min = new_range[0]
    old_delta = old_range[1] - old_min
    new_delta = new_range[1] - new_min
    return (((X - old_min) * new_delta) / old_delta) + new_min


def MultivariateNormalDiag(loc, scale_diag):
    """Multi variate Gaussian with a diagonal covariance function (on the last dimension)."""
    if loc.dim() < 1:
        raise ValueError("loc must be at least one-dimensional.")
    return Independent(Normal(loc, scale_diag), 1)


def clamp(
    x,
    minimum=-float("Inf"),
    maximum=float("Inf"),
    is_leaky=False,
    negative_slope=0.01,
    hard_min=None,
    hard_max=None,
):
    lower_bound = (
        (minimum + negative_slope * (x - minimum))
        if is_leaky
        else torch.zeros_like(x) + minimum
    )
    upper_bound = (
        (maximum + negative_slope * (x - maximum))
        if is_leaky
        else torch.zeros_like(x) + maximum
    )
    clamped = torch.max(lower_bound, torch.min(x, upper_bound))

    if hard_min is not None or hard_max is not None:
        if hard_min is None:
            hard_min = -float("Inf")
        elif hard_max is None:
            hard_max = float("Inf")
        clamped = clamp(x, minimum=hard_min, maximum=hard_max, is_leaky=False)

    return clamped


class ProbabilityConverter(nn.Module):
    def __init__(
        self,
        min_p=0.0,
        activation="sigmoid",
        is_train_temperature=False,
        is_train_bias=False,
        trainable_dim=1,
        initial_temperature=1.0,
        initial_probability=0.5,
        initial_x=0,
        bias_transformer=nn.Identity(),
        temperature_transformer=nn.Identity(),
    ):

        super().__init__()
        self.min_p = min_p
        self.activation = activation
        self.is_train_temperature = is_train_temperature
        self.is_train_bias = is_train_bias
        self.trainable_dim = trainable_dim
        self.initial_temperature = initial_temperature
        self.initial_probability = initial_probability
        self.initial_x = initial_x
        self.bias_transformer = bias_transformer
        self.temperature_transformer = temperature_transformer

        self.reset_parameters()

    def reset_parameters(self):
        self.temperature = torch.tensor([self.initial_temperature] * self.trainable_dim)
        if self.is_train_temperature:
            self.temperature = nn.Parameter(self.temperature)

        initial_bias = self._probability_to_bias(
            self.initial_probability, initial_x=self.initial_x
        )

        self.bias = torch.tensor([initial_bias] * self.trainable_dim)
        if self.is_train_bias:
            self.bias = nn.Parameter(self.bias)

    def forward(self, x):
        self.temperature.to(x.device)
        self.bias.to(x.device)

        temperature = self.temperature_transformer(self.temperature)
        bias = self.bias_transformer(self.bias)

        if self.activation == "sigmoid":
            full_p = torch.sigmoid((x + bias) * temperature)

        elif self.activation in ["hard-sigmoid", "leaky-hard-sigmoid"]:
            # uses 0.2 and 0.5 to be similar to sigmoid
            y = 0.2 * ((x + bias) * temperature) + 0.5

            if self.activation == "leaky-hard-sigmoid":
                full_p = clamp(
                    y,
                    minimum=0.1,
                    maximum=0.9,
                    is_leaky=True,
                    negative_slope=0.01,
                    hard_min=0,
                    hard_max=0,
                )
            elif self.activation == "hard-sigmoid":
                full_p = clamp(y, minimum=0.0, maximum=1.0, is_leaky=False)

        else:
            raise ValueError("Unkown activation : {}".format(self.activation))

        p = rescale_range(full_p, (0, 1), (self.min_p, 1 - self.min_p))

        return p

    def _probability_to_bias(self, p, initial_x=0):
        """Compute the bias to use to satisfy the constraints."""
        assert p > self.min_p and p < 1 - self.min_p
        range_p = 1 - self.min_p * 2
        p = (p - self.min_p) / range_p
        p = torch.tensor(p, dtype=torch.float)

        if self.activation == "sigmoid":
            bias = -(torch.log((1 - p) / p) / self.initial_temperature + initial_x)

        elif self.activation in ["hard-sigmoid", "leaky-hard-sigmoid"]:
            bias = ((p - 0.5) / 0.2) / self.initial_temperature - initial_x

        return bias


def dist_to_device(dist, device):
    """Set a distirbution to a given device."""
    if dist is None:
        return
    dist.base_dist.loc = dist.base_dist.loc.to(device)
    dist.base_dist.scale = dist.base_dist.loc.to(device)


def make_abs_conv(Conv):
    """Make a convolution have only positive parameters."""

    class AbsConv(Conv):
        def forward(self, input):
            return F.conv2d(
                input,
                self.weight.abs(),
                self.bias,
                self.stride,
                self.padding,
                self.dilation,
                self.groups,
            )

    return AbsConv


def make_padded_conv(Conv, Padder):
    """Make a convolution have any possible padding."""

    class PaddedConv(Conv):
        def __init__(self, *args, Padder=Padder, padding=0, **kwargs):
            old_padding = 0
            if Padder is None:
                Padder = nn.Identity
                old_padding = padding

            super().__init__(*args, padding=old_padding, **kwargs)
            self.padder = Padder(padding)

        def forward(self, X):
            X = self.padder(X)
            return super().forward(X)

    return PaddedConv


def make_depth_sep_conv(Conv):
    """Make a convolution module depth separable."""

    class DepthSepConv(nn.Module):
        """Make a convolution depth separable.

        Parameters
        ----------
        in_channels : int
            Number of input channels.

        out_channels : int
            Number of output channels.

        kernel_size : int

        **kwargs :
            Additional arguments to `Conv`
        """

        def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            confidence=False,
            bias=True,
            **kwargs
        ):
            super().__init__()
            self.depthwise = Conv(
                in_channels,
                in_channels,
                kernel_size,
                groups=in_channels,
                bias=bias,
                **kwargs
            )
            self.pointwise = Conv(in_channels, out_channels, 1, bias=bias)
            self.reset_parameters()

        def forward(self, x):
            out = self.depthwise(x)
            out = self.pointwise(out)
            return out

        def reset_parameters(self):
            weights_init(self)

    return DepthSepConv


class CircularPad2d(nn.Module):
    """Implements a 2d circular padding."""

    def __init__(self, padding):
        super().__init__()
        self.padding = padding

    def forward(self, x):
        return F.pad(x, (self.padding,) * 4, mode="circular")


class BackwardPDB(torch.autograd.Function):
    """Run PDB in the backward pass."""

    @staticmethod
    def forward(ctx, input, name="debugger"):
        ctx.name = name
        ctx.save_for_backward(input)
        return input

    @staticmethod
    def backward(ctx, grad_output):
        (input,) = ctx.saved_tensors
        if not torch.isfinite(grad_output).all() or not torch.isfinite(input).all():
            import pdb

            pdb.set_trace()
        return grad_output, None  # 2 args so return None for `name`


backward_pdb = BackwardPDB.apply

class NeuralProcessFamily(nn.Module, abc.ABC):
    _valid_paths = ["deterministic", "latent", "both"]

    def __init__(
        self,
        x_dim,
        y_dim,
        encoded_path,
        r_dim=128,
        x_transf_dim=-1,
        is_heteroskedastic=True,
        XEncoder=None,
        Decoder=None,
        PredictiveDistribution=Normal,
        p_y_loc_transformer=nn.Identity(),
        p_y_scale_transformer=lambda y_scale: 0.1 + 0.9 * F.softplus(y_scale),
    ):
        super().__init__()

        self.x_dim = x_dim
        self.y_dim = y_dim
        self.r_dim = r_dim
        self.encoded_path = encoded_path
        self.is_heteroskedastic = is_heteroskedastic

        if x_transf_dim is None:
            self.x_transf_dim = self.x_dim
        elif x_transf_dim == -1:
            self.x_transf_dim = self.r_dim
        else:
            self.x_transf_dim = x_transf_dim

        self.encoded_path = encoded_path.lower()
        if self.encoded_path not in self._valid_paths:
            raise ValueError(f"Unknown encoded_path={self.encoded_path}.")

        if XEncoder is None:
            XEncoder = self.dflt_Modules["XEncoder"]

        if Decoder is None:
            Decoder = self.dflt_Modules["Decoder"]

        self.x_encoder = XEncoder(self.x_dim, self.x_transf_dim)

        # times 2 out because loc and scale (mean and var for gaussian)
        self.decoder = Decoder(self.x_transf_dim, self.r_dim, self.y_dim * 2)

        self.PredictiveDistribution = PredictiveDistribution
        self.p_y_loc_transformer = p_y_loc_transformer
        self.p_y_scale_transformer = p_y_scale_transformer

        self.reset_parameters()

    def reset_parameters(self):
        weights_init(self)

    @property
    def dflt_Modules(self):
        dflt_Modules = dict()

        dflt_Modules["XEncoder"] = partial(
            MLP, n_hidden_layers=1, hidden_size=self.r_dim
        )

        dflt_Modules["SubDecoder"] = partial(
            MLP,
            n_hidden_layers=4,
            hidden_size=self.r_dim,
        )

        dflt_Modules["Decoder"] = merge_flat_input(
            dflt_Modules["SubDecoder"], is_sum_merge=True
        )

        return dflt_Modules
    
    def forward(self, batch, num_samples=None, reduce_ll=True):
        outs = AttrDict()
        py = self.predict(batch.xc, batch.yc, batch.xt, num_samples=num_samples)
        #print(py.mean.shape)
        #print(batch.y.shape)
        #exit()
        
        batch.y = torch.permute(batch.y, (0, 2, 3, 1))
        #print(py.log_prob(batch.y).shape)
        ll = py.log_prob(batch.y)

        if self.training:
            outs.loss = (-torch.permute(batch.x, (0, 2, 3, 1)).unsqueeze(0) * ll).sum(-1)
            outs.loss = outs.loss[outs.loss.nonzero(as_tuple=True)]
            outs.loss = -ll.mean()
            
        else:
            if reduce_ll:
                outs.ctx_ll = (torch.permute(batch.xc, (0, 2, 3, 1)).unsqueeze(0) * ll).sum(-1)
                outs.ctx_ll = outs.ctx_ll[outs.ctx_ll.nonzero(as_tuple=True)].mean()
                outs.tar_ll = (torch.permute(batch.xt, (0, 2, 3, 1)).unsqueeze(0) * ll).sum(-1)
                outs.tar_ll = outs.tar_ll[outs.tar_ll.nonzero(as_tuple=True)].mean()
            else:
                outs.ctx_ll = (torch.permute(batch.xc, (0, 2, 3, 1)).unsqueeze(0) * ll).sum(-1)
                outs.tar_ll = (torch.permute(batch.xt, (0, 2, 3, 1)).unsqueeze(0) * ll).sum(-1)

        return outs
    
    def predict(self, X_cntxt, Y_cntxt, X_trgt, Y_trgt=None, num_samples=None):
        self._validate_inputs(X_cntxt, Y_cntxt, X_trgt, Y_trgt)

        # size = [batch_size, *n_cntxt, x_transf_dim]
        X_cntxt = self.x_encoder(X_cntxt)
        # size = [batch_size, *n_trgt, x_transf_dim]
        X_trgt = self.x_encoder(X_trgt)

        # {R^u}_u
        # size = [batch_size, *n_rep, r_dim]
        R = self.encode_globally(X_cntxt, Y_cntxt)

        if self.encoded_path in ["latent", "both"]:
            z_samples, q_zCc, q_zCct = self.latent_path(X_cntxt, R, X_trgt, Y_trgt)
        else:
            z_samples, q_zCc, q_zCct = None, None, None

        if self.encoded_path == "latent":
            # if only latent path then cannot depend on deterministic representation
            R = None

        # size = [n_z_samples, batch_size, *n_trgt, r_dim]
        R_trgt = self.trgt_dependent_representation(X_cntxt, z_samples, R, X_trgt)

        # p(y|cntxt,trgt)
        # batch shape=[n_z_samples, batch_size, *n_trgt] ; event shape=[y_dim]
        p_yCc = self.decode(X_trgt, R_trgt)

        return p_yCc  #, z_samples, q_zCc, q_zCct

    def _validate_inputs(self, X_cntxt, Y_cntxt, X_trgt, Y_trgt):
        """Validates the inputs by checking if features are rescaled to [-1,1] during training."""
        if self.training:
            if not (isin_range(X_cntxt, [-2, 2]) and isin_range(X_trgt, [-2, 2])):
                raise ValueError(
                    f"Features during training should be in [-1,1]. {X_cntxt.min()} <= X_cntxt <= {X_cntxt.max()} ; {X_trgt.min()} <= X_trgt <= {X_trgt.max()}."
                )

    @abc.abstractmethod
    def encode_globally(self, X_cntxt, R_cntxt):
        pass

    @abc.abstractmethod
    def trgt_dependent_representation(self, X_cntxt, z_samples, R, X_trgt):
        pass

    def latent_path(self, X_cntxt, R, X_trgt, Y_trgt):
        raise NotImplementedError(
            f"`latent_path` not implemented. Cannot use encoded_path={self.encoded_path} in such case."
        )

    def decode(self, X_trgt, R_trgt):
        # size = [n_z_samples, batch_size, *n_trgt, y_dim*2]
        p_y_suffstat = self.decoder(X_trgt, R_trgt)

        # size = [n_z_samples, batch_size, *n_trgt, y_dim]
        p_y_loc, p_y_scale = p_y_suffstat.split(self.y_dim, dim=-1)

        p_y_loc = self.p_y_loc_transformer(p_y_loc)
        p_y_scale = self.p_y_scale_transformer(p_y_scale)

        #! shuld probably pool before p_y_scale_transformer
        if not self.is_heteroskedastic:
            # to make sure not heteroskedastic you pool all the p_y_scale
            # only exact when X_trgt is a constant (e.g. grid case). If not it's a descent approx
            n_z_samples, batch_size, *n_trgt, y_dim = p_y_scale.shape
            p_y_scale = p_y_scale.view(n_z_samples * batch_size, *n_trgt, y_dim)
            p_y_scale = pool_and_replicate_middle(p_y_scale)
            p_y_scale = p_y_scale.view(n_z_samples, batch_size, *n_trgt, y_dim)

            #p_y_scale = 0.1 + nn.Softplus()(p_y_scale.view(n_z_samples, batch_size, *n_trgt, y_dim))

        # batch shape=[n_z_samples, batch_size, *n_trgt] ; event shape=[y_dim]
        p_yCc = self.PredictiveDistribution(p_y_loc, p_y_scale)

        return p_yCc

    def set_extrapolation(self, min_max):
        """Set the neural process for extrapolation."""
        pass


class LatentNeuralProcessFamily(NeuralProcessFamily):
    _valid_paths = ["latent", "both"]

    def __init__(
        self,
        *args,
        is_q_zCct=False,
        num_samples=1,
        LatentEncoder=None,
        LatentDistribution=Normal,
        q_z_loc_transformer=nn.Identity(),
        q_z_scale_transformer=lambda z_scale: 0.1 + 0.9 * torch.sigmoid(z_scale),
        z_dim=None,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

        self.is_q_zCct = is_q_zCct
        self.n_z_samples_train = num_samples
        self.n_z_samples_test = num_samples
        self.z_dim = self.r_dim if z_dim is None else z_dim

        if LatentEncoder is None:
            LatentEncoder = self.dflt_Modules["LatentEncoder"]

        # times 2 out because loc and scale (mean and var for gaussian)
        self.latent_encoder = LatentEncoder(self.r_dim, self.z_dim * 2)

        if self.encoded_path == "both":
            self.r_z_merger = nn.Linear(self.r_dim + self.z_dim, self.r_dim)

        self.LatentDistribution = LatentDistribution
        self.q_z_loc_transformer = q_z_loc_transformer
        self.q_z_scale_transformer = q_z_scale_transformer

        if self.z_dim != self.r_dim and self.encoded_path == "latent":
            # will reshape the z samples to make sure they can be given to the decoder
            self.reshaper_z = nn.Linear(self.z_dim, self.r_dim)

        self.reset_parameters()

    @property
    def dflt_Modules(self):
        # allow inheritence
        dflt_Modules = NeuralProcessFamily.dflt_Modules.__get__(self)

        dflt_Modules["LatentEncoder"] = partial(
            MLP,
            n_hidden_layers=1,
            hidden_size=self.r_dim,
        )

        return dflt_Modules

    def forward(self, *args, **kwargs):

        # make sure that only sampling oce per loop => cannot be a property
        try:
            # if scipy random variable, i.e., random number of samples
            self.n_z_samples = (
                self.n_z_samples_train.rvs()
                if self.training
                else self.n_z_samples_test.rvs()
            )
        except AttributeError:
            self.n_z_samples = (
                self.n_z_samples_train if self.training else self.n_z_samples_test
            )

        return super().forward(*args, **kwargs)

    def _validate_inputs(self, X_cntxt, Y_cntxt, X_trgt, Y_trgt):
        super()._validate_inputs(X_cntxt, Y_cntxt, X_trgt, Y_trgt)

    def latent_path(self, X_cntxt, R, X_trgt, Y_trgt):

        # q(z|c)
        # batch shape = [batch_size, *n_lat] ; event shape = [z_dim]
        q_zCc = self.infer_latent_dist(X_cntxt, R)

        if self.is_q_zCct and Y_trgt is not None:
            # during training when we know Y_trgt, we can take an expectation over q(z|cntxt,trgt)
            # instead of q(z|cntxt). note that actually does q(z|trgt) because trgt has cntxt
            R_from_trgt = self.encode_globally(X_trgt, Y_trgt)
            q_zCct = self.infer_latent_dist(X_trgt, R_from_trgt)
            sampling_dist = q_zCct
        else:
            q_zCct = None
            sampling_dist = q_zCc

        # size = [n_z_samples, batch_size, *n_lat, z_dim]
        z_samples = sampling_dist.rsample([self.n_z_samples])

        return z_samples, q_zCc, q_zCct

    def infer_latent_dist(self, X, R):
        # size = [batch_size, *n_lat, z_dim]
        R_lat_inp = self.rep_to_lat_input(R)

        # size = [batch_size, *n_lat, z_dim*2]
        q_z_suffstat = self.latent_encoder(R_lat_inp)

        q_z_loc, q_z_scale = q_z_suffstat.split(self.z_dim, dim=-1)

        q_z_loc = self.q_z_loc_transformer(q_z_loc)
        q_z_scale = self.q_z_scale_transformer(q_z_scale)

        # batch shape = [batch_size, *n_lat] ; event shape = [z_dim]
        q_zCc = self.LatentDistribution(q_z_loc, q_z_scale)

        return q_zCc

    def rep_to_lat_input(self, R):
        """Transform the n_rep representations to n_lat inputs."""
        # by default *n_rep = *n_lat
        return R

    def merge_r_z(self, R, z_samples):
        """
        Merges the deterministic representation and sampled latent. Assumes that n_lat = n_rep.

        Parameters
        ----------
        R : torch.Tensor, size=[batch_size, *, r_dim]
            Global representation values {r^u}_u.

        z_samples : torch.Tensor, size=[n_z_samples, batch_size, *, r_dim]
            Global representation values {r^u}_u.

        Return
        ------
        out : torch.Tensor, size=[n_z_samples, batch_size, *, r_dim]
        """
        if R.shape != z_samples.shape:

            R = R.unsqueeze(0).expand(*z_samples.shape[:-1], self.r_dim)

        # (add ReLU to not have linear followed by linear)
        return torch.relu(self.r_z_merger(torch.cat((R, z_samples), dim=-1)))
    
def linear_init(module, activation="relu"):
    """Initialize a linear layer.

    Parameters
    ----------
    module : nn.Module
       module to initialize.

    activation : `torch.nn.modules.activation` or str, optional
        Activation that will be used on the `module`.
    """
    x = module.weight

    if module.bias is not None:
        module.bias.data.zero_()

    if activation is None:
        return nn.init.xavier_uniform_(x)

    activation_name ="relu"

    if activation_name == "leaky_relu":
        a = 0 if isinstance(activation, str) else activation.negative_slope
        return nn.init.kaiming_uniform_(x, a=a, nonlinearity="leaky_relu")
    elif activation_name == "relu":
        return nn.init.kaiming_uniform_(x, nonlinearity="relu")
    elif activation_name in ["sigmoid", "tanh"]:
        return nn.init.xavier_uniform_(x, gain=get_gain(activation))

class MLP(nn.Module):
    def __init__(
        self,
        input_size,
        output_size,
        hidden_size=32,
        n_hidden_layers=1,
        activation=nn.ReLU(),
        is_bias=True,
        dropout=0,
        is_force_hid_smaller=False,
        is_res=False,
    ):
        super().__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.n_hidden_layers = n_hidden_layers
        self.is_res = is_res

        if is_force_hid_smaller and self.hidden_size > max(
            self.output_size, self.input_size
        ):
            self.hidden_size = max(self.output_size, self.input_size)
            txt = "hidden_size={} larger than output={} and input={}. Setting it to {}."
            warnings.warn(
                txt.format(hidden_size, output_size, input_size, self.hidden_size)
            )
        elif self.hidden_size < min(self.output_size, self.input_size):
            self.hidden_size = min(self.output_size, self.input_size)
            txt = (
                "hidden_size={} smaller than output={} and input={}. Setting it to {}."
            )
            warnings.warn(
                txt.format(hidden_size, output_size, input_size, self.hidden_size)
            )

        self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity()
        self.activation = activation

        self.to_hidden = nn.Linear(self.input_size, self.hidden_size, bias=is_bias)
        self.linears = nn.ModuleList(
            [
                nn.Linear(self.hidden_size, self.hidden_size, bias=is_bias)
                for _ in range(self.n_hidden_layers - 1)
            ]
        )
        self.out = nn.Linear(self.hidden_size, self.output_size, bias=is_bias)

        self.reset_parameters()

    def forward(self, x):
        out = self.to_hidden(x)
        out = self.activation(out)
        x = self.dropout(out)

        for linear in self.linears:
            out = linear(x)
            out = self.activation(out)
            if self.is_res:
                out = out + x
            out = self.dropout(out)
            x = out

        out = self.out(x)
        return out

    def reset_parameters(self):
        linear_init(self.to_hidden, activation=self.activation)
        for lin in self.linears:
            linear_init(lin, activation=self.activation)
        linear_init(self.out)
        
def prod(iterable):
    """Compute the product of all elements in an iterable."""
    return reduce(operator.mul, iterable, 1)

class CNP_(NeuralProcessFamily):
    _valid_paths = ["deterministic"]

    def __init__(self, x_dim, y_dim, XYEncoder=None, **kwargs):

        # don't force det so that can inherit ,
        kwargs["encoded_path"] = kwargs.get("encoded_path", "deterministic")
        super().__init__(
            x_dim,
            y_dim,
            **kwargs,
        )

        if XYEncoder is None:
            XYEncoder = self.dflt_Modules["XYEncoder"]

        self.xy_encoder = XYEncoder(self.x_transf_dim, self.y_dim, self.r_dim)

        self.reset_parameters()

    @property
    def dflt_Modules(self):
        # allow inheritence
        dflt_Modules = NeuralProcessFamily.dflt_Modules.__get__(self)

        SubXYEncoder = partial(
            MLP,
            n_hidden_layers=2,
            is_force_hid_smaller=True,
            hidden_size=self.r_dim,
        )
        dflt_Modules["XYEncoder"] = merge_flat_input(SubXYEncoder, is_sum_merge=True)

        return dflt_Modules

    def encode_globally(self, X_cntxt, Y_cntxt):
        batch_size, n_cntxt, _ = X_cntxt.shape

        # encode all cntxt pair separately
        # size = [batch_size, n_cntxt, r_dim]
        R_cntxt = self.xy_encoder(X_cntxt, Y_cntxt)

        # using mean for aggregation (i.e. n_rep=1)
        # size = [batch_size, 1, r_dim]
        R = torch.mean(R_cntxt, dim=1, keepdim=True)

        if n_cntxt == 0:
            # arbitrarily setting the global representation to zero when no context
            R = torch.zeros(batch_size, 1, self.r_dim, device=R_cntxt.device)

        return R

    def trgt_dependent_representation(self, _, __, R, X_trgt):

        # same (global) representation for predicting all target point
        batch_size, n_trgt, _ = X_trgt.shape
        R_trgt = R.expand(batch_size, n_trgt, self.r_dim)

        # n_z_samples=1. size = [1, batch_size, n_trgt, r_dim]
        return R_trgt.unsqueeze(0)


class LNP(LatentNeuralProcessFamily, CNP_):
    def __init__(self, x_dim, y_dim, encoded_path="latent", **kwargs):
        super().__init__(x_dim, y_dim, encoded_path=encoded_path, **kwargs)

    def trgt_dependent_representation(self, _, z_samples, R, X_trgt):

        batch_size, n_trgt, _ = X_trgt.shape
        n_z_samples = z_samples.size(0)

        if self.encoded_path == "both":
            # size = [n_z_samples, batch_size, 1, r_dim]
            R_trgt = self.merge_r_z(R, z_samples)

        elif self.encoded_path == "latent":
            # size = [n_z_samples, batch_size, 1, z_dim]
            R_trgt = z_samples

            # size = [n_z_samples, batch_size, 1, r_dim]
            if self.z_dim != self.r_dim:
                R_trgt = self.reshaper_z(R_trgt)

        R_trgt = R_trgt.expand(n_z_samples, batch_size, n_trgt, self.r_dim)

        return R_trgt

def collapse_z_samples_batch(t):
    """Merge n_z_samples and batch_size in a single dimension."""
    n_z_samples, batch_size, *rest = t.shape
    return t.contiguous().view(n_z_samples * batch_size, *rest)


def extract_z_samples_batch(t, n_z_samples, batch_size):
    """`reverses` collapse_z_samples_batch."""
    _, *rest = t.shape
    return t.view(n_z_samples, batch_size, *rest)


def replicate_z_samples(t, n_z_samples):
    """Replicates a tensor `n_z_samples` times on a new first dim."""
    return t.unsqueeze(0).expand(n_z_samples, *t.shape)


def pool_and_replicate_middle(t):
    """Mean pools a tensor on all but the first and last dimension (i.e. all the middle dimension)."""
    first, *middle, last = t.shape

    # size = [first, 1, last]
    t = t.view(first, prod(middle), last).mean(1, keepdim=True)

    t = t.view(first, *([1] * len(middle)), last)
    t = t.expand(first, *middle, last)

    # size = [first, *middle, last]
    return t

class PowerFunction(nn.Module):
    def __init__(self, K=1):
        super().__init__()
        self.K = K

    def forward(self, x):
        return torch.cat(list(map(x.pow, range(self.K + 1))), -1)

class ConvCNP(NeuralProcessFamily):
    _valid_paths = ["deterministic"]
    def __init__(
        self,
        x_dim,
        y_dim,
        density_induced=128,
        Interpolator=SetConv,
        CNN=partial(
            CNN,
            ConvBlock=ResConvBlock,
            Conv=nn.Conv2d,
            n_blocks=3,
            Normalization=nn.Identity,
            is_chan_last=True,
            kernel_size=11,
        ),
        **kwargs,
    ):

        if (
            "Decoder" in kwargs and kwargs["Decoder"] != nn.Identity
        ):  # identity means that not using
            logger.warning(
                "`Decoder` was given to `ConvCNP`. To be translation equivariant you should disregard the first argument for example using `discard_ith_arg(Decoder, i=0)`, which is done by default when you DO NOT provide the Decoder."
            )

        # don't force det so that can inherit ,
        kwargs["encoded_path"] = kwargs.get("encoded_path", "deterministic")
        super().__init__(
            x_dim,
            y_dim,
            x_transf_dim=None,
            XEncoder=nn.Identity,
            **kwargs,
        )

        self.density_induced = density_induced
        # input is between -1 and 1 but use at least 0.5 temporary values on each sides to not
        # have strong boundary effects
        self.X_induced = torch.linspace(-1.5, 1.5, int(self.density_induced * 3))
        self.CNN = CNN

        self.cntxt_to_induced = Interpolator(self.x_dim, self.y_dim, self.r_dim)
        self.induced_to_induced = CNN(self.r_dim)
        self.induced_to_trgt = Interpolator(self.x_dim, self.r_dim, self.r_dim)

        self.reset_parameters()

    @property
    def n_induced(self):
        # using property because this might change after you set extrapolation
        return len(self.X_induced)

    @property
    def dflt_Modules(self):
        # allow inheritence
        dflt_Modules = NeuralProcessFamily.dflt_Modules.__get__(self)

        # don't depend on x
        dflt_Modules["Decoder"] = discard_ith_arg(dflt_Modules["SubDecoder"], i=0)

        return dflt_Modules

    def _get_X_induced(self, X):
        batch_size, _, _ = X.shape

        # effectively puts on cuda only once
        self.X_induced = self.X_induced.to(X.device)
        X_induced = self.X_induced.view(1, -1, 1)
        X_induced = X_induced.expand(batch_size, self.n_induced, self.x_dim)
        return X_induced

    def encode_globally(self, X_cntxt, Y_cntxt):
        batch_size, n_cntxt, _ = X_cntxt.shape

        # size = [batch_size, n_induced, x_dim]
        X_induced = self._get_X_induced(X_cntxt)

        # size = [batch_size, n_induced, r_dim]
        R_induced = self.cntxt_to_induced(X_cntxt, X_induced, Y_cntxt)

        if n_cntxt == 0:
            # arbitrarily setting the global representation to zero when no context
            # but the density channel will also be => makes sense
            R_induced = torch.zeros(
                batch_size, self.n_induced, self.r_dim, device=R_induced.device
            )

        # size = [batch_size, n_induced, r_dim]
        R_induced = self.induced_to_induced(R_induced)

        return R_induced

    def trgt_dependent_representation(self, X_cntxt, z_samples, R_induced, X_trgt):
        batch_size, n_trgt, _ = X_trgt.shape

        # size = [batch_size, n_induced, x_dim]
        X_induced = self._get_X_induced(X_cntxt)

        # size = [batch_size, n_trgt, r_dim]
        R_trgt = self.induced_to_trgt(X_induced, X_trgt, R_induced)

        # n_z_samples=1. size = [1, batch_size, n_trgt, r_dim]
        return R_trgt.unsqueeze(0)

    def set_extrapolation(self, min_max):
        """
        Scale the induced inputs to be in a given range while keeping
        the same density than during training (used for extrapolation.).
        """
        current_min = min_max[0] - 0.5
        current_max = min_max[1] + 0.5
        self.X_induced = torch.linspace(
            current_min,
            current_max,
            int(self.density_induced * (current_max - current_min)),
        )

class CONVCNP2D(NeuralProcessFamily):
    _valid_paths = ["deterministic"]

    def __init__(
        self,
        x_dim,
        y_dim,
        # uses only depth wise + make sure positive to be interpreted as a density
        Conv=lambda y_dim: make_abs_conv(nn.Conv2d)(
            y_dim,
            y_dim,
            groups=y_dim,
            kernel_size=11,
            padding=11 // 2,
            bias=False,
        ),
        CNN=partial(
            CNN,
            ConvBlock=ResConvBlock,
            Conv=nn.Conv2d,
            n_blocks=3,
            Normalization=nn.Identity,
            is_chan_last=True,
            kernel_size=11,
        ),
        **kwargs,
    ):

        assert (
            x_dim == 1 or x_dim == y_dim
        ), "Ensure that featrue masks can be multiplied with Y"

        if (
            "Decoder" in kwargs and kwargs["Decoder"] != nn.Identity
        ):  # identity means that not using
            logger.warning(
                "`Decoder` was given to `ConvCNP`. To be translation equivariant you should disregard the first argument for example using `discard_ith_arg(Decoder, i=0)`, which is done by default when you DO NOT provide the Decoder."
            )

        # don't force det so that can inherit ,
        kwargs["encoded_path"] = kwargs.get("encoded_path", "deterministic")
        super().__init__(
            x_dim,
            y_dim,
            x_transf_dim=None,
            XEncoder=nn.Identity,
            **kwargs,
        )

        self.CNN = CNN
        self.conv = Conv(y_dim)
        self.resizer = nn.Linear(
            self.y_dim * 2, self.r_dim
        )  # 2 because also confidence channels

        self.induced_to_induced = CNN(self.r_dim)

        self.reset_parameters()

    dflt_Modules = ConvCNP.dflt_Modules

    def cntxt_to_induced(self, mask_cntxt, X):
        """Infer the missing values  and compute a density channel."""

        # channels have to be in second dimension for convolution
        # size = [batch_size, y_dim, *grid_shape]
        #X = channels_to_2nd_dim(X)
        # size = [batch_size, x_dim, *grid_shape]
        #mask_cntxt = channels_to_2nd_dim(mask_cntxt).float()

        # size = [batch_size, y_dim, *grid_shape]
        X_cntxt = X * mask_cntxt
        #print(X_cntxt.size())
        signal = self.conv(X_cntxt)
        density = self.conv(mask_cntxt.expand_as(X))

        # normalize
        out = signal / torch.clamp(density, min=1e-5)

        # size = [batch_size, y_dim * 2, *grid_shape]
        out = torch.cat([out, density], dim=1)

        # size = [batch_size, *grid_shape, y_dim * 2]
        out = channels_to_last_dim(out)

        # size = [batch_size, *grid_shape, r_dim]
        out = self.resizer(out)

        return out

    def encode_globally(self, mask_cntxt, X):

        # size = [batch_size, *grid_shape, r_dim]
        R_induced = self.cntxt_to_induced(mask_cntxt, X)
        R_induced = self.induced_to_induced(R_induced)

        return R_induced

    def trgt_dependent_representation(self, _, __, R_induced, ___):

        # n_z_samples=1. size = [1, batch_size, n_trgt, r_dim]
        return R_induced.unsqueeze(0)

    def set_extrapolation(self, min_max):
        raise NotImplementedError("GridConvCNP cannot be used for extrapolation.")