import math
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F


class CNNDecoder(nn.Module):
    """
    Universal decoder for ViT features that reconstructs original images.
    Can handle different ViT architectures (Tiny, Small, Base) and intermediate features.
    
    Args:
        in_dim (int): Dimension of input features (192 for Tiny, 384 for Small, 768 for Base)
        img_size (int): Target image size (default: 224)
        patch_size (int): Patch size used in ViT (default: 16)
        out_channels (int): Number of output channels (default: 3 for RGB)
        hidden_dims (List[int], optional): Custom hidden dimensions for decoder layers.
                                         If None, automatically determined based on in_dim.
    """
    def __init__(self, 
                 in_dim: int,
                 img_size: int = 224,
                 patch_size: int = 16,
                 out_channels: int = 3,
                 hidden_dims: List[int] = None):
        super().__init__()
        
        # Calculate number of patches and feature map size
        self.num_patches = (img_size // patch_size) ** 2
        self.feature_size = int(math.sqrt(self.num_patches))  # e.g., 14 for 224/16
        self.target_size = img_size
        self.pre_norm = nn.LayerNorm(in_dim)
        
        # For 224/16=14 feature size, need 4 upsampling steps total to reach 224:
        # 14->28->56->112->224
        self.num_upsamples = int(math.log2(img_size // self.feature_size))
        
        # Automatically determine hidden dimensions if not provided
        if hidden_dims is None:
            # Scale base dimension with input dimension
            if in_dim <= 192:  # ViT-Tiny
                base_dim = 128
                extra_layers = 0
            elif in_dim <= 384:  # ViT-Small
                base_dim = 256
                extra_layers = 1
            else:  # ViT-Base
                base_dim = 512
                extra_layers = 2
            
            # Start from in_dim and gradually decrease channels
            hidden_dims = [in_dim]
            
            # Add extra processing layers at high dimensions
            current_dim = in_dim
            for _ in range(extra_layers):
                current_dim = current_dim // 2
                hidden_dims.append(current_dim)
            
            # Add base_dim and subsequent layers
            hidden_dims.append(base_dim)
            current_dim = base_dim
            for _ in range(self.num_upsamples - 1):  # One less since final conv handles last upsample
                current_dim = max(current_dim // 2, 32)  # Ensure minimum of 32 channels
                hidden_dims.append(current_dim)
        
        # Build decoder layers
        self.layers = nn.ModuleList()
        for i in range(len(hidden_dims) - 1):
            # Add residual blocks for larger models
            if hidden_dims[i] >= 384:
                self.layers.append(nn.Sequential(
                    nn.Conv2d(hidden_dims[i], hidden_dims[i + 1], 3, padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.ReLU(inplace=True)
                ))
            else:
                self.layers.append(nn.Sequential(
                    nn.Conv2d(hidden_dims[i], hidden_dims[i + 1], 3, padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.ReLU(inplace=True)
                ))
            
        # Final output layer
        self.final_conv = nn.Sequential(
            nn.Conv2d(hidden_dims[-1], out_channels, 3, padding=1),
            nn.Tanh()  # Normalize output to [-1, 1]
        )
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
            
    def forward(self, x, return_intermediates=False):
        """
        Forward pass of the decoder.
        
        Args:
            x: Input tensor of shape [B, N, C] where:
               B = batch size
               N = number of patches
               C = embedding dimension
            return_intermediates: If True, returns intermediate feature maps
            
        Returns:
            Reconstructed image of shape [B, out_channels, img_size, img_size]
            If return_intermediates=True, also returns list of intermediate features
        """
        intermediates = []
        x = x[:, 1:]
        x = self.pre_norm(x)
        # Reshape from [B, N, C] to [B, C, H, W]
        B, N, C = x.shape
        H = W = int(math.sqrt(N))
        x = x.permute(0, 2, 1).reshape(B, C, H, W)
        
        # Apply decoder layers with upsampling
        for layer in self.layers:
            x = layer(x)
            x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
            if return_intermediates:
                intermediates.append(x)
        
        # Final convolution
        x = self.final_conv(x)
        # Ensure output size matches target size
        if x.size(-1) != self.target_size:
            x = F.interpolate(x, size=self.target_size, mode='bilinear', align_corners=False)
        
        if return_intermediates:
            return x, intermediates
        return x

    @staticmethod
    def get_default_hidden_dims(in_dim: int) -> List[int]:
        """
        Get default hidden dimensions for a given input dimension.
        Useful for initializing decoder with custom architectures.
        """
        base_dim = max(in_dim // 4, 64)
        dims = [in_dim, base_dim]
        current_dim = base_dim
        while current_dim > 32:
            current_dim = current_dim // 2
            dims.append(current_dim)
        return dims
    