import torchvision
import torch
import torch.nn as nn

class ImagenetModel(nn.Module):
    def __init__(self, n_outputs, output_activation='linear', backbone='resnet50'):
        super(ImagenetModel, self).__init__()
        self.backbone = getattr(torchvision.models, backbone)()

        if 'resnet' in backbone:
            feature_dim = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()
        elif 'densenet' in backbone:
            feature_dim = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Identity()
        elif 'vit' in backbone:
            feature_dim = self.backbone.heads.head.in_features
            self.backbone.heads = nn.Identity()
        elif 'swin' in backbone:
            feature_dim = self.backbone.head.in_features
            self.backbone.head = nn.Identity()
        else:
            raise ValueError(f'Unsupported backbone: {backbone}')
        
        self.output_layer = nn.Linear(feature_dim, n_outputs)
        if output_activation == 'softmax':
            self.output_activation = nn.Softmax(dim=1)
        elif output_activation == 'sigmoid':
            self.output_activation = nn.Sigmoid()
        else:
            self.output_activation = nn.Identity()

    def forward(self, x):
        x = self.backbone(x)
        # For ResNet, ViT, DenseNet: flatten features
        if isinstance(self.backbone, torchvision.models.ResNet) or \
           isinstance(self.backbone, torchvision.models.VisionTransformer) or \
           isinstance(self.backbone, torchvision.models.DenseNet):
            x = torch.flatten(x, 1)
        # For Swin: x is already of shape (B, feature_dim)
        x = self.output_layer(x)
        return self.output_activation(x)


class MNISTModel(nn.Module):
    def __init__(self, n_outputs, output_activation='linear'):
        """
        CNN to extract features from MNIST images for survival prediction.
        
        Args:
            feature_dim (int): Size of the final feature vector output.
        """
        super().__init__()

        if output_activation == 'softmax':
            output_act = nn.Softmax(dim=1)
        elif output_activation == 'sigmoid':
            output_act = nn.Sigmoid()
        else:
            output_act = nn.Identity()

        self.convlayers = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1, stride=2), 
            nn.ReLU(),
            nn.Dropout2d(0.1),
            nn.Conv2d(16, 32, kernel_size=3, padding=1, stride=2),
            nn.ReLU(),
            nn.Dropout2d(0.1),
            nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2),
            nn.ReLU(),
            nn.Dropout2d(0.1),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(32, n_outputs),
            output_act
        )

    def forward(self, x):
        out = self.convlayers(x).squeeze(-1)
        return out



import torch
import torch.nn as nn

class MNIST3DModel(nn.Module):
    def __init__(self, n_outputs, output_activation='linear', in_channels=1,
                 mlp_hidden=(64, 32), conv_dropout=0.1, mlp_dropout=0.2,
                 pool_size=(4, 4, 4)):
        """
        Tiny 3D CNN for volumes like OrganMNIST3D (default 28x28x28).

        Args:
            n_outputs (int): output dim (e.g., 1 for Cox log-risk, K for classes/bins).
            output_activation (str): 'linear' | 'sigmoid' | 'softmax'
            in_channels (int): 1 for grayscale volumes (default), or 3 if you replicated.
            mlp_hidden (tuple): hidden sizes for the two FC layers.
            conv_dropout (float): dropout prob after conv blocks (3D).
            mlp_dropout (float): dropout prob in the MLP.
            pool_size (tuple): adaptive pool to (D,H,W) before flatten, defaults to (4,4,4).
        """
        super().__init__()

        if output_activation == 'softmax':
            output_act = nn.Softmax(dim=1)
        elif output_activation == 'sigmoid':
            output_act = nn.Sigmoid()
        else:
            output_act = nn.Identity()

        C1, C2, C3 = 16, 32, 64

        self.convlayers = nn.Sequential(
            nn.Conv3d(in_channels, C1, kernel_size=3, padding=1, stride=2),
            nn.ReLU(inplace=True),
            nn.Dropout3d(p=conv_dropout),

            nn.Conv3d(C1, C2, kernel_size=3, padding=1, stride=2),
            nn.ReLU(inplace=True),
            nn.Dropout3d(p=conv_dropout),

            nn.Conv3d(C2, C3, kernel_size=3, padding=1, stride=2),
            nn.ReLU(inplace=True),
            nn.Dropout3d(p=conv_dropout),

            # Make the flattened size stable across inputs (e.g., 28→14→7→4 each dim)
            nn.AdaptiveAvgPool3d(output_size=pool_size),
            nn.Flatten()
        )

        flat_dim = C3 * pool_size[0] * pool_size[1] * pool_size[2]
        h1, h2 = mlp_hidden

        self.head = nn.Sequential(
            nn.Linear(flat_dim, h1),
            nn.ReLU(inplace=True),
            nn.Dropout(p=mlp_dropout),

            nn.Linear(h1, h2),
            nn.ReLU(inplace=True),
            nn.Dropout(p=mlp_dropout),

            nn.Linear(h2, n_outputs),
            output_act
        )

        # Kaiming init for conv layers
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):
        # x: [B, C, D, H, W]
        x = self.convlayers(x)
        x = self.head(x)
        return x  # [B, n_outputs]
