"""Backbone models with feature extraction capabilities."""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from typing import Dict, Tuple, Optional


class FeatureExtractor(nn.Module):
    """Feature extractor that returns both features and logits."""
    
    def __init__(self, model: nn.Module, feature_layer: str = "avgpool"):
        """Initialize feature extractor.
        
        Args:
            model: Base model (e.g., ResNet)
            feature_layer: Layer name to extract features from
        """
        super().__init__()
        self.model = model
        self.feature_layer = feature_layer
        self.features = None
        self._register_hooks()
    
    def _register_hooks(self):
        """Register forward hooks to capture features."""
        def get_features(name):
            def hook(model, input, output):
                self.features = output
            return hook
        
        # Register hook on the specified layer
        for name, module in self.model.named_modules():
            if name == self.feature_layer:
                module.register_forward_hook(get_features(name))
                break
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass returning features and logits.
        
        Args:
            x: Input tensor
            
        Returns:
            Tuple of (features, logits)
        """
        logits = self.model(x)
        features = self.features
        
        if features is None:
            # Fallback: extract features manually
            features = self._extract_features_manual(x)
        
        return features, logits
    
    def _extract_features_manual(self, x: torch.Tensor) -> torch.Tensor:
        """Manual feature extraction if hooks fail."""
        # This is a fallback method
        with torch.no_grad():
            # Remove the final classification layer
            features = self.model.forward_features(x) if hasattr(self.model, 'forward_features') else x
        return features


class ResNet50Backbone(nn.Module):
    """ResNet-50 backbone with configurable feature extraction."""
    
    def __init__(
        self,
        num_classes: int,
        pretrained: bool = True,
        dropout: float = 0.1,
        feature_dim: int = 2048
    ):
        """Initialize ResNet-50 backbone.
        
        Args:
            num_classes: Number of output classes
            pretrained: Whether to use pretrained weights
            dropout: Dropout rate
            feature_dim: Feature dimension
        """
        super().__init__()
        
        # Load pretrained ResNet-50
        self.backbone = models.resnet50(pretrained=pretrained)
        
        # Remove the final classification layer
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
        
        # Add custom head
        self.head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(feature_dim, num_classes)
        )
        
        self.feature_dim = feature_dim
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass.
        
        Args:
            x: Input tensor
            
        Returns:
            Logits
        """
        features = self.backbone(x)
        features = features.view(features.size(0), -1)
        logits = self.head(features)
        return logits
    
    def get_features(self, x: torch.Tensor) -> torch.Tensor:
        """Extract features without classification head.
        
        Args:
            x: Input tensor
            
        Returns:
            Features tensor
        """
        features = self.backbone(x)
        features = features.view(features.size(0), -1)
        return features
    
    def get_features_and_logits(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get both features and logits.
        
        Args:
            x: Input tensor
            
        Returns:
            Tuple of (features, logits)
        """
        features = self.get_features(x)
        logits = self.head(features)
        return features, logits


class ResNet50WithFeatures(nn.Module):
    """ResNet-50 with integrated feature extraction."""
    
    def __init__(
        self,
        num_classes: int,
        pretrained: bool = True,
        dropout: float = 0.1,
        feature_dim: int = 2048
    ):
        """Initialize ResNet-50 with features.
        
        Args:
            num_classes: Number of output classes
            pretrained: Whether to use pretrained weights
            dropout: Dropout rate
            feature_dim: Feature dimension
        """
        super().__init__()
        
        # Load pretrained ResNet-50
        resnet = models.resnet50(pretrained=pretrained)
        
        # Extract layers
        self.conv1 = resnet.conv1
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        self.avgpool = resnet.avgpool
        
        # Custom head
        self.head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(feature_dim, num_classes)
        )
        
        self.feature_dim = feature_dim
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass returning logits.
        
        Args:
            x: Input tensor
            
        Returns:
            Logits
        """
        features = self.get_features(x)
        logits = self.head(features)
        return logits
    
    def get_features(self, x: torch.Tensor) -> torch.Tensor:
        """Extract features from the penultimate layer.
        
        Args:
            x: Input tensor
            
        Returns:
            Features tensor
        """
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        
        return x
    
    def get_features_and_logits(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get both features and logits.
        
        Args:
            x: Input tensor
            
        Returns:
            Tuple of (features, logits)
        """
        features = self.get_features(x)
        logits = self.head(features)
        return features, logits


def get_backbone(
    backbone_name: str,
    num_classes: int,
    pretrained: bool = True,
    **kwargs
) -> nn.Module:
    """Get backbone model by name.
    
    Args:
        backbone_name: Name of backbone (resnet50)
        num_classes: Number of output classes
        pretrained: Whether to use pretrained weights
        **kwargs: Additional arguments
        
    Returns:
        Backbone model
    """
    if backbone_name == "resnet50":
        return ResNet50WithFeatures(
            num_classes=num_classes,
            pretrained=pretrained,
            **kwargs
        )
    else:
        raise ValueError(f"Unknown backbone: {backbone_name}")


def freeze_backbone_layers(model: nn.Module, num_layers: int = 0):
    """Freeze specified number of backbone layers.
    
    Args:
        model: Model to freeze layers in
        num_layers: Number of layers to freeze (0 = freeze all backbone)
    """
    if isinstance(model, ResNet50WithFeatures):
        # Freeze backbone layers
        for param in model.conv1.parameters():
            param.requires_grad = False
        for param in model.bn1.parameters():
            param.requires_grad = False
        for param in model.maxpool.parameters():
            param.requires_grad = False
        
        if num_layers >= 1:
            for param in model.layer1.parameters():
                param.requires_grad = False
        if num_layers >= 2:
            for param in model.layer2.parameters():
                param.requires_grad = False
        if num_layers >= 3:
            for param in model.layer3.parameters():
                param.requires_grad = False
        if num_layers >= 4:
            for param in model.layer4.parameters():
                param.requires_grad = False
        
        # Keep head trainable
        for param in model.head.parameters():
            param.requires_grad = True
        
        print(f"Froze {num_layers} backbone layers")
    else:
        print("Freezing not implemented for this model type")


def unfreeze_all_layers(model: nn.Module):
    """Unfreeze all layers in the model.
    
    Args:
        model: Model to unfreeze
    """
    for param in model.parameters():
        param.requires_grad = True
    print("Unfroze all layers")
