import torch as t
import torch.distributions as td
import torch.nn as nn


class KG():
    r"""
    Simple container class for different components of a covariance matrix.  You shouldn't need to use this unless you are developing your own kernels.

    arg:
        - **ii:** :math:`P_\text{i}\times P_\text{i}` covariance matrix for inducing points.  ``shape=(samples, inducing_batch, inducing_batch)``
        - **it:** :math:`P_\text{i}\times P_\text{t}` covariance matrix for inducing points.  ``shape=(samples, inducing_batch, mbatch)``
        - **tt:** :math:`P_\text{t}` diagonal variances test/train points.  ``shape=(samples, 1, mbatch)``?
    """
    def __init__(self, ii, it, tt):
        self.ii = ii
        self.it = it
        self.tt = tt


class InducingAdd(nn.Module):
    def __init__(self, inducing_batch, inducing_data=None, inducing_shape=None, fixed=False):
        super().__init__()
        assert (inducing_data is not None) != (inducing_shape is not None)

        if inducing_data is None:
            self.inducing_data = nn.Parameter(t.randn(*inducing_shape))
        else:
            i_d = inducing_data.clone().to(t.float32)
            if fixed:
                self.register_buffer("inducing_data", i_d)
            else:
                self.inducing_data = nn.Parameter(i_d)

        assert inducing_batch == self.inducing_data.shape[0]
        self.rank = 1 + len(self.inducing_data.shape)

    def forward(self, x):
        assert self.rank == len(x.shape) 

        inducing_data = self.inducing_data.expand(x.shape[0], *self.inducing_data.shape)
        x = t.cat([inducing_data, x], 1)
         
        return x


class InducingRemove(nn.Module):
    def __init__(self, inducing_batch):
        super().__init__()
        self.inducing_batch = inducing_batch

    def forward(self, x):
        return x[:, self.inducing_batch:]


def InducingWrapper(net, inducing_batch, *args, **kwargs):
    """
    Combines incoming test/train data with learned inducing inputs, then strips away the inducing outputs, just leaving the function approximated at inducing locations. 

    args:
        net (nn.Module): The underlying function approximator, represented as PyTorch modules, to be wrapped.
        inducing_batch (int): The underlying function approximator, represented as PyTorch modules, to be wrapped.

    Keyword Args:
        inducing_shape (Optional[torch.Size]): The size of the inducing inputs, including `inducing_batch` as the first dimension.  Default: ``None``.
        inducing_data (Optional[torch.Tensor]): The values of the inducing inputs. Useful to e.g. initialize the inducing points on top of datapoints.  Default: ``None``.
        fixed (Bool): Do we fix the inducing point locations?  Default: ``False``.

    Must specify one and only one of `inducing_shape` or `inducing_data`
 
    Example:
        >>> import bayesfunc as bf
        >>> import torch as t
        >>> import torch.nn as nn
        >>>
        >>> in_features = 20
        >>> hidden_features = 50
        >>> out_features = 30
        >>>
        >>> m1 = bf.GILinear(in_features, hidden_features, inducing_batch=100)
        >>> m2 = bf.GILinear(hidden_features, out_features, inducing_batch=100)
        >>> net = nn.Sequential(m1, m2)
        >>>
        >>> net = bf.InducingWrapper(net, 100, inducing_shape=(100, in_features))
        >>> output, _, _ = bf.propagate(net, t.randn(3, 128, in_features))
        >>> output.shape
        torch.Size([3, 128, 30])
    """
    ia = InducingAdd(inducing_batch, *args, **kwargs)
    ir = InducingRemove(inducing_batch)
    return nn.Sequential(ia, net, ir)


def logpq(f):
    """
    Extracts log P(f) - log Q(f) by iterating through all modules in the network

    args:
        f: function approximator written as a pytorch module
    """
    total = 0.
    for mod in f.modules():
        if hasattr(mod, "logpq"):
            total += mod.logpq
            mod.logpq = None
    return total


def clear_sample(f):
    for mod in f.modules():
        if hasattr(mod, "_sample"):
            mod._sample = None 


def get_sample_dict(f):
    result = {}
    for (n, m) in f.named_modules():
        if hasattr(m, "_sample"):
            result[n] = m._sample
    return result


def set_sample_dict(f, sample_dict, detach=True):
    mod_dict = {n: m for (n, m) in f.named_modules()}
    for name, sample in sample_dict.items():
        mod = mod_dict[name]
        assert hasattr(mod, "_sample")
        if detach:
            sample = sample.detach()

        mod._sample = sample


def propagate(f, input, sample_dict=None, detach=True):
    """
    The ONLY way to run the neural networks defined in bayesfunc.  Replaces `f(input)`, which will now fail silently!
    
    args:
        f: the bayesfunc function
        input: input to the function

    keyword args:
        sample_dict: optional dictionary of sampled weights, to allow using the same weights for multiple different inputs
        detach: if true, we detach parameters in sample_dict, stopping propagation of gradients

    outputs:
        - **output:** neural network output (as in ``output = f(input)``). 
        - **logpq:** :math:`\log P(f) - \log Q(f)` the difference of prior and approximate posterior log-probabilities.
        - **output_sample_dict:** a dictionary containing all the sampled weights used in the network.  If ``sample_dict`` is set, we have ``output_sample_dict == sample_dict``.

    warning:
        Only properly implemented for ``GILinear``, ``GIConv2d``, ``FactorisedLinear`` and ``FactorisedConv2D``.  Everything else will run, but will independently sample a new function on every invocation, ignoring the ``sample_dict`` input argument.

    In standard use, ``sample_dict`` is never set and ``output_sample_dict`` is never needed.  These only become useful e.g. in continual learning.
    """
    clear_sample(f)
    if sample_dict is not None:
        set_sample_dict(f, sample_dict, detach=detach)
    output = f(input) 
    sample_dict = get_sample_dict(f)
    clear_sample(f)
    return output, logpq(f), sample_dict

    
class NormalLearnedScale(nn.Module):
    def __init__(self):
        super().__init__()
        self.log_scale = nn.Parameter(t.zeros(()))

    def forward(self, x):
        return td.Normal(x, self.log_scale.exp())
