import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from .nn import (conv_nd, linear, normalization, timestep_embedding,
                 torch_checkpoint, zero_module)


class ResNetFaceEmbedding(nn.Module):
    def __init__(self, backbone='resnet18', out_channels=256, normalization=normalization):
        super().__init__()

        # Load backbone without classifier
        if backbone == 'resnet18':
            resnet = models.resnet18()
            ch = resnet.fc.in_features  # usually 512
        elif backbone == 'resnet50':
            resnet = models.resnet50()
            ch = resnet.fc.in_features  # usually 2048
        else:
            raise ValueError("Unsupported backbone")

        # Remove the final fc and avgpool
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])  # Keep feature map

        # Your output head
        self.out = nn.Sequential(
            normalization(ch),
            nn.SiLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(ch, out_channels, kernel_size=1),
            nn.Flatten()
        )

    def forward(self, x):
        x = self.backbone(x)  # Output shape: [B, C, H, W]
        x = self.out(x)       # Output shape: [B, out_channels]
        return x
