import torch
from torch import nn, Tensor

from .unet_parts import DoubleConv, Down, OutConv


class ResBlock1D(nn.Module):
    # See https://arxiv.org/pdf/2002.04745v1 for the LN position
    def __init__(
        self,
        in_channels: int,
        mid_channels: int,
        out_channels: int,
        activation: nn.Module = nn.GELU(),
        dropout: float = 0.1,
    ):
        super().__init__()
        self.norm = nn.LayerNorm(in_channels)
        self.block = nn.Sequential(
            nn.Linear(in_channels, mid_channels),
            activation,
            nn.Dropout(dropout),
            nn.Linear(mid_channels, out_channels),
            nn.Dropout(dropout),
        )

        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Linear(in_channels, out_channels),
                nn.Dropout(dropout),
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x0):
        x = self.block(self.norm(x0))
        x = x + self.shortcut(x0)
        return x


class ResBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        activation: nn.Module = nn.GELU(),
        downsample: bool = False,
    ):
        super(ResBlock, self).__init__()

        if downsample:
            self.conv_block = Down(in_channels, out_channels)
        else:
            self.conv_block = DoubleConv(in_channels, out_channels)

        if in_channels != out_channels or downsample:
            stride = 2 if downsample else 1
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels, kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(out_channels),
            )
        else:
            self.shortcut = nn.Identity()

        self.activation = activation

    def forward(self, x0):
        x = self.conv_block(x0)
        x = x + self.shortcut(x0)
        x = self.activation(x)
        return x


class ResNetBackbone(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        layers: list = [2, 2, 2],
        init_features: int = 16,
    ):
        super().__init__()

        blocks = [DoubleConv(in_channels, init_features)]
        in_features = init_features
        downsample = False
        for n_layers in layers:
            for _ in range(n_layers):
                out_features = 2 * in_features if downsample else in_features
                block = ResBlock(in_features, out_features, downsample=downsample)
                blocks.append(block)
                downsample = False
                in_features = out_features
            downsample = True
        blocks.append(OutConv(in_features, out_channels))
        self.net = nn.Sequential(*blocks)

    def forward(self, x) -> Tensor:
        x = self.net(x)
        return x
