import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoderLayer, TransformerEncoder

class ConvPreEncoder(nn.Module):
    """
    Lightweight pre-encoder: a small conv stack + resize for image prep.
    """
    def __init__(self, in_channels=3, out_channels=3, output_dim=(224, 224)):
        super().__init__()
        self.output_dim = output_dim
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, out_channels, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_block(x)
        x = F.interpolate(x, size=self.output_dim, mode='bilinear', align_corners=False)
        return x

class LinearHead(nn.Module):
    """
    Linear classification head that consumes [CLS]-like first token.
    """
    def __init__(self, input_dim: int, num_classes: int):
        super().__init__()
        self.output_proj = nn.Linear(input_dim, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [batch, seq_len, input_dim]
        return self.output_proj(x[:, 0])

def get_encoder(size: str, image_dim: int, num_classes: int, label_elmes: bool, orig: bool):
    """
    Returns a one-layer TransformerEncoder + linear projection classifier.

    Args are retained for parity with the caller; only image_dim and num_classes
    are functionally used.
    """
    d_model = image_dim
    encoder_layer = TransformerEncoderLayer(d_model=d_model, nhead=8, batch_first=True)
    encoder = TransformerEncoder(encoder_layer, num_layers=1)
    output_proj = nn.Linear(d_model, num_classes)

    class SimpleEncoder(nn.Module):
        def __init__(self, encoder, output_proj):
            super().__init__()
            self.encoder = encoder
            self.output_proj = output_proj

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            x = self.encoder(x)
            return self.output_proj(x[:, 0])

    return SimpleEncoder(encoder, output_proj)
