import torch
import torch.nn as nn
import torchvision.models as models

class ResNet18(nn.Module):

    def __init__(self, latent_dim=512, num_classes=10, input_channels=3, pretrained=False):

        super(ResNet18, self).__init__()
        
        # Load a pre-trained ResNet-18 model
        resnet = models.resnet18(pretrained=pretrained)
        
        # Modify the first convolutional layer if the number of input channels is not 3
        if input_channels != 3:
            # The original conv1
            conv1_orig = resnet.conv1
            # Create a new conv1 with the desired number of input channels
            resnet.conv1 = nn.Conv2d(input_channels, conv1_orig.out_channels, 
                                     kernel_size=conv1_orig.kernel_size, stride=conv1_orig.stride, 
                                     padding=conv1_orig.padding, bias=conv1_orig.bias)
            # Note: The new conv1 layer has randomly initialized weights.

        # The encoder part is the ResNet backbone without the final FC layer
        self.encoder = nn.Sequential(*list(resnet.children())[:-1])
        
        # A fully connected layer to map the ResNet features to the desired latent dimension
        self.feature_layer = nn.Linear(resnet.fc.in_features, latent_dim)
        
        # The classifier part
        self.classifier = nn.Linear(latent_dim, num_classes)

    def forward(self, x, return_features=False):

        # Part 1: Feature Extraction
        features_raw = self.encoder(x)
        features_raw = torch.flatten(features_raw, 1)
        features = self.feature_layer(features_raw)
        
        if return_features:
            return features
        
        # Part 2: Classification
        logits = self.classifier(features)
        return logits

if __name__ == '__main__':
    # --- Example 1: Standard classification ---
    print("--- Example 1: Standard classification ---")
    # Model with 3 input channels (like RGB images) and 100 classes
    model_rgb = ResNet18(num_classes=100, input_channels=3, pretrained=True)
    input_rgb = torch.randn(16, 3, 224, 224)
    
    # Get classification logits
    logits = model_rgb(input_rgb)
    print(f"Shape of the output logits for RGB input: {logits.shape}") # Expected: (16, 100)
    
    # --- Example 2: Feature extraction ---
    print("\n--- Example 2: Feature extraction ---")
    # Get the feature vector
    features = model_rgb(input_rgb, return_features=True)
    print(f"Shape of the feature vector: {features.shape}") # Expected: (16, 512)
    
    # --- Example 3: Custom input channels ---
    print("\n--- Example 3: Custom input channels ---")
    # Model for grayscale images (1 input channel)
    # Note: pretrained=True will still load weights for all layers except the first one.
    model_grayscale = ResNet18(num_classes=10, input_channels=1, pretrained=True)
    input_grayscale = torch.randn(16, 1, 224, 224)
    logits_grayscale = model_grayscale(input_grayscale)
    print(f"Shape of output logits for grayscale input: {logits_grayscale.shape}") # Expected: (16, 10) 