from typing import List
import torch
import torch.nn as nn
from .swinv2_model import SwinTransformerV2

class ClassificationModelWrapper(nn.Module):
    """
    Wraps a Swin Transformer V2 model to perform image classification.
    """

    def __init__(self, model: SwinTransformerV2, number_of_classes: int = 10, output_channels: int = 768) -> None:
        """
        Constructor method
        :param model: (SwinTransformerV2) Swin Transformer V2 model
        :param number_of_classes: (int) Number of classes to predict
        :param output_channels: (int) Output channels of the last feature map of the Swin Transformer V2 model
        """
        # Call super constructor
        super(ClassificationModelWrapper, self).__init__()
        # Save model
        self.model: SwinTransformerV2 = model
        # Init adaptive average pooling layer
        self.pooling: nn.Module = nn.AdaptiveAvgPool2d(1)
        # Init classification head
        self.classification_head: nn.Module = nn.Linear(in_features=output_channels, out_features=number_of_classes)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """
        Forward pass
        :param input: (torch.Tensor) Input tensor of the shape [batch size, channels, height, width]
        :return: (torch.Tensor) Output classification of the shape [batch size, number of classes]
        """
        # Compute features
        features: List[torch.Tensor] = self.model(input)
        # Compute classification
        classification: torch.Tensor = self.classification_head(self.pooling(features[-1]).flatten(start_dim=1))
        return classification