from __future__ import annotations

from typing import Sequence

import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor

from .utils import infer_leading_dims


class MlpModel(nn.Module):
    """Multilayer Perceptron with last layer linear.

    Args:
        input_size (int): number of inputs
        hidden_sizes (list): can be empty list for none (linear model).
        output_size: linear layer at output, or if ``None``, the last hidden size will be the output size and will have nonlinearity applied
        nonlinearity: torch nonlinearity Module (not Functional).
    """

    def __init__(
        self,
        input_size: int,
        hidden_sizes: int | Sequence[int] | None,
        output_size: int | None,
        hidden_nonlinearity: type[nn.Module] = nn.ReLU,
    ):
        super().__init__()
        if isinstance(hidden_sizes, int):
            hidden_sizes = [hidden_sizes]
        elif hidden_sizes is None:
            hidden_sizes = []
        else:
            hidden_sizes = list(hidden_sizes)
        hidden_layers = [
            nn.Linear(n_in, n_out)
            for n_in, n_out in zip([input_size] + hidden_sizes[:-1], hidden_sizes)
        ]
        sequence = list()
        for layer in hidden_layers:
            sequence.extend([layer, hidden_nonlinearity()])
        if output_size is not None:
            last_size = hidden_sizes[-1] if hidden_sizes else input_size
            sequence.append(nn.Linear(last_size, output_size))
        self.model = nn.Sequential(*sequence)
        self._output_size = hidden_sizes[-1] if output_size is None else output_size

    def forward(self, input: Tensor) -> Tensor:
        """Compute the model on the input, assuming input shape [B,input_size]."""
        return self.model(input)

    @property
    def output_size(self) -> int:
        """Retuns the output size of the model."""
        return self._output_size


def conv2d_output_shape(h, w, kernel_size=1, stride=1, padding=0, dilation=1):
    """Returns output H, W after convolution/pooling on input H, W."""
    kh, kw = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 2
    sh, sw = stride if isinstance(stride, tuple) else (stride,) * 2
    ph, pw = padding if isinstance(padding, tuple) else (padding,) * 2
    d = dilation
    h = (h + (2 * ph) - (d * (kh - 1)) - 1) // sh + 1
    w = (w + (2 * pw) - (d * (kw - 1)) - 1) // sw + 1
    return h, w


class Conv2dModel(nn.Module):
    """2-D Convolutional model component, with option for max-pooling vs
    downsampling for strides > 1.  Requires number of input channels, but
    not input shape.  Uses ``torch.nn.Conv2d``.
    """

    def __init__(
        self,
        in_channels: int,
        channels: Sequence[int],
        kernel_sizes: Sequence[int],
        strides: Sequence[int],
        paddings: Sequence[int] | None = None,
        nonlinearity: type[nn.Module] = nn.ReLU,  # Module, not Functional.
        use_maxpool: bool = False,  # if True: convs use stride 1, maxpool downsample.
    ):
        super().__init__()
        if paddings is None:
            paddings = [0 for _ in range(len(channels))]
        assert len(channels) == len(kernel_sizes) == len(strides) == len(paddings)
        in_channels = [in_channels] + channels[:-1]
        ones = [1 for _ in range(len(strides))]
        if use_maxpool:
            maxp_strides = strides
            strides = ones
        else:
            maxp_strides = ones
        conv_layers = [
            nn.Conv2d(
                in_channels=ic, out_channels=oc, kernel_size=k, stride=s, padding=p
            )
            for (ic, oc, k, s, p) in zip(
                in_channels, channels, kernel_sizes, strides, paddings
            )
        ]
        sequence = list()
        for conv_layer, maxp_stride in zip(conv_layers, maxp_strides):
            sequence.extend([conv_layer, nonlinearity()])
            if maxp_stride > 1:
                sequence.append(nn.MaxPool2d(maxp_stride))  # No padding.
        self.conv = nn.Sequential(*sequence)

    def forward(self, input: Tensor) -> Tensor:
        """Computes the convolution stack on the input; assumes correct shape
        already: [B,C,H,W]."""
        return self.conv(input)

    def conv_out_size(self, h, w, c=None):
        """Helper function ot return the output size for a given input shape,
        without actually performing a forward pass through the model."""
        for child in self.conv.children():
            try:
                h, w = conv2d_output_shape(
                    h, w, child.kernel_size, child.stride, child.padding
                )
            except AttributeError:
                pass  # Not a conv or maxpool layer.
            try:
                c = child.out_channels
            except AttributeError:
                pass  # Not a conv layer.
        return h * w * c


