import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models import vit_b_16, ViT_B_16_Weights

class EHREncoderFactory:
    """Factory class for creating EHR encoders"""
    
    @staticmethod
    def create_encoder(encoder_type, **kwargs):
        """
        Create EHR encoder based on type
        
        Args:
            encoder_type: 'lstm', 'transformer'
            **kwargs: encoder-specific parameters
        """
        if encoder_type.lower() == 'lstm':
            return LSTMEncoder(**kwargs)
        elif encoder_type.lower() == 'transformer':
            return TransformerEncoder(**kwargs)
        else:
            raise ValueError(f"Unsupported EHR encoder type: {encoder_type}. Supported types: 'lstm', 'transformer'")

class CXREncoderFactory:
    """Factory class for creating CXR encoders"""
    
    @staticmethod
    def create_encoder(encoder_type, hidden_size=256, pretrained=True, **kwargs):
        """
        Create CXR encoder based on type
        
        Args:
            encoder_type: 'resnet50', 'vit_b_16'
            hidden_size: output feature dimension
            pretrained: whether to use pretrained weights
            **kwargs: encoder-specific parameters
        """
        if encoder_type.lower() == 'resnet50':
            return ResNet50Encoder(hidden_size=hidden_size, pretrained=pretrained)
        elif encoder_type.lower() == 'vit_b_16':
            return ViTEncoder(hidden_size=hidden_size, pretrained=pretrained)
        else:
            raise ValueError(f"Unsupported CXR encoder type: {encoder_type}. Supported types: 'resnet50', 'vit_b_16'")

class LearnablePositionalEncoding(nn.Module):
    """Learnable positional encoding for transformer"""
    
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 500):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        self.pe = nn.Parameter(torch.rand(1, max_len, d_model))
        self.pe.data.uniform_(-0.1, 0.1)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]  # x: (batch_size, seq_len, embedding_dim)
        return self.dropout(x)

