import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from .layers import AttentionPooling, WeightedAveragePooling

class PostPoolingClassifier(nn.Module):
    def __init__(self, base_model, num_classes, pooling_type='avg', num_iterations=10):
        """
        Classification model based on a ViT Backbone
        Args:
            base_model: a ViT model (from timm).
            pooling_strategy: one of "cls", "avg", "sum", "attention", "weighted_avg", "max".
        """
        super().__init__()
        self.base_model = base_model
        self.pooling_type = pooling_type
        self.classifier = nn.Linear(base_model.embed_dim, num_classes)
        self.num_iterations = num_iterations
        if self.pooling_type == "attention":
            self.attention_pooling = AttentionPooling(d_model=base_model.embed_dim)
        elif self.pooling_type == "weighted_avg":
            n_tokens = base_model.patch_embed.num_patches + 1
            self.weighted_avg_pooling = WeightedAveragePooling(n_tokens, dtype=torch.float32)
    def forward(self, x):
        tokens = self.base_model.forward_features(x)
        if self.pooling_type == 'cls':
            pooled = tokens[:, 0, :]
        elif self.pooling_type == 'avg':
            pooled = tokens[:, 1:, :].mean(dim=1)
        elif self.pooling_type == 'sum':
            pooled = tokens[:, 1:, :].sum(dim=1)
        elif self.pooling_type == "attention":
            pooled = self.attention_pooling(tokens)
        elif self.pooling_type == 'weighted_avg':
            pooled = self.weighted_avg_pooling(tokens) 
        elif self.pooling_type == "max":
            pooled, _ = tokens[:, 1:, :].max(dim=1)
        else:
            raise ValueError("Unsupported pooling_type. Choose from 'cls', 'avg', 'sum', 'attention', 'weighted_avg', or 'max'.")
        out = self.classifier(pooled)
        return out
    

if __name__=="__main__":
    pass