class Conv2dHeadModel(nn.Module):
    """Model component composed of a ``Conv2dModel`` component followed by
    a fully-connected ``MlpModel`` head.  Requires full input image shape to
    instantiate the MLP head.
    """

    def __init__(
        self,
        image_shape: tuple[int, int, int],
        channels: Sequence[int],
        kernel_sizes: Sequence[int],
        strides: Sequence[int],
        hidden_sizes: int | Sequence[int] | None,
        output_size: int | None = None,
        paddings: Sequence[int] | None = None,
        nonlinearity: type[nn.Module] = nn.ReLU,
        use_maxpool: bool = False,
    ):
        super().__init__()
        c, h, w = image_shape
        self.conv = Conv2dModel(
            in_channels=c,
            channels=channels,
            kernel_sizes=kernel_sizes,
            strides=strides,
            paddings=paddings,
            nonlinearity=nonlinearity,
            use_maxpool=use_maxpool,
        )
        conv_out_size = self.conv.conv_out_size(h, w)
        if hidden_sizes or output_size:
            self.head = MlpModel(
                conv_out_size,
                hidden_sizes,
                output_size=output_size,
                hidden_nonlinearity=nonlinearity,
            )
            if output_size is not None:
                self._output_size = output_size
            else:
                self._output_size = (
                    hidden_sizes if isinstance(hidden_sizes, int) else hidden_sizes[-1]
                )
        else:
            self.head = lambda x: x
            self._output_size = conv_out_size

    def forward(self, input: Tensor) -> Tensor:
        """Compute the convolution and fully connected head on the input;
        assumes correct input shape: [B,C,H,W]."""
        return self.head(self.conv(input).view(input.shape[0], -1))

    @property
    def output_size(self) -> int:
        """Returns the final output size after MLP head."""
        return self._output_size


class RunningMeanStdModel(nn.Module):
    """Adapted from OpenAI baselines.  Maintains a running estimate of mean
    and variance of data along each dimension, accessible in the `mean` and
    `var` attributes.  Supports multi-GPU training by all-reducing statistics
    across GPUs."""

    def __init__(self, shape):
        super().__init__()
        self.register_buffer("mean", torch.zeros(shape))
        self.register_buffer("var", torch.ones(shape))
        self.register_buffer("count", torch.zeros(()))
        self.shape = shape

    def update(self, x: Tensor) -> None:
        _, T, B, _ = infer_leading_dims(x, len(self.shape))
        x = x.view(T * B, *self.shape)
        batch_mean = x.mean(dim=0)
        batch_var = x.var(dim=0, unbiased=False)
        batch_count = T * B
        if dist.is_initialized():  # Assume need all-reduce.
            mean_var = torch.stack([batch_mean, batch_var])
            dist.all_reduce(mean_var)
            world_size = dist.get_world_size()
            mean_var /= world_size
            batch_count *= world_size
            batch_mean, batch_var = mean_var[0], mean_var[1]
        if self.count == 0:
            self.mean[:] = batch_mean
            self.var[:] = batch_var
        else:
            delta = batch_mean - self.mean
            total = self.count + batch_count
            self.mean[:] = self.mean + delta * batch_count / total
            m_a = self.var * self.count
            m_b = batch_var * batch_count
            M2 = m_a + m_b + delta**2 * self.count * batch_count / total
            self.var[:] = M2 / total
        self.count += batch_count
