import math
from typing import Callable, Optional, Sequence, Union

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import init
from torch.nn.common_types import _size_2_t
from torch.nn.modules.utils import _pair, _reverse_repeat_tuple
from torch.nn.parameter import Parameter


class BatchedConv2D(nn.Module):
    def __init__(
        self,
        n_envs: int,
        in_channels: int,
        out_channels: int,
        kernel_size: _size_2_t,
        stride: _size_2_t = 1,
        padding: Union[str, _size_2_t] = 0,
        dilation: _size_2_t = 1,
        bias: bool = True,
        padding_mode: str = "zeros",  # TODO: refine this type
        device=None,
        dtype=None,
    ):
        super().__init__()
        self.n_envs = n_envs

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = padding if isinstance(padding, str) else _pair(padding)
        self.dilation = _pair(dilation)
        self.padding_mode = padding_mode
        self.groups = 1

        factory_kwargs = {"device": device, "dtype": dtype}

        # `_reversed_padding_repeated_twice` is the padding to be passed to
        # `F.pad` if needed (e.g., for non-zero padding types that are
        # implemented as two ops: padding + conv). `F.pad` accepts paddings in
        # reverse order than the dimension.
        if isinstance(self.padding, str):
            self._reversed_padding_repeated_twice = [0, 0] * len(self.kernel_size)
            if padding == "same":
                for d, k, i in zip(self.dilation, self.kernel_size, range(len(self.kernel_size) - 1, -1, -1)):
                    total_padding = d * (k - 1)
                    left_pad = total_padding // 2
                    self._reversed_padding_repeated_twice[2 * i] = left_pad
                    self._reversed_padding_repeated_twice[2 * i + 1] = total_padding - left_pad
        else:
            self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2)

        self.weight = Parameter(torch.empty((n_envs, out_channels, in_channels, *self.kernel_size), **factory_kwargs))
        if bias:
            self.bias = Parameter(torch.empty(n_envs, out_channels, **factory_kwargs))
        else:
            self.register_parameter("bias", None)

        self._nenv_conv_forward: Callable = torch.vmap(self._conv_forward)

    def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
        if self.padding_mode != "zeros":
            return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), weight, bias, self.stride, _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)

    def forward(self, input: Tensor) -> Tensor:
        return self._nenv_conv_forward(input, self.weight, self.bias)


class BatchedLinear(nn.Module):
    def __init__(
        self,
        n_envs: int,
        in_features: int,
        out_features: int,
        bias: bool = True,
        device=None,
        dtype=None,
    ) -> None:
        super().__init__()

        self.n_envs = n_envs

        self.in_features = in_features
        self.out_features = out_features

        factory_kwargs = {"device": device, "dtype": dtype}
        self.weight = Parameter(torch.empty((n_envs, out_features, in_features), **factory_kwargs))
        if bias:
            self.bias = Parameter(torch.empty(n_envs, out_features, **factory_kwargs))
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()

        def env_mm(in1: Tensor, in2: Tensor) -> Tensor:
            return torch.matmul(in1, in2)

        self.nenv_bmm = torch.vmap(env_mm)

    def reset_parameters(self) -> None:
        for i in range(self.n_envs):
            init.kaiming_uniform_(self.weight[i], a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight[0])
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input: Tensor) -> Tensor:
        return self.nenv_bmm(self.weight, input) + self.bias[..., None]

    def extra_repr(self) -> str:
        return f"n_envs={self.n_envs}, in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"


class PrintModule(nn.Module):
    def __init__(self, text: str = "") -> None:
        super().__init__()
        self.text = text

    def forward(self, x: Tensor) -> Tensor:
        print(self.text, x.shape)
        return x


class Unsqueeze(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        return x.unsqueeze(self.dim)


class Reshape(nn.Module):
    def __init__(self, shape: Sequence[int]) -> None:
        super().__init__()
        self.shape = shape

    def forward(self, x: Tensor):
        return x.reshape(self.shape)


class SwapDims(nn.Module):
    def __init__(self, dim0: int, dim1: int):
        super().__init__()
        self.dim0 = dim0
        self.dim1 = dim1

    def forward(self, x: Tensor) -> Tensor:
        return x.swapdims(self.dim0, self.dim1)
