"""
Task-specific loss functions for different downstream tasks.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Union, Dict, Any


class ClassificationLoss(nn.Module):
    """
    Classification loss with support for various loss types.
    """
    
    def __init__(self, 
                 loss_type: str = 'ce',
                 num_classes: Optional[int] = None,
                 label_smoothing: float = 0.0,
                 focal_alpha: float = 1.0,
                 focal_gamma: float = 2.0,
                 class_weights: Optional[torch.Tensor] = None):
        super().__init__()
        self.loss_type = loss_type
        self.num_classes = num_classes
        self.label_smoothing = label_smoothing
        self.focal_alpha = focal_alpha
        self.focal_gamma = focal_gamma
        self.class_weights = class_weights
        
        if loss_type == 'ce':
            self.criterion = nn.CrossEntropyLoss(
                label_smoothing=label_smoothing,
                weight=class_weights
            )
        elif loss_type == 'bce':
            self.criterion = nn.BCEWithLogitsLoss(weight=class_weights)
        elif loss_type == 'focal':
            self.criterion = None  # Will be computed manually
        else:
            raise ValueError(f"Unsupported loss type: {loss_type}")
            
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Compute classification loss.
        
        Args:
            pred: [batch_size, num_classes] - predicted logits
            target: [batch_size] or [batch_size, num_classes] - target labels
            
        Returns:
            Classification loss scalar
        """
        if self.loss_type == 'focal':
            return self._focal_loss(pred, target)
        else:
            return self.criterion(pred, target)
            
    def _focal_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """Compute focal loss."""
        if target.dim() == 1:
            # Convert to one-hot
            target_one_hot = F.one_hot(target, num_classes=pred.size(1))
        else:
            target_one_hot = target
            
        # Apply sigmoid for binary case
        if pred.size(1) == 1:
            pred_prob = torch.sigmoid(pred).squeeze(-1)
            target_one_hot = target_one_hot.float()
        else:
            pred_prob = F.softmax(pred, dim=-1)
            
        # Focal loss
        ce_loss = F.cross_entropy(pred, target, reduction='none')
        pt = torch.where(target_one_hot == 1, pred_prob, 1 - pred_prob)
        focal_loss = self.focal_alpha * (1 - pt).pow(self.focal_gamma) * ce_loss
        
        return focal_loss.mean()


class RegressionLoss(nn.Module):
    """
    Regression loss with support for various loss types.
    """
    
    def __init__(self, 
                 loss_type: str = 'mse',
                 huber_delta: float = 1.0,
                 quantile_tau: float = 0.5,
                 smooth_l1_beta: float = 1.0):
        super().__init__()
        self.loss_type = loss_type
        self.huber_delta = huber_delta
        self.quantile_tau = quantile_tau
        self.smooth_l1_beta = smooth_l1_beta
        
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Compute regression loss.
        
        Args:
            pred: [batch_size, *] - predicted values
            target: [batch_size, *] - target values
            
        Returns:
            Regression loss scalar
        """
        if self.loss_type == 'mse':
            return F.mse_loss(pred, target)
        elif self.loss_type == 'mae':
            return F.l1_loss(pred, target)
        elif self.loss_type == 'huber':
            return F.huber_loss(pred, target, delta=self.huber_delta)
        elif self.loss_type == 'smooth_l1':
            return F.smooth_l1_loss(pred, target, beta=self.smooth_l1_beta)
        elif self.loss_type == 'quantile':
            return self._quantile_loss(pred, target)
        else:
            raise ValueError(f"Unsupported loss type: {self.loss_type}")
            
    def _quantile_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """Compute quantile loss."""
        error = target - pred
        loss = torch.where(error >= 0, 
                          self.quantile_tau * error,
                          (self.quantile_tau - 1) * error)
        return loss.mean()


class GenerationLoss(nn.Module):
    """
    Loss for text generation tasks.
    """
    
    def __init__(self, 
                 loss_type: str = 'ce',
                 ignore_index: int = -100,
                 label_smoothing: float = 0.0,
                 length_penalty: float = 1.0):
        super().__init__()
        self.loss_type = loss_type
        self.ignore_index = ignore_index
        self.label_smoothing = label_smoothing
        self.length_penalty = length_penalty
        
        if loss_type == 'ce':
            self.criterion = nn.CrossEntropyLoss(
                ignore_index=ignore_index,
                label_smoothing=label_smoothing
            )
        else:
            raise ValueError(f"Unsupported loss type: {loss_type}")
            
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Compute generation loss.
        
        Args:
            pred: [batch_size, seq_len, vocab_size] - predicted logits
            target: [batch_size, seq_len] - target token ids
            
        Returns:
            Generation loss scalar
        """
        # Reshape for cross entropy
        batch_size, seq_len, vocab_size = pred.shape
        pred_flat = pred.view(-1, vocab_size)
        target_flat = target.view(-1)
        
        # Compute loss
        loss = self.criterion(pred_flat, target_flat)
        
        # Apply length penalty if needed
        if self.length_penalty != 1.0:
            # Count non-ignored tokens
            valid_tokens = (target_flat != self.ignore_index).sum().float()
            if valid_tokens > 0:
                loss = loss * (valid_tokens ** (self.length_penalty - 1))
                
        return loss


