import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from ..base import BaseFuseTrainer
from ..registry import ModelRegistry
from ..base.base_encoder import ResNet50Encoder

@ModelRegistry.register('resnet')
class ResNetModel(BaseFuseTrainer):
    """
    Single-modal ResNet model for CXR image classification
    - Uses ResNet50Encoder from base encoder
    - Supports both mortality and phenotype prediction tasks
    - Inherits training/validation/test steps from BaseFuseTrainer
    """

    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)
        self.task = self.hparams.task
        
        # Set task-specific number of classes
        if self.task == 'phenotype':
            self.num_classes = self.hparams.num_classes
        elif self.task == 'mortality':
            self.num_classes = 1
        elif self.task == 'los':
            self.num_classes = 7  # LoS has 7 classes (bins 2-8, excluding 0,1)
        else:
            raise ValueError(f"Unsupported task: {self.task}. Only 'mortality', 'phenotype', and 'los' are supported")
        
        self._init_model_components()

    def _init_model_components(self):
        """Initialize the ResNet encoder and classifier"""
        # ResNet50 backbone
        self.resnet_encoder = ResNet50Encoder(
            hidden_size=getattr(self.hparams, 'hidden_size', 256),
            pretrained=getattr(self.hparams, 'pretrained', True)
        )
        
        # Classification head
        self.classifier = nn.Linear(
            self.resnet_encoder.get_output_dim(), 
            self.num_classes
        )
        
        # Dropout for regularization
        self.dropout = nn.Dropout(getattr(self.hparams, 'dropout', 0.3))
        
        print(f"ResNet model initialized for {self.task} task")
        print(f"  - Hidden size: {getattr(self.hparams, 'hidden_size', 256)}")
        print(f"  - Pretrained: {getattr(self.hparams, 'pretrained', True)}")
        print(f"  - Dropout: {getattr(self.hparams, 'dropout', 0.3)}")
        print(f"  - Number of classes: {self.num_classes}")

    def forward(self, batch):
        """
        Forward pass for the ResNet model
        
        Args:
            batch: Dictionary containing 'cxr_imgs' and 'labels'
            
        Returns:
            Dictionary with 'loss', 'predictions', and 'labels'
        """
        # Extract inputs
        images = batch['cxr_imgs']  # [batch_size, 3, H, W]
        labels = batch['labels']   # [batch_size, num_classes] or [batch_size]
        
        # Forward pass through ResNet encoder
        features = self.resnet_encoder(images)  # [batch_size, hidden_size]
        features = self.dropout(features)
        
        # Classification
        predictions = self.classifier(features)  # [batch_size, num_classes]
        
        # Apply sigmoid for probability output
        predictions_prob = torch.sigmoid(predictions)
        
        # Calculate loss
        loss = self.classification_loss(predictions, labels)
        
        return {
            'loss': loss,
            'predictions': predictions_prob,
            'labels': labels,
            'features': features
        }

    def configure_optimizers(self):
        """Configure optimizer and learning rate scheduler"""
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=getattr(self.hparams, 'lr', 0.0001)
        )

        scheduler = {
            "scheduler": ReduceLROnPlateau(
                optimizer,
                factor=0.5,
                patience=getattr(self.hparams, 'patience', 10),
                mode='min',
                verbose=True
            ),
            "monitor": "loss/validation_epoch",
            "interval": "epoch",
            "frequency": 1
        }

        return {"optimizer": optimizer, "lr_scheduler": scheduler}