import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from .layers import AttentionPooling, WeightedAveragePooling, MaxPooling


class InpaintingViT(nn.Module):
    def __init__(self, base_model, pooling_strategy='avg', mask_size=32, num_iterations=15):
        """
        Args:
            base_model: a ViT model (from timm).
            pooling_strategy: one of "cls", "avg", "sum", "attention", "weighted_avg", "max".
            mask_size: size of the square patch to reconstruct.
        """
        super().__init__()
        self.vit = base_model
        hidden_size = base_model.embed_dim
        self.num_iterations = num_iterations
        print('Num Iterations is: {}'.format(self.num_iterations))
        self.pooling_strategy = pooling_strategy.lower()
        self.mask_size = mask_size
        
        if self.pooling_strategy == "attention":
            self.attention_pooling = AttentionPooling(d_model=base_model.embed_dim)
        elif self.pooling_strategy == "weighted_avg":
            self.weighted_avg_pooling = WeightedAveragePooling(base_model.patch_embed.num_patches + 1, dtype=torch.float32)
        elif self.pooling_strategy == "max":
            self.max_pooling = MaxPooling()
            
        # Decoder: will be modified externally.
        self.decoder = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
        )

    def forward(self, x):
        """
        x: masked image, shape (B, 3, 224, 224)
        Returns: predicted patch flattened: (B, 3*mask_size*mask_size)
        """
        last_hidden = self.vit.forward_features(x)
        if self.pooling_strategy == 'cls':
            pooled = last_hidden[:, 0, :]
        elif self.pooling_strategy == 'avg':
            pooled = last_hidden[:, 1:, :].mean(dim=1)
        elif self.pooling_strategy == 'sum':
            pooled = last_hidden[:, 1:, :].sum(dim=1)
        elif self.pooling_strategy == "max":
            pooled = self.max_pooling(last_hidden[:, 1:, :])
        elif self.pooling_strategy == "attention":
            pooled = self.attention_pooling(last_hidden) #[:, 1:, :]
        elif self.pooling_strategy == "weighted_avg":
            pooled = self.weighted_avg_pooling(last_hidden) #[:, 1:, :]
        else:
            raise ValueError("Invalid pooling strategy. Choose from 'cls', 'avg', 'sum', 'svd', 'attention', 'weighted_avg', 'max'.")
        patch_flat = self.decoder(pooled)
        return patch_flat



if __name__=="__main__":
    pass