"""
Module for modality routing and expert models.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


class ModalityRouter(nn.Module):
    """Router that decides which modality path to use.
    Takes concatenated numeric and text features and outputs probabilities over 4 modality paths.
    """
    
    def __init__(self, input_dim_numeric=3, input_dim_text=2, hidden_dim=32):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim_numeric + input_dim_text, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 4)  # 4 modality paths
        )
    
    def forward(self, x_numeric, x_text):
        """Forward pass for modality router.
        Args:
            x_numeric (Tensor): Numeric features
            x_text (Tensor): Text features
        Returns:
            Tensor: Probabilities over 4 modality paths
        """
        x = torch.cat([x_numeric, x_text], dim=1)
        logits = self.mlp(x)
        return F.softmax(logits, dim=1)


class TaskRouter(nn.Module):
    """Router that decides between STL and MTL for each modality path.
    """
    
    def __init__(self, input_dim, hidden_dim=32):
        super().__init__()
        self.input_dim = input_dim
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2)  # STL vs MTL
        )
    
    def forward(self, x):
        """Forward pass for task router.
        Args:
            x (Tensor): Input features for a modality path
        Returns:
            Tensor: Probabilities over STL/MTL
        """
        # Ensure input has the correct shape
        if len(x.shape) > 2:  # If input has more than 2 dimensions
            batch_size = x.shape[0]
            x = x.reshape(batch_size, -1)  # Flatten all dimensions except batch
        
        if x.shape[1] > self.input_dim:
            x = x[:, :self.input_dim]  # Take only the first self.input_dim features
        elif x.shape[1] < self.input_dim:
            # Pad with zeros if input dimension is smaller
            padding = torch.zeros(x.shape[0], self.input_dim - x.shape[1], device=x.device)
            x = torch.cat([x, padding], dim=1)
        
        logits = self.mlp(x)
        return F.softmax(logits, dim=1)


class ModalityTransform(nn.Module):
    """Transforms inputs according to the 4 modality paths.
    T1: text only, N1: numeric only, T2: text + nonlinear(numeric), N2: numeric + nonlinear(text)
    """
    
    def __init__(self, input_dim_numeric=3, input_dim_text=6, hidden_dim=32):
        super().__init__()
        self.input_dim_numeric = input_dim_numeric
        self.input_dim_text = input_dim_text
        # Nonlinear transformation for fusion paths
        self.numeric_to_text = nn.Sequential(
            nn.Linear(input_dim_numeric, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim_text)
        )
        self.text_to_numeric = nn.Sequential(
            nn.Linear(input_dim_text, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim_numeric)
        )
    
    def forward(self, x_numeric, x_text):
        """Forward pass for modality transformation.
        Args:
            x_numeric (Tensor): Numeric features
            x_text (Tensor): Text features
        Returns:
            tuple: Transformed features for each modality path (T1, N1, T2, N2)
        """
        # T1: Use x_text only
        t1 = x_text
        # N1: Use x_numeric only
        n1 = x_numeric
        # T2: Concatenate x_text with nonlinear transformation of x_numeric
        nonlinear_n1 = torch.sin(self.numeric_to_text(x_numeric))
        t2 = torch.cat([x_text, nonlinear_n1], dim=1)
        # N2: Concatenate x_numeric with nonlinear transformation of x_text
        nonlinear_t1 = torch.sin(self.text_to_numeric(x_text))
        n2 = torch.cat([x_numeric, nonlinear_t1], dim=1)
        return t1, n1, t2, n2


class ExpertModel(nn.Module):
    """Base class for expert models (STL and MTL).
    Encodes input features for downstream prediction.
    """
    
    def __init__(self, input_dim, hidden_dim=32):
        super().__init__()
        self.input_dim = input_dim
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
    
    def forward(self, x):
        """Forward pass for expert model encoder.
        Args:
            x (Tensor): Input features
        Returns:
            Tensor: Encoded features
        """
        # Ensure input has the correct shape
        if len(x.shape) > 2:  # If input has more than 2 dimensions
            batch_size = x.shape[0]
            x = x.reshape(batch_size, -1)  # Flatten all dimensions except batch
        
        if x.shape[1] > self.input_dim:
            x = x[:, :self.input_dim]  # Take only the first self.input_dim features
        elif x.shape[1] < self.input_dim:
            # Pad with zeros if input dimension is smaller
            padding = torch.zeros(x.shape[0], self.input_dim - x.shape[1], device=x.device)
            x = torch.cat([x, padding], dim=1)
        
        return self.encoder(x)


class STLExpert(ExpertModel):
    """Single Task Learning expert.
    Outputs a single prediction for one task.
    """
    
    def __init__(self, input_dim, hidden_dim=32):
        super().__init__(input_dim, hidden_dim)
        self.head = nn.Linear(hidden_dim, 1)
    
    def forward(self, x):
        """Forward pass for STL expert.
        Args:
            x (Tensor): Input features
        Returns:
            Tensor: Prediction for one task
        """
        features = super().forward(x)
        return self.head(features)


class MTLExpert(ExpertModel):
    """Multi Task Learning expert.
    Outputs predictions for two tasks.
    """
    
    def __init__(self, input_dim, hidden_dim=32):
        super().__init__(input_dim, hidden_dim)
        self.head1 = nn.Linear(hidden_dim, 1)
        self.head2 = nn.Linear(hidden_dim, 1)
    
    def forward(self, x):
        """Forward pass for MTL expert.
        Args:
            x (Tensor): Input features
        Returns:
            tuple: Predictions for two tasks
        """
        features = super().forward(x)
        return self.head1(features), self.head2(features)


class RoutingModel(nn.Module):
    """Complete routing model with modality and task routing.
    Integrates all routers, transforms, and experts for end-to-end prediction.
    """
    
    def __init__(self, input_dim_numeric=3, input_dim_text=2, hidden_dim=32):
        super().__init__()
        self.modality_transform = ModalityTransform(
            input_dim_numeric, input_dim_text, hidden_dim
        )
        self.modality_router = ModalityRouter(
            input_dim_numeric, input_dim_text, hidden_dim
        )
        
        # Task routers for each modality path with correct input dimensions
        input_dims = [
            input_dim_text,  # T1
            input_dim_numeric,  # N1
            input_dim_text + input_dim_numeric,  # T2
            input_dim_numeric * 2  # N2
        ]
        self.task_routers = nn.ModuleList([
            TaskRouter(dim, hidden_dim) for dim in input_dims
        ])
        
        # Expert models for each modality path and task type
        self.experts = nn.ModuleList([
            nn.ModuleList([
                STLExpert(dim, hidden_dim),
                MTLExpert(dim, hidden_dim)
            ])
            for dim in input_dims
        ])
    
    def forward(self, x_numeric, x_text, hard_routing=False):
        """Forward pass for the full routing model.
        Args:
            x_numeric (Tensor): Numeric features
            x_text (Tensor): Text features
            hard_routing (bool): If True, use hard (Gumbel) routing
        Returns:
            tuple: Predictions for both tasks, modality probabilities, task probabilities
        """
        # Get modality paths
        t1, n1, t2, n2 = self.modality_transform(x_numeric, x_text)
        paths = [t1, n1, t2, n2]
        
        # Get modality routing probabilities
        modality_probs = self.modality_router(x_numeric, x_text)  # [batch_size, 4]
        if hard_routing:
            modality_probs = F.gumbel_softmax(
                torch.log(modality_probs + 1e-10),
                tau=0.1,
                hard=True
            )
        
        # Get task routing probabilities and expert predictions
        task_probs = []
        expert_preds = []
        
        for i, (path, task_router, experts) in enumerate(zip(
            paths, self.task_routers, self.experts
        )):
            # Get task routing probabilities
            task_prob = task_router(path)  # [batch_size, 2]
            if hard_routing:
                task_prob = F.gumbel_softmax(
                    torch.log(task_prob + 1e-10),
                    tau=0.1,
                    hard=True
                )
            task_probs.append(task_prob)
            
            # Get expert predictions
            stl_pred = experts[0](path)  # STL expert
            mtl_pred1, mtl_pred2 = experts[1](path)  # MTL expert
            expert_preds.append((stl_pred, mtl_pred1, mtl_pred2))
        
        # Combine predictions using routing probabilities
        final_pred1 = torch.zeros_like(x_numeric[:, 0:1])
        final_pred2 = torch.zeros_like(x_numeric[:, 0:1])
        
        for i, ((stl_pred, mtl_pred1, mtl_pred2), task_prob) in enumerate(zip(expert_preds, task_probs)):
            # Get modality probability for this path
            mod_prob = modality_probs[:, i:i+1]  # [batch_size, 1]
            
            # STL contribution
            final_pred1 += mod_prob * task_prob[:, 0:1] * stl_pred
            final_pred2 += mod_prob * task_prob[:, 0:1] * stl_pred
            
            # MTL contribution
            final_pred1 += mod_prob * task_prob[:, 1:2] * mtl_pred1
            final_pred2 += mod_prob * task_prob[:, 1:2] * mtl_pred2
        
        return final_pred1, final_pred2, modality_probs, task_probs 