import torch
import torch.nn as nn
from torchvision import models
from timm import create_model  

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

def vit_norm_stats():
    return IMAGENET_MEAN, IMAGENET_STD

class ViTForFL(nn.Module):
    def __init__(self, name="vit_tiny", num_classes=10, pretrained=True,
                 freeze_patch_and_pos=False, dropout=0.0):
        super().__init__()

        if name == "vit_tiny":               
            self.backbone = create_model(
                'vit_tiny_patch16_224',
                pretrained=pretrained,
                num_classes=0,            
                drop_rate=dropout
            )
        elif name == "vit_small":            #
            self.backbone = create_model(
                'vit_small_patch16_224',
                pretrained=pretrained,
                num_classes=0,
                drop_rate=dropout
            )
        elif name == "vit_b_16":             
            self.backbone = create_model(
                'vit_base_patch16_224',
                pretrained=pretrained,
                num_classes=0,
                drop_rate=dropout
            )
        elif name == "vit_b_32":             
            self.backbone = models.vit_b_32(weights=models.ViT_B_32_Weights.IMAGENET1K_V1 if pretrained else None)
            in_dim = self.backbone.heads.head.in_features
            self.backbone.heads.head = nn.Identity()
            self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
            self.classifier = nn.Linear(in_dim, num_classes)
            
            if freeze_patch_and_pos:
                for n, p in self.backbone.named_parameters():
                    if "conv_proj" in n or "encoder.pos_embedding" in n:
                        p.requires_grad = False
            return
        else:
            raise ValueError(f"Unsupported ViT variant: {name}")

        in_dim = self.backbone.num_features 
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        self.classifier = nn.Linear(in_dim, num_classes)

        if freeze_patch_and_pos:
            for n, p in self.backbone.named_parameters():
                if "patch_embed" in n or "pos_embed" in n:
                    p.requires_grad = False

    def forward(self, x):
        """Unified forward method, compatible with timm and torchvision"""
        features = self.backbone(x)
        return self.classifier(self.dropout(features))
