import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Tuple

BIAS = 10.0

class BaseLambdaNetwork(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

class SimpleLambdaNetwork(BaseLambdaNetwork):
    def __init__(self, model_orig: nn.Module, clip_model_name: str, hidden_dim: int = -1, lagrangian_type: str = 'scalar'):
        super().__init__()
        self.lagrangian_type = lagrangian_type
        
        if lagrangian_type == 'vector':
            embedding_size = LinearLambdaNetwork.EMBEDDING_SIZES.get(clip_model_name, 512)
            self.lambda_param = nn.Parameter(torch.zeros(1, embedding_size))
            self.bias_param = nn.Parameter(torch.full((1, embedding_size), BIAS))
        else:  # scalar
            self.lambda_param = nn.Parameter(torch.tensor(0.0))
            self.bias_param = nn.Parameter(torch.tensor(BIAS))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.shape[0]
        lambda_value = self.lambda_param
        output = lambda_value.expand(batch_size) + self.bias_param
        return output + (output.relu() - output).detach()

class LinearLambdaNetwork(BaseLambdaNetwork):
    EMBEDDING_SIZES: Dict[str, int] = {
        'ViT-B-32': 512,
        'ViT-B-32-quickgelu': 512,
        'ViT-B-16': 512,
        'ViT-L-14': 768,
        'ViT-L-14-336': 768,
        'hf-hub:laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg': 640,
        'ViT-B-16-laion2B': 512,  
        'ViT-B-32-laion2B': 512,  
        'dinov2_vitb14_reg_lc': 3840,
        'dinov2_vits14_reg_lc': 1920,
    }
    
    def __init__(self, model_orig: nn.Module, clip_model_name: str, lagrangian_type: str = 'scalar'):
        super().__init__()
        
        if clip_model_name not in self.EMBEDDING_SIZES:
            print(f'Warning: Unsupported model "{clip_model_name}". Defaulting to embedding size 512.')
        
        self.embedding_size = self.EMBEDDING_SIZES.get(clip_model_name, 512)
        self.lagrangian_type = lagrangian_type
        self.backbone = self._setup_backbone(model_orig)
        
        if lagrangian_type == 'vector':
            # For vector output, project to the full embedding dimension
            self.projection_layer = nn.Linear(self.embedding_size, self.embedding_size)
            self.bias_param = nn.Parameter(torch.full((1, self.embedding_size), BIAS))
        else:  # scalar
            # For scalar output, project to a single value
            self.projection_layer = nn.Linear(self.embedding_size, 1)
            self.bias_param = nn.Parameter(torch.tensor(BIAS))
        
        nn.init.zeros_(self.projection_layer.weight)
        nn.init.zeros_(self.projection_layer.bias)
    
    def _setup_backbone(self, model_orig: nn.Module) -> nn.Module:
        for param in model_orig.parameters():
            param.requires_grad = False
        return model_orig
    
    def forward(self, 
                vision: torch.Tensor, 
                output_normalize: bool,
                embedding_orig: Optional[torch.Tensor] = None) -> torch.Tensor:
        if embedding_orig is None:
            with torch.no_grad():
                embedding_orig = self.backbone(vision=vision, 
                                            output_normalize=output_normalize)
        
        output = self.projection_layer(embedding_orig) + self.bias_param
        return output + (output.relu() - output).detach()

class LinearMLPLambdaNetwork(LinearLambdaNetwork):
    def __init__(self, model_orig: nn.Module, clip_model_name: str, hidden_dim: int = -1, lagrangian_type: str = 'scalar'):
        # Initialize the parent without creating the projection layer
        super(BaseLambdaNetwork, self).__init__()
        
        if clip_model_name not in self.EMBEDDING_SIZES:
            print(f'Warning: Unsupported model "{clip_model_name}". Defaulting to embedding size 512.')
        
        self.embedding_size = self.EMBEDDING_SIZES.get(clip_model_name, 512)
        self.lagrangian_type = lagrangian_type
        self.backbone = self._setup_backbone(model_orig)
        
        # Ensure hidden_dim is reasonable
        if hidden_dim <= 0:
            hidden_dim = int(1.5 * self.embedding_size)
        elif hidden_dim < self.embedding_size // 2:
            hidden_dim = self.embedding_size // 2
        
        # Create appropriate MLP based on lagrangian_type
        if lagrangian_type == 'vector':
            # For vector output
            self.mlp = nn.Sequential(
                nn.Linear(self.embedding_size, hidden_dim),
                nn.LeakyReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.LeakyReLU(),
                nn.Linear(hidden_dim, self.embedding_size)  # Output full embedding dimension
            )
            self.bias_param = nn.Parameter(torch.full((1, self.embedding_size), BIAS))
        else:  # scalar
            # For scalar output
            self.mlp = nn.Sequential(
                nn.Linear(self.embedding_size, hidden_dim),
                nn.LeakyReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.LeakyReLU(),
                nn.Linear(hidden_dim, 1)  # Output a single value
            )
            self.bias_param = nn.Parameter(torch.tensor(BIAS))
        
        # Initialize only the last layer to zeros
        last_layer = self.mlp[-1]
        nn.init.zeros_(last_layer.weight)
        nn.init.zeros_(last_layer.bias)
        
        # Other layers get standard initialization
        for layer in self.mlp[:-2]:  # Skip ReLU and last layer
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_normal_(layer.weight)
                nn.init.zeros_(layer.bias)
    
    def forward(self, 
                vision: torch.Tensor, 
                output_normalize: bool,
                embedding_orig: Optional[torch.Tensor] = None) -> torch.Tensor:
        if embedding_orig is None:
            with torch.no_grad():
                embedding_orig = self.backbone(vision=vision, 
                                            output_normalize=output_normalize)
        
        output = self.mlp(embedding_orig) + self.bias_param
        return output + (output.relu() - output).detach()

class LambdaNetworkFactory:
    @staticmethod
    def create_network(network_type: str, **kwargs) -> BaseLambdaNetwork:
        """
        Create a lambda network of the specified type.
        
        Args:
            network_type: Type of network ('simple', 'linear', 'linear_mlp')
            **kwargs: Additional arguments for specific network types
                - model_orig: Original model for reference
                - clip_model_name: Name of the CLIP model being used
                - hidden_dim: (Optional) Hidden dimension for MLP-based networks
                - lagrangian_type: (Optional) Type of Lagrangian output ('scalar' or 'vector')
        """
        networks = {
            'simple': SimpleLambdaNetwork,
            'linear': LinearLambdaNetwork,
            'linear_mlp': LinearMLPLambdaNetwork
        }
        
        if network_type not in networks:
            raise ValueError(f'Unknown network type: {network_type}. '
                           f'Available types: {list(networks.keys())}')
        
        # Set default lagrangian_type if not provided
        if 'lagrangian_type' not in kwargs:
            kwargs['lagrangian_type'] = 'scalar'
        
        return networks[network_type](**kwargs)