"""
Neural Network Models for Multi-Task Learning

This module provides various neural network architectures for multi-task learning scenarios.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Any


class FeatureEncoder(nn.Module):
    """
    Basic feature encoder for multi-task learning.
    
    This module provides a simple feedforward network for feature extraction.
    """
    
    def __init__(self, input_dim: int = 512, hidden_dim: int = 64, dropout_rate: float = 0.5):
        """
        Initialize the feature encoder.
        
        Args:
            input_dim: Input dimension
            hidden_dim: Hidden layer dimension
            dropout_rate: Dropout rate for regularization
        """
        super(FeatureEncoder, self).__init__()
        self.with_resnet = False
        
        self.hidden_layer = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )
        
        # Initialize weights
        self._initialize_weights()
        
    def _initialize_weights(self):
        """Initialize network weights."""
        nn.init.normal_(self.hidden_layer[0].weight, 0, 0.005)
        nn.init.constant_(self.hidden_layer[0].bias, 0.1)
        
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the encoder.
        
        Args:
            inputs: Input tensor
            
        Returns:
            Encoded features
        """
        inputs = torch.flatten(inputs, 1)
        return self.hidden_layer(inputs)


class ResNetFeatureEncoder(nn.Module):
    """
    ResNet-based feature encoder for image data.
    
    This module uses a pre-trained ResNet backbone with a custom head.
    """
    
    def __init__(self, num_classes: int = 65, hidden_dim: int = 512, dropout_rate: float = 0.5):
        """
        Initialize the ResNet feature encoder.
        
        Args:
            num_classes: Number of output classes
            hidden_dim: Hidden layer dimension
            dropout_rate: Dropout rate for regularization
        """
        super(ResNetFeatureEncoder, self).__init__()
        
        # Import resnet18 from LibMTL (assuming it's available)
        try:
            from LibMTL.LibMTL.model import resnet18
            self.resnet_backbone = resnet18(pretrained=True)
        except ImportError:
            # Fallback to torchvision resnet if LibMTL is not available
            from torchvision.models import resnet18
            self.resnet_backbone = resnet18(pretrained=True)
        
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.hidden_layer = nn.Sequential(
            nn.Linear(512, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )
        
        # Initialize weights
        self._initialize_weights()
        
    def _initialize_weights(self):
        """Initialize network weights."""
        nn.init.normal_(self.hidden_layer[0].weight, 0, 0.005)
        nn.init.constant_(self.hidden_layer[0].bias, 0.1)
        
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the ResNet encoder.
        
        Args:
            inputs: Input tensor
            
        Returns:
            Encoded features
        """
        features = self.resnet_backbone(inputs)
        pooled_features = torch.flatten(self.adaptive_pool(features), 1)
        return self.hidden_layer(pooled_features)


class BertFeatureEncoder(nn.Module):
    """
    BERT-based feature encoder for text data.
    
    This module provides a simple linear projection for BERT embeddings.
    """
    
    def __init__(self, input_dim: int = 768, hidden_dim: int = 64, dropout_rate: float = 0.1):
        """
        Initialize the BERT feature encoder.
        
        Args:
            input_dim: Input dimension (BERT embedding size)
            hidden_dim: Hidden layer dimension
            dropout_rate: Dropout rate for regularization
        """
        super(BertFeatureEncoder, self).__init__()
        
        self.hidden_layer = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )
        
        # Initialize weights
        self._initialize_weights()
        
    def _initialize_weights(self):
        """Initialize network weights."""
        nn.init.normal_(self.hidden_layer[0].weight, 0, 0.005)
        nn.init.constant_(self.hidden_layer[0].bias, 0.1)
        
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the BERT encoder.
        
        Args:
            inputs: Input tensor (BERT embeddings)
            
        Returns:
            Encoded features
        """
        inputs = torch.flatten(inputs, 1)
        return self.hidden_layer(inputs)


class FullResNetEncoder(nn.Module):
    """
    Full ResNet encoder with custom output layer.
    
    This module replaces the final layer of ResNet with a custom output layer.
    """
    
    def __init__(self, output_dim: int = 64):
        """
        Initialize the full ResNet encoder.
        
        Args:
            output_dim: Output dimension
        """
        super(FullResNetEncoder, self).__init__()
        self.with_resnet = True
        
        try:
            from LibMTL.LibMTL.model import resnet18
            self.resnet = resnet18(pretrained=True)
        except ImportError:
            from torchvision.models import resnet18
            self.resnet = resnet18(pretrained=True)
        
        # Replace the final layer
        self.resnet.fc = nn.Linear(512, output_dim)
        
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the full ResNet.
        
        Args:
            inputs: Input tensor
            
        Returns:
            Encoded features
        """
        return self.resnet(inputs)


class MultiTaskResNet18(nn.Module):
    """
    Multi-task ResNet18 model for image classification.
    
    This module provides a ResNet18 backbone with a custom head for multi-task learning.
    """
    
    def __init__(self, hidden_dim: int = 64, dropout_rate: float = 0.5):
        """
        Initialize the multi-task ResNet18.
        
        Args:
            hidden_dim: Hidden layer dimension
            dropout_rate: Dropout rate for regularization
        """
        super(MultiTaskResNet18, self).__init__()
        
        try:
            from LibMTL.LibMTL.model import resnet18
            self.resnet_backbone = resnet18(pretrained=True)
        except ImportError:
            from torchvision.models import resnet18
            self.resnet_backbone = resnet18(pretrained=True)
        
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.hidden_layer = nn.Sequential(
            nn.Linear(512, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )
        
        # Initialize weights
        self._initialize_weights()
        
    def _initialize_weights(self):
        """Initialize network weights."""
        nn.init.normal_(self.hidden_layer[0].weight, 0, 0.005)
        nn.init.constant_(self.hidden_layer[0].bias, 0.1)
        
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the multi-task ResNet18.
        
        Args:
            inputs: Input tensor
            
        Returns:
            Encoded features
        """
        features = self.resnet_backbone(inputs)
        pooled_features = torch.flatten(self.adaptive_pool(features), 1)
        return self.hidden_layer(pooled_features)


class SimpleConvolutionalModel(nn.Module):
    """
    Simple convolutional neural network for image classification.
    
    This module provides a basic CNN architecture with two convolutional layers.
    """
    
    def __init__(self, num_classes: int = 2):
        """
        Initialize the simple convolutional model.
        
        Args:
            num_classes: Number of output classes
        """
        super(SimpleConvolutionalModel, self).__init__()
        
        # First convolutional layer: 16 filters, 3x3 kernel
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        
        # Second convolutional layer: 16 filters, 5x5 kernel
        self.conv2 = nn.Conv2d(16, 16, kernel_size=5, padding=2)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        
        # Global average pooling
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Output layer
        self.output_layer = nn.Linear(16, num_classes)
        
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the simple convolutional model.
        
        Args:
            inputs: Input tensor
            
        Returns:
            Model outputs
        """
        # First convolutional block
        x = F.relu(self.conv1(inputs))
        x = self.pool1(x)
        
        # Second convolutional block
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        
        # Global pooling and classification
        x = self.global_pool(x)
        x = torch.flatten(x, 1)
        x = self.output_layer(x)
        
        return x

