from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    import torchvision
except ImportError:
    torchvision = None

Tensor = torch.Tensor


def _handle_time_dim(x: Tensor) -> Tensor:

    if x.dim() == 5:
        x = x.mean(0)
    if x.dim() != 4:
        raise ValueError(f"Expected 4D or 5D tensor, got shape {tuple(x.shape)}")
    return x


# ----------------------------------------------------------------------
# 1) MLP
# ----------------------------------------------------------------------

class MmWaveAnn_MLP(nn.Module):

    def __init__(
        self,
        input_shape: Tuple[int, int, int],
        num_classes: int,
        hidden_dims=(1024, 128),
    ):
        super().__init__()
        c, h, w = input_shape
        in_dim = c * h * w

        layers = []
        prev = in_dim
        for hd in hidden_dims:
            layers += [
                nn.Linear(prev, hd),
                nn.ReLU(inplace=True),
            ]
            prev = hd
        layers.append(nn.Linear(prev, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        x = _handle_time_dim(x)    # (B,C,H,W)
        x = x.flatten(1)           # (B, C*H*W)
        return self.net(x)


# ----------------------------------------------------------------------
# 2)
# ----------------------------------------------------------------------

class MmWaveAnn_LeNet(nn.Module):

    def __init__(self, input_shape: Tuple[int, int, int], num_classes: int):
        super().__init__()
        c, _, _ = input_shape

        self.features = nn.Sequential(
            nn.Conv2d(c, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )

        self.avgpool = nn.AdaptiveAvgPool2d((4, 4))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, num_classes),
        )

    def forward(self, x: Tensor) -> Tensor:
        x = _handle_time_dim(x)
        x = self.features(x)
        x = self.avgpool(x)
        x = self.classifier(x)
        return x


# ----------------------------------------------------------------------
# 3) ResNet 
# ----------------------------------------------------------------------

def _make_resnet_backbone(
    depth: int,
    in_channels: int,
    num_classes: int,
) -> nn.Module:
    if torchvision is None:
        raise RuntimeError("torchvision is required for ResNet models")

    if depth == 18:
        net = torchvision.models.resnet18(weights=None)
    elif depth == 50:
        net = torchvision.models.resnet50(weights=None)
    elif depth == 101:
        net = torchvision.models.resnet101(weights=None)
    else:
        raise ValueError(f"Unsupported ResNet depth: {depth}")

   
    if in_channels != 3:
        net.conv1 = nn.Conv2d(
            in_channels,
            net.conv1.out_channels,
            kernel_size=net.conv1.kernel_size,
            stride=net.conv1.stride,
            padding=net.conv1.padding,
            bias=False,
        )

 
    in_fc = net.fc.in_features
    net.fc = nn.Linear(in_fc, num_classes)
    return net


class MmWaveAnn_ResNet18(nn.Module):
    def __init__(self, input_shape: Tuple[int, int, int], num_classes: int):
        super().__init__()
        c, _, _ = input_shape
        self.backbone = _make_resnet_backbone(18, c, num_classes)

    def forward(self, x: Tensor) -> Tensor:
        x = _handle_time_dim(x)
        return self.backbone(x)


class MmWaveAnn_ResNet50(nn.Module):
    def __init__(self, input_shape: Tuple[int, int, int], num_classes: int):
        super().__init__()
        c, _, _ = input_shape
        self.backbone = _make_resnet_backbone(50, c, num_classes)

    def forward(self, x: Tensor) -> Tensor:
        x = _handle_time_dim(x)
        return self.backbone(x)


class MmWaveAnn_ResNet101(nn.Module):
    def __init__(self, input_shape: Tuple[int, int, int], num_classes: int):
        super().__init__()
        c, _, _ = input_shape
        self.backbone = _make_resnet_backbone(101, c, num_classes)

    def forward(self, x: Tensor) -> Tensor:
        x = _handle_time_dim(x)
        return self.backbone(x)


# ----------------------------------------------------------------------
# 4) RNN / GRU / LSTM / BiLSTM
# ----------------------------------------------------------------------

class _BaseRNN(nn.Module):

    def __init__(
        self,
        input_shape: Tuple[int, int, int],
        num_classes: int,
        cell_type: str = "rnn",
        bidirectional: bool = False,
        hidden_size: int = 128,
        num_layers: int = 1,
    ):
        super().__init__()
        c, h, w = input_shape
        self.h = h
        self.feature_dim = c * w
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        rnn_cls = {
            "rnn": nn.RNN,
            "gru": nn.GRU,
            "lstm": nn.LSTM,
        }[cell_type]

        self.rnn = rnn_cls(
            input_size=self.feature_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bidirectional=bidirectional,
            batch_first=False,   # (T,B,F)
        )
        self.fc = nn.Linear(hidden_size * self.num_directions, num_classes)

    def forward(self, x: Tensor) -> Tensor:
        x = _handle_time_dim(x)          # (B,C,H,W)
        b, c, h, w = x.shape
        if h != self.h:
            self.h = h
            self.feature_dim = c * w

        # (B,C,H,W) -> (H,B,C*W)
        x = x.view(b, c * w, h)          # (B,F,H)
        x = x.permute(2, 0, 1)           # (H,B,F)

        out, _ = self.rnn(x)             # (H,B,Hid*dir)
        last = out[-1]                   # (B,Hid*dir)
        return self.fc(last)


class MmWaveAnn_RNN(_BaseRNN):
    def __init__(
        self,
        input_shape: Tuple[int, int, int],
        num_classes: int,
        hidden_size: int = 128,
        num_layers: int = 1,
    ):
        super().__init__(
            input_shape,
            num_classes,
            cell_type="rnn",
            bidirectional=False,
            hidden_size=hidden_size,
            num_layers=num_layers,
        )


class MmWaveAnn_GRU(_BaseRNN):
    def __init__(
        self,
        input_shape: Tuple[int, int, int],
        num_classes: int,
        hidden_size: int = 128,
        num_layers: int = 1,
    ):
        super().__init__(
            input_shape,
            num_classes,
            cell_type="gru",
            bidirectional=False,
            hidden_size=hidden_size,
            num_layers=num_layers,
        )


class MmWaveAnn_LSTM(_BaseRNN):
    def __init__(
        self,
        input_shape: Tuple[int, int, int],
        num_classes: int,
        hidden_size: int = 128,
        num_layers: int = 1,
    ):
        super().__init__(
            input_shape,
            num_classes,
            cell_type="lstm",
            bidirectional=False,
            hidden_size=hidden_size,
            num_layers=num_layers,
        )


class MmWaveAnn_BiLSTM(_BaseRNN):
    def __init__(
        self,
        input_shape: Tuple[int, int, int],
        num_classes: int,
        hidden_size: int = 128,
        num_layers: int = 1,
    ):
        super().__init__(
            input_shape,
            num_classes,
            cell_type="lstm",
            bidirectional=True,
            hidden_size=hidden_size,
            num_layers=num_layers,
        )


# ----------------------------------------------------------------------
# 5) CNN + GRU 
# ----------------------------------------------------------------------

class MmWaveAnn_CNN_GRU(nn.Module):

    def __init__(
        self,
        input_shape: Tuple[int, int, int],
        num_classes: int,
        hidden_size: int = 128,
        num_layers: int = 1,
    ):
        super().__init__()
        c, h, w = input_shape
        self.h = h

        self.cnn = nn.Sequential(
            nn.Conv2d(c, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )

 
        self.gru = nn.GRU(
            input_size=64 * max(1, w // 4),
            hidden_size=hidden_size,
            num_layers=num_layers,
            bidirectional=False,
            batch_first=False,      # (T,B,F)
        )
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x: Tensor) -> Tensor:
        x = _handle_time_dim(x)         # (B,C,H,W)
        b, _, h, w = x.shape
        x = self.cnn(x)                 # (B,64,Hc,Wc)
        b, c2, hc, wc = x.shape

        x = x.view(b, c2 * wc, hc)      # (B,F,Hc)
        x = x.permute(2, 0, 1)          # (Hc,B,F)

        out, _ = self.gru(x)            # (Hc,B,Hid)
        last = out[-1]                  # (B,Hid)
        return self.fc(last)




class PatchEmbedding(nn.Module):
    def __init__(
        self,
        in_channels: int,
        emb_size: int,
        patch_size: Tuple[int, int],
    ):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(
            in_channels,
            emb_size,
            kernel_size=patch_size,
            stride=patch_size,
        )

    def forward(self, x: Tensor) -> Tensor:
        # x: (B,C,H,W)
        x = self.proj(x)                 # (B,emb,H',W')
        b, e, h, w = x.shape
        x = x.flatten(2).transpose(1, 2) # (B, N, emb)
        return x


class TransformerEncoder(nn.Module):
    def __init__(
        self,
        emb_size: int,
        depth: int = 4,
        nhead: int = 4,
        dim_feedforward: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=emb_size,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=depth)

    def forward(self, x: Tensor) -> Tensor:
        return self.encoder(x)          # (B,N,emb)


class MmWaveAnn_ViT(nn.Module):
  
    def __init__(
        self,
        input_shape: Tuple[int, int, int],
        num_classes: int,
        emb_size: int = 256,
        depth: int = 4,
        nhead: int = 4,
        patch_size: Tuple[int, int] = None,
    ):
        super().__init__()
        c, h, w = input_shape

        if patch_size is None:
            ph = max(4, h // 10)
            pw = max(4, w // 10)
            patch_size = (ph, pw)

        self.patch_embed = PatchEmbedding(c, emb_size, patch_size)
        self.encoder = TransformerEncoder(
            emb_size,
            depth=depth,
            nhead=nhead,
        )
        self.cls_head = nn.Linear(emb_size, num_classes)

    def forward(self, x: Tensor) -> Tensor:
        x = _handle_time_dim(x)         # (B,C,H,W)
        x = self.patch_embed(x)         # (B,N,emb)
        x = self.encoder(x)             # (B,N,emb)
        x = x.mean(dim=1)               # patch 上平均池化
        return self.cls_head(x)


class MmWaveAnn_LeNet5(nn.Module):

    def __init__(self, input_shape, num_classes):
        super().__init__()

        assert len(input_shape) == 3, " (C, H, W)"
        in_channels, H, W = input_shape

        self.num_classes = num_classes
        self.in_channels = in_channels

 
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 6, kernel_size=5),  # (C, H, W) -> (6, H-4, W-4)
            nn.Tanh(),
            nn.MaxPool2d(kernel_size=2),               # -> (6, (H-4)/2, (W-4)/2)
            nn.Conv2d(6, 16, kernel_size=5),           # -> (16, ...)
            nn.Tanh(),
            nn.MaxPool2d(kernel_size=2)                # -> (16, ...)
        )

 
        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, H, W)
            feat = self.features(dummy)             # (1, 16, H', W')
            flatten_dim = feat.view(1, -1).size(1)  # = 16 * H' * W'

        self.classifier = nn.Sequential(
            nn.Linear(flatten_dim, 120),
            nn.Tanh(),
            nn.Linear(120, 84),
            nn.Tanh(),
            nn.Linear(84, num_classes),
        )

    def forward(self, x):

        x = self.features(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        probas = F.softmax(logits, dim=1)
        return logits, probas