class MultiTaskLoss(nn.Module):
    """
    Multi-task learning loss with learnable weights.
    """
    
    def __init__(self, task_names: list, initial_weights: Optional[list] = None):
        super().__init__()
        self.task_names = task_names
        self.num_tasks = len(task_names)
        
        # Learnable task weights
        if initial_weights is None:
            initial_weights = [1.0] * self.num_tasks
            
        self.task_weights = nn.Parameter(torch.tensor(initial_weights, dtype=torch.float32))
        
    def forward(self, task_losses: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Compute weighted multi-task loss.
        
        Args:
            task_losses: dict of {task_name: loss_value}
            
        Returns:
            Weighted multi-task loss scalar
        """
        total_loss = 0.0
        
        for i, task_name in enumerate(self.task_names):
            if task_name in task_losses:
                weight = torch.exp(-self.task_weights[i])  # Ensure positive weights
                total_loss += weight * task_losses[task_name] + self.task_weights[i]
                
        return total_loss


class ConsistencyLoss(nn.Module):
    """
    Consistency loss for semi-supervised learning.
    """
    
    def __init__(self, consistency_type: str = 'mse', temperature: float = 0.5):
        super().__init__()
        self.consistency_type = consistency_type
        self.temperature = temperature
        
    def forward(self, pred1: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor:
        """
        Compute consistency loss between two predictions.
        
        Args:
            pred1: [batch_size, *] - first prediction
            pred2: [batch_size, *] - second prediction
            
        Returns:
            Consistency loss scalar
        """
        if self.consistency_type == 'mse':
            return F.mse_loss(pred1, pred2)
        elif self.consistency_type == 'kl':
            # KL divergence between softmax outputs
            log_prob1 = F.log_softmax(pred1 / self.temperature, dim=-1)
            prob2 = F.softmax(pred2 / self.temperature, dim=-1)
            return F.kl_div(log_prob1, prob2, reduction='batchmean')
        elif self.consistency_type == 'cosine':
            # Cosine similarity loss
            pred1_norm = F.normalize(pred1.view(pred1.size(0), -1), dim=-1)
            pred2_norm = F.normalize(pred2.view(pred2.size(0), -1), dim=-1)
            similarity = torch.sum(pred1_norm * pred2_norm, dim=-1)
            return (1 - similarity).mean()
        else:
            raise ValueError(f"Unsupported consistency type: {self.consistency_type}")


class TemporalConsistencyLoss(nn.Module):
    """
    Temporal consistency loss for time series tasks.
    """
    
    def __init__(self, 
                 consistency_type: str = 'smoothness',
                 smoothness_weight: float = 1.0,
                 periodicity_weight: float = 0.1):
        super().__init__()
        self.consistency_type = consistency_type
        self.smoothness_weight = smoothness_weight
        self.periodicity_weight = periodicity_weight
        
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Compute temporal consistency loss.
        
        Args:
            pred: [batch_size, seq_len, feature_dim] - predicted sequence
            target: [batch_size, seq_len, feature_dim] - target sequence
            
        Returns:
            Temporal consistency loss scalar
        """
        total_loss = 0.0
        
        if self.consistency_type in ['smoothness', 'both']:
            # Temporal smoothness loss
            smoothness_loss = self._temporal_smoothness(pred)
            total_loss += self.smoothness_weight * smoothness_loss
            
        if self.consistency_type in ['periodicity', 'both']:
            # Periodicity loss
            periodicity_loss = self._periodicity_loss(pred)
            total_loss += self.periodicity_weight * periodicity_loss
            
        return total_loss
        
    def _temporal_smoothness(self, pred: torch.Tensor) -> torch.Tensor:
        """Compute temporal smoothness loss."""
        # First-order difference
        diff = pred[:, 1:] - pred[:, :-1]
        return torch.mean(torch.abs(diff))
        
    def _periodicity_loss(self, pred: torch.Tensor) -> torch.Tensor:
        """Compute periodicity loss."""
        # This is a simplified periodicity loss
        # For production use, consider more sophisticated approaches
        
        # Compute autocorrelation
        batch_size, seq_len, feature_dim = pred.shape
        pred_centered = pred - pred.mean(dim=1, keepdim=True)
        
        # Simple autocorrelation at lag 1
        autocorr = torch.sum(pred_centered[:, 1:] * pred_centered[:, :-1], dim=(1, 2))
        autocorr = autocorr / (seq_len - 1)
        
        # Loss: minimize autocorrelation (encourage independence)
        return torch.abs(autocorr).mean() 