import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import timm
from config import Config

class MaskedResNet(nn.Module):
    """ 
    1. Neighborhood padding at first layer.
    2. Layer-wise masking (simplified as propagating masked input).
    """
    def __init__(self, model_name='resnet50'):
        super().__init__()

        self.backbone = models.resnet50(pretrained=True)
        self.backbone.eval()

        self.features = {}
        
        self.layer_name = Config.CNN_LAYER
        getattr(self.backbone, self.layer_name).register_forward_hook(self._get_hook(self.layer_name))

    def _get_hook(self, name):
        def hook(model, input, output):
            self.features[name] = output
        return hook

    def neighborhood_padding(self, x, mask):
        
        k = Config.PADDING_KERNEL_SIZE
        pad = k // 2
        

        x_masked = x * mask

        mask_float = mask.float()
        ones = torch.ones_like(mask_float)
       
        conv_sum = F.avg_pool2d(x_masked, kernel_size=k, stride=1, padding=pad, count_include_pad=False) * (k*k)
        conv_count = F.avg_pool2d(mask_float, kernel_size=k, stride=1, padding=pad, count_include_pad=False) * (k*k)
        

        conv_count[conv_count == 0] = 1
        
        filled = conv_sum / conv_count
        

        out = x * mask + filled * (1 - mask)
        return out

    def forward(self, x, mask=None):
        
        self.features = {}
        
        if mask is not None:
            if mask.shape[-2:] != x.shape[-2:]:
                mask = F.interpolate(mask, size=x.shape[-2:], mode='nearest')
                
            x_processed = self.neighborhood_padding(x, mask)
            
         
            logits = self.backbone(x_processed)
            
            return logits, self.features[self.layer_name]
        else:
            logits = self.backbone(x)
            return logits, self.features[self.layer_name]


class MaskedViT(nn.Module):
    """
    Intermediate layer patch averaging.
    """
    def __init__(self, model_name='vit_base_patch16_224'):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=True)
        self.backbone.eval()
        self.target_layer_idx = Config.VIT_LAYER_INDEX
        self.features = None
        
        # Register hook
        self.backbone.blocks[self.target_layer_idx].register_forward_hook(self._hook)

    def _hook(self, module, input, output):
        # Output of a ViT block is [Batch, Tokens, Dim]
        self.features = output

    def forward(self, x):
        self.features = None
        logits = self.backbone(x)
        return logits, self.features

    def get_patch_embeddings(self, x, mask):
        """
        Extracts embeddings for ViT by averaging unmasked patches.
        """
        # Forward to populate self.features
        _ = self.backbone(x)
        
        # Features: [B, N_tokens, Dim]
        # N_tokens includes CLS token usually (index 0)
        feats = self.features
        
        cls_token = feats[:, 0, :]
        patch_tokens = feats[:, 1:, :] # [B, 196, Dim] for 224x224 patch16
        
        B, N, D = patch_tokens.shape
        H_grid = int(N**0.5) # e.g. 14
        
        # Resize mask to patch grid
        # Mask: [B, 1, H_img, W_img] -> [B, 1, 14, 14]
        mask_grid = F.interpolate(mask, size=(H_grid, H_grid), mode='nearest')
        mask_flat = mask_grid.reshape(B, N, 1) # [B, 196, 1]
        
        # Apply mask
        masked_tokens = patch_tokens * mask_flat
        
        # Average only unmasked tokens
        sum_tokens = masked_tokens.sum(dim=1)
        count_tokens = mask_flat.sum(dim=1)
        count_tokens[count_tokens == 0] = 1 # Avoid NaN
        
        avg_embedding = sum_tokens / count_tokens # [B, D]
        
        return avg_embedding