from typing import Dict, List, Union

import torch
from torch import Tensor, nn


def conv3x3(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: nn.Module = None, use_bn: bool = True):
        super().__init__()
        norm = nn.BatchNorm2d if use_bn else nn.Identity
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm(planes)
        self.downsample = downsample

    def forward(self, x: Tensor) -> Tensor:
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample: nn.Module = None, use_bn: bool = True):
        super().__init__()
        norm = nn.BatchNorm2d if use_bn else nn.Identity
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = norm(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = norm(planes)
        self.conv3 = conv1x1(planes, planes * self.expansion)
        self.bn3 = norm(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x: Tensor) -> Tensor:
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)


class ResNet(nn.Module):
    """Implement a ResNet variant tailored for CIFAR-style inputs."""

    def __init__(self, resnet_name: str, num_classes: int = 10, dropout_p: float = 0.5, use_bn: bool = True):
        """
        Args:
            resnet_name: Identifier for the ResNet depth (e.g. ``ResNet18``).
            num_classes: Number of target classes.
            dropout_p: Dropout probability applied before the classifier.
            use_bn: Whether to include batch-normalisation in convolutional blocks.
        """
        super().__init__()
        resnet_configs: Dict[str, Dict[str, Union[List[int], nn.Module]]] = {
            "ResNet18": {"layers": [2, 2, 2, 2], "block": BasicBlock},
            "ResNet34": {"layers": [3, 4, 6, 3], "block": BasicBlock},
            "ResNet50": {"layers": [3, 4, 6, 3], "block": Bottleneck},
        }
        if resnet_name not in resnet_configs:
            raise ValueError(f"ResNet name '{resnet_name}' not recognized. Available: {list(resnet_configs.keys())}")

        cfg = resnet_configs[resnet_name]
        block, layers = cfg["block"], cfg["layers"]

        norm = nn.BatchNorm2d if use_bn else nn.Identity
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = norm(self.inplanes)
        self.relu = nn.ReLU(inplace=True)

        self.layer1 = self._make_layer(block, 64, layers[0], stride=1, use_bn=use_bn)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, use_bn=use_bn)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, use_bn=use_bn)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, use_bn=use_bn)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(p=dropout_p)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        self._initialize_weights()

    def _make_layer(self, block, planes: int, blocks: int, stride: int, use_bn: bool) -> nn.Sequential:
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            norm = nn.BatchNorm2d if use_bn else nn.Identity
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm(planes * block.expansion),
            )

        layers = [block(self.inplanes, planes, stride, downsample, use_bn)]
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, use_bn=use_bn))
        return nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

    def _initialize_weights(self) -> None:
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                if hasattr(m, "weight") and m.weight is not None:
                    nn.init.constant_(m.weight, 1)
                if hasattr(m, "bias") and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