class LSTMEncoder(nn.Module):
    """LSTM-based EHR encoder (MedFuse style)"""
    
    def __init__(self, input_size, num_classes, hidden_size=256, num_layers=2, 
                 dropout=0.3, bidirectional=True, **kwargs):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        
        # Input embedding/projection
        self.input_projection = nn.Linear(input_size, hidden_size)
        
        # LSTM layer
        self.lstm = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional,
            batch_first=True
        )
        
        # Output projection
        lstm_output_size = hidden_size * 2 if bidirectional else hidden_size
        self.feature_projection = nn.Linear(lstm_output_size, hidden_size)
        self.classifier = nn.Linear(hidden_size, num_classes)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, seq_lengths, output_prob=True):
        """
        Args:
            x: [batch_size, seq_len, input_size]
            seq_lengths: [batch_size] sequence lengths (list or tensor)
            output_prob: whether to apply sigmoid
        """
        # Project input to hidden dimension
        x = self.input_projection(x)
        
        # Ensure seq_lengths is a tensor and handle type conversion
        if isinstance(seq_lengths, list):
            seq_lengths = torch.tensor(seq_lengths, dtype=torch.long, device=x.device)
        elif not isinstance(seq_lengths, torch.Tensor):
            seq_lengths = torch.tensor(seq_lengths, dtype=torch.long, device=x.device)
        elif seq_lengths.device != x.device:
            seq_lengths = seq_lengths.to(x.device)
        
        # Clamp seq_lengths to valid range
        seq_lengths = torch.clamp(seq_lengths, min=1, max=x.size(1))
        
        # Pack padded sequence
        packed_x = nn.utils.rnn.pack_padded_sequence(
            x, seq_lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        
        # LSTM forward pass - only get hidden state, output is not needed
        _, (hidden, _) = self.lstm(packed_x)
        
        # Use the last hidden state - modified here
        if self.bidirectional:
            # Bidirectional LSTM: take the last layer's forward and backward hidden states
            # hidden shape: [num_layers * 2, batch_size, hidden_size]
            forward_hidden = hidden[-2]  # Last layer's forward hidden state
            backward_hidden = hidden[-1]  # Last layer's backward hidden state
            lstm_feat = torch.cat([forward_hidden, backward_hidden], dim=1)
        else:
            # Unidirectional LSTM: take the last layer's hidden state
            # hidden shape: [num_layers, batch_size, hidden_size]
            lstm_feat = hidden[-1]  # Last layer's hidden state
        
        lstm_feat = self.dropout(lstm_feat)
        
        # Project to final feature space
        feat = self.feature_projection(lstm_feat)
        feat = self.dropout(feat)
        
        # Classification
        prediction = self.classifier(feat)
        if output_prob:
            prediction = prediction.sigmoid()
            
        return feat, prediction
    
    def get_output_dim(self):
        """Return the output dimension of the encoder"""
        return self.hidden_size

class TransformerEncoder(nn.Module):
    """Transformer-based EHR encoder (DRFUSE style)"""
    
    def __init__(self, input_size, num_classes, d_model=256, n_head=8, n_layers=2,
                 dropout=0.3, max_len=500, **kwargs):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len
        
        # Input embedding and positional encoding
        self.input_embedding = nn.Linear(input_size, d_model)
        self.pos_encoder = LearnablePositionalEncoding(d_model, dropout=0, max_len=max_len)
        
        # Transformer encoder layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=n_head, 
            batch_first=True, 
            dropout=dropout,
            dim_feedforward=d_model * 4  # Standard transformer scaling
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        
        # Output layers
        self.classifier = nn.Linear(d_model, num_classes)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, seq_lengths, output_prob=True):
        """
        Args:
            x: [batch_size, seq_len, input_size]
            seq_lengths: [batch_size] sequence lengths (list or tensor)
            output_prob: whether to apply sigmoid
        """
        # Project input to hidden dimension
        x = self.input_embedding(x)
        x = self.pos_encoder(x)
        
        # Ensure seq_lengths is a tensor and handle type conversion
        if isinstance(seq_lengths, list):
            seq_lengths = torch.tensor(seq_lengths, dtype=torch.long, device=x.device)
        elif not isinstance(seq_lengths, torch.Tensor):
            seq_lengths = torch.tensor(seq_lengths, dtype=torch.long, device=x.device)
        elif seq_lengths.device != x.device:
            seq_lengths = seq_lengths.to(x.device)
        
        # Clamp seq_lengths to valid range
        seq_lengths = torch.clamp(seq_lengths, min=1, max=x.size(1))
        
        # Create attention mask for variable length sequences
        attn_mask = torch.stack([
            torch.cat([
                torch.zeros(len_, device=x.device),
                # torch.ones for padding positions
                torch.ones(max(seq_lengths) - len_, dtype=torch.bool, device=x.device)
            ])
            for len_ in seq_lengths
        ])
        
        # Transformer encoding
        transformer_out = self.transformer_encoder(x, src_key_padding_mask=attn_mask)
        
        # Global average pooling with attention to valid positions
        padding_mask = torch.ones_like(attn_mask).unsqueeze(2)
        padding_mask[attn_mask == float('-inf')] = 0
        
        # Weighted average over sequence dimension
        feat = (padding_mask * transformer_out).sum(dim=1) / padding_mask.sum(dim=1)
        feat = self.dropout(feat)
        
        # Classification
        prediction = self.classifier(feat)
        if output_prob:
            prediction = prediction.sigmoid()
            
        return feat, prediction
    
    def get_output_dim(self):
        """Return the output dimension of the encoder"""
        return self.d_model

class ResNet50Encoder(nn.Module):
    """ResNet50-based CXR encoder"""
    
    def __init__(self, hidden_size=256, pretrained=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.backbone = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2 if pretrained else None)
        self.backbone.fc = nn.Linear(in_features=2048, out_features=hidden_size)
        
    def forward(self, x):
        return self.backbone(x)
    
    def get_output_dim(self):
        """Return the output dimension of the encoder"""
        return self.hidden_size

class ViTEncoder(nn.Module):
    """Vision Transformer Base 16-based CXR encoder"""
    
    def __init__(self, hidden_size=256, pretrained=True):
        super().__init__()
        self.hidden_size = hidden_size
        
        self.backbone = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1 if pretrained else None)
        vit_hidden_size = 768  # ViT-B/16 hidden size
            
        # Replace the classifier head
        self.backbone.heads.head = nn.Linear(vit_hidden_size, hidden_size)
        
    def forward(self, x):
        return self.backbone(x)
    
    def get_output_dim(self):
        """Return the output dimension of the encoder"""
        return self.hidden_size

# Convenience functions
def create_ehr_encoder(encoder_type, **kwargs):
    """Create EHR encoder with unified interface"""
    return EHREncoderFactory.create_encoder(encoder_type, **kwargs)

def create_cxr_encoder(encoder_type, **kwargs):
    """Create CXR encoder with unified interface"""
    return CXREncoderFactory.create_encoder(encoder_type, **kwargs)