

import torch
import torch.nn as nn


class MultiTaskWrapper(nn.Module):

    def __init__(self, base_model, task_names, shared_dim=128, out_dim=1):

        super().__init__()
        self.base_model = base_model
        self.task_names = task_names
        self.shared_dim = shared_dim
        self.out_dim = out_dim
        

        self.task_heads = nn.ModuleDict()
        hidden_dim = 64
        for task_name in task_names:
            input_norm = nn.LayerNorm(shared_dim)
            hidden_linear = nn.Linear(shared_dim, hidden_dim)
            hidden_norm = nn.LayerNorm(hidden_dim)
            hidden_activation = nn.ReLU()
            hidden_dropout = nn.Dropout(0.1)
            
            hidden_residual = nn.Linear(shared_dim, hidden_dim) if shared_dim != hidden_dim else nn.Identity()
            
            output_norm = nn.LayerNorm(hidden_dim)
            output_linear = nn.Linear(hidden_dim, out_dim)
            
            with torch.no_grad():
                nn.init.xavier_uniform_(hidden_linear.weight, gain=0.1)  
                nn.init.constant_(hidden_linear.bias, 0.0)
                if shared_dim != hidden_dim:
                    nn.init.xavier_uniform_(hidden_residual.weight, gain=0.1)
                    nn.init.constant_(hidden_residual.bias, 0.0)
                nn.init.xavier_uniform_(output_linear.weight, gain=1.0)  
                nn.init.constant_(output_linear.bias, 0.0)
            
            class ResidualMLPHead(nn.Module):
                def __init__(self, input_norm, hidden_linear, hidden_norm, hidden_activation, 
                           hidden_dropout, hidden_residual, output_norm, output_linear):
                    super().__init__()
                    self.input_norm = input_norm
                    self.hidden_linear = hidden_linear
                    self.hidden_norm = hidden_norm
                    self.hidden_activation = hidden_activation
                    self.hidden_dropout = hidden_dropout
                    self.hidden_residual = hidden_residual
                    self.output_norm = output_norm
                    self.output_linear = output_linear
                
                def forward(self, x):
                    x_norm = self.input_norm(x)
                    x_proj = self.hidden_linear(x_norm)
                    x_residual = self.hidden_residual(x_norm) if isinstance(self.hidden_residual, nn.Linear) else x_norm
                    x_proj = x_proj + x_residual
                    x_proj = self.hidden_norm(x_proj)
                    x_proj = self.hidden_activation(x_proj)
                    x_proj = self.hidden_dropout(x_proj)
                    
                    x_proj = self.output_norm(x_proj)
                    x_out = self.output_linear(x_proj)
                    
                    return x_out
            
            head = ResidualMLPHead(
                input_norm, hidden_linear, hidden_norm, hidden_activation,
                hidden_dropout, hidden_residual, output_norm, output_linear
            )
            
            self.task_heads[task_name] = head
        
        print(f"Created MultiTaskWrapper with {len(task_names)} tasks:")
        for task_name in task_names:
            print(f"  - {task_name}: independent output head (shared_dim={shared_dim} -> out_dim={out_dim})")
    
    def _extract_features(self, base_output):
        align_loss = None
        base_data = None
        
        if isinstance(base_output, tuple):
            base_data, align_loss = base_output
            if isinstance(align_loss, torch.Tensor):
                pass
            elif isinstance(align_loss, (int, float)):
                if base_data is not None and isinstance(base_data, torch.Tensor):
                    align_loss = torch.tensor(align_loss).to(base_data.device)
                else:
                    align_loss = None
            else:
                align_loss = None
        elif isinstance(base_output, dict):
            if 'features' in base_output:
                candidate = base_output['features']
                if isinstance(candidate, torch.Tensor):
                    base_data = candidate
                    if not hasattr(self, '_extracted_key_warned'):
                        print(f"[MultiTaskWrapper] Using 'features' from model output (shape: {base_data.shape})")
                        self._extracted_key_warned = True
            
            if base_data is None:
                feature_keys = ['hidden', 'embedding', 'representation']
                for key in feature_keys:
                    if key in base_output:
                        candidate = base_output[key]
                        if isinstance(candidate, torch.Tensor):
                            base_data = candidate
                            if not hasattr(self, '_extracted_key_warned'):
                                print(f"[MultiTaskWrapper] Using '{key}' from model output (shape: {base_data.shape})")
                                self._extracted_key_warned = True
                            break
            
            if base_data is None and 'logits' in base_output:
                logits = base_output['logits']
                if isinstance(logits, dict):
                    if hasattr(self.base_model, 'last_features'):
                        base_data = self.base_model.last_features
                        if not hasattr(self, '_extracted_key_warned'):
                            print(f"[MultiTaskWrapper] Using base_model.last_features (shape: {base_data.shape})")
                            self._extracted_key_warned = True
                elif isinstance(logits, torch.Tensor):
                    base_data = logits
                    if not hasattr(self, '_extracted_key_warned'):
                        print(f"[MultiTaskWrapper] WARNING: Using 'logits' as features (shape: {base_data.shape}). This is unusual.")
                        self._extracted_key_warned = True
            
            if base_data is None:
                if hasattr(self.base_model, 'last_features'):
                    base_data = self.base_model.last_features
                    if not hasattr(self, '_extracted_key_warned'):
                        print(f"[MultiTaskWrapper] Using base_model.last_features as fallback (shape: {base_data.shape})")
                        self._extracted_key_warned = True
            
            if base_data is None:
                for key, value in base_output.items():
                    if isinstance(value, torch.Tensor) and key not in ['clip_loss', 'align_loss', 'eamc_info', 'graph_info']:
                        base_data = value
                        if not hasattr(self, '_extracted_key_warned'):
                            print(f"[MultiTaskWrapper] WARNING: Using '{key}' from model output as features (shape: {value.shape})")
                            self._extracted_key_warned = True
                        break
            
            align_loss = base_output.get('clip_loss') or base_output.get('align_loss')
            if align_loss is not None and not isinstance(align_loss, torch.Tensor):
                if base_data is not None and isinstance(base_data, torch.Tensor):
                    align_loss = torch.tensor(align_loss).to(base_data.device) if isinstance(align_loss, (int, float)) else None
                else:
                    align_loss = None
        elif isinstance(base_output, torch.Tensor):
            base_data = base_output
        else:
            raise TypeError(f"Unexpected base_output type: {type(base_output)}. Expected tuple, dict, or tensor.")
        
        if base_data is None:
            raise ValueError(f"Could not extract features from base_output. Type: {type(base_output)}. "
                           f"If dict, keys: {list(base_output.keys()) if isinstance(base_output, dict) else 'N/A'}")
        
        if not isinstance(base_data, torch.Tensor):
            raise TypeError(f"Extracted base_data is not a tensor. Type: {type(base_data)}")
        
        return base_data, align_loss
    
    def _process_features(self, features):
        if not isinstance(features, torch.Tensor):
            raise TypeError(f"Expected torch.Tensor, got {type(features)}")
        
        original_shape = features.shape
        
        if features.dim() == 0:
            features = features.unsqueeze(0).unsqueeze(0)
        elif features.dim() == 1:
            features = features.unsqueeze(0)
        elif features.dim() == 3:
            if features.shape[1] > 0:
                features = features[:, 0, :]  
            else:
                features = features.mean(dim=1)  
        elif features.dim() > 3:
            batch_size = features.shape[0]
            features = features.view(batch_size, -1)
        
        if features.dim() != 2:
            if features.dim() == 1:
                features = features.unsqueeze(0)
            else:
                batch_size = features.shape[0] if features.dim() > 0 else 1
                features = features.view(batch_size, -1)
        
        if features.shape[1] != self.shared_dim:
            if not hasattr(self, '_feature_projection'):
                self._feature_projection = nn.Linear(features.shape[1], self.shared_dim).to(features.device)
            features = self._feature_projection(features)
        
        if features.dim() != 2:
            raise ValueError(f"Features should be 2D [batch_size, shared_dim], got shape {features.shape}")
        
        return features
    
    def forward(self, images, text_item, packed_ts, static_data, task=None):
        base_output = self.base_model(images, text_item, packed_ts, static_data, task='classification')
        
        base_features, align_loss = self._extract_features(base_output)
        
        features = self._process_features(base_features)
        
        task_outputs = {}
        for task_name in self.task_names:
            task_head = self.task_heads[task_name]
            task_logits = task_head(features)
            if task_logits.dim() == 2 and task_logits.shape[1] == 1:
                task_logits = task_logits.squeeze(-1)
            task_outputs[task_name] = task_logits
        
        if align_loss is not None:
            return task_outputs, align_loss
        else:
            return task_outputs
