import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from .layers import AttentionPooling, WeightedAveragePooling, MaxPooling

class PostPoolingDetector(nn.Module):
    def __init__(self, base_model, num_classes, pooling_type='avg', num_iterations=10):
        """
        Args:
            base_model (nn.Module): Pre-trained ViT model that outputs token features.
            num_classes (int): Number of object classes (20 for Pascal VOC).
            pooling_type (str): One of 'cls', 'avg', 'sum', 'attention', 'weighted_avg', 'max'.
        """
        super().__init__()
        self.base_model = base_model
        self.pooling_type = pooling_type.lower()
        self.classifier = nn.Linear(base_model.embed_dim, num_classes)
        self.bbox_regressor = nn.Linear(base_model.embed_dim, 4)
        self.num_iterations = num_iterations
        print("Number of iterations is {}" .format(self.num_iterations))

        if self.pooling_type == "attention":
            self.attention_pooling = AttentionPooling(d_model=base_model.embed_dim)
        elif self.pooling_type == "weighted_avg":
            n_tokens = base_model.patch_embed.num_patches + 1
            self.weighted_avg_pooling = WeightedAveragePooling(n_tokens, dtype=torch.float32)
        elif self.pooling_type == "max":
            self.max_pooling = MaxPooling()

    def forward(self, x):
        tokens = self.base_model.forward_features(x)
        if self.pooling_type == 'cls':
            pooled = tokens[:, 0, :]
        elif self.pooling_type == 'avg':
            pooled = tokens[:, 1:, :].mean(dim=1)
        elif self.pooling_type == 'sum':
            pooled = tokens[:, 1:, :].sum(dim=1)
        elif self.pooling_type == "max":
            pooled = self.max_pooling(tokens[:, 1:, :])
        elif self.pooling_type == "attention":
            pooled = self.attention_pooling(tokens)
        elif self.pooling_type == "weighted_avg":
            pooled = self.weighted_avg_pooling(tokens)

        else:
            raise ValueError("Unsupported pooling_type. Choose from 'cls', 'avg', 'sum', 'attention', 'weighted_avg', or 'max'.")
        cls_logits = self.classifier(pooled)
        bbox_preds = self.bbox_regressor(pooled)
        bbox_preds = torch.sigmoid(bbox_preds)
        return cls_logits, bbox_preds