

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, Optional, Tuple


import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, Optional, Tuple


class AdvancedForecastingLoss(nn.Module):
    def __init__(self, 
                 horizon: int = 24,
                 alpha: float = 0.5,     
                 beta: float = 0.3,       
                 gamma: float = 0.2,    
                 delta: float = 1.0,       # Huber delta
                 decay_rate: float = 12.0, 
                 use_temporal_weighting: bool = True,
                 use_trend_loss: bool = True,
                 use_quantile_loss: bool = True,
                 quantiles: list = [0.1, 0.5, 0.9],
                 use_magnitude_penalty: bool = True, 
                 magnitude_weight: float = 0.01,   
                 min_magnitude_ratio: float = 0.8):    
        super().__init__()
        self.horizon = horizon
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.delta = delta
        self.decay_rate = decay_rate
        self.use_temporal_weighting = use_temporal_weighting
        self.use_trend_loss = use_trend_loss
        self.use_quantile_loss = use_quantile_loss
        self.quantiles = quantiles
        
        self.use_magnitude_penalty = use_magnitude_penalty
        self.magnitude_weight = magnitude_weight
        self.min_magnitude_ratio = min_magnitude_ratio
        
        self.mse_loss = nn.MSELoss(reduction='none')
        self.huber_loss = nn.HuberLoss(reduction='none', delta=delta)
        self.l1_loss = nn.L1Loss(reduction='none')
        
        if use_temporal_weighting:
            self.temporal_weights = torch.exp(-torch.arange(horizon) / decay_rate)
        
        if use_trend_loss:
            self.trend_weights = self._compute_trend_weights(horizon)
    
    def _compute_trend_weights(self, horizon: int) -> torch.Tensor:
        weights = torch.linspace(0.5, 1.5, horizon)
        return weights
    
    def _temporal_weighted_loss(self, loss: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        if not self.use_temporal_weighting:
            return loss.mean()
        
        weights = self.temporal_weights.to(loss.device)
        weights = weights.unsqueeze(0).expand_as(loss)
        weighted_loss = loss * weights
        return weighted_loss.sum() / weights.sum()
    
    def _trend_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        if not self.use_trend_loss:
            return torch.tensor(0.0, device=pred.device)
        
        pred_diff = torch.diff(pred, dim=1)
        target_diff = torch.diff(target, dim=1)
        
        trend_loss = F.mse_loss(pred_diff, target_diff)
        
        pred_trend_strength = torch.abs(pred_diff).mean()
        target_trend_strength = torch.abs(target_diff).mean()
        strength_loss = F.mse_loss(pred_trend_strength, target_trend_strength)
        
        return trend_loss + 0.1 * strength_loss
    
    def _quantile_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        if not self.use_quantile_loss:
            return torch.tensor(0.0, device=pred.device)
        
        quantile_loss = 0.0
        for q in self.quantiles:
            error = target - pred
            loss_q = torch.max(q * error, (q - 1) * error)
            quantile_loss += loss_q.mean()
        
        return quantile_loss / len(self.quantiles)
    
    def _magnitude_penalty(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        if not self.use_magnitude_penalty:
            return torch.tensor(0.0, device=pred.device)
        
        pred_magnitude = torch.mean(torch.abs(pred))
        target_magnitude = torch.mean(torch.abs(target))
        
        if target_magnitude < 1e-8:
            return torch.tensor(0.0, device=pred.device)
        
        magnitude_ratio = pred_magnitude / target_magnitude
        
        if magnitude_ratio < self.min_magnitude_ratio:
            penalty = torch.exp(-magnitude_ratio) * 5
            return penalty
        else:
            return torch.tensor(0.0, device=pred.device)
    
    def _variance_penalty(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        pred_var = torch.var(pred)
        target_var = torch.var(target)
        
        if target_var < 1e-8:
            return torch.tensor(0.0, device=pred.device)
        
        var_ratio = pred_var / target_var
        if var_ratio < 0.1:  
            penalty = torch.exp(-var_ratio * 10) * 0.01
            return penalty
        else:
            return torch.tensor(0.0, device=pred.device)
    
    def forward(self, pred: torch.Tensor, target: torch.Tensor, 
                trend: Optional[torch.Tensor] = None, 
                residual: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:

        if pred.dim() == 1 and target.dim() == 1:
            pred = pred.unsqueeze(1)
            target = target.unsqueeze(1)
        elif pred.dim() == 1 and target.dim() == 2:
            pred = pred.unsqueeze(1).expand(-1, target.size(1))
        elif pred.dim() == 2 and target.dim() == 1:
            target = target.unsqueeze(1).expand(-1, pred.size(1))
        elif pred.dim() == 2 and target.dim() == 2:
            if pred.size(1) != target.size(1):
                min_dim = min(pred.size(1), target.size(1))
                pred = pred[:, :min_dim]
                target = target[:, :min_dim]
        else:
            pred = pred.view(pred.size(0), -1)
            target = target.view(target.size(0), -1)
            if pred.size(1) != target.size(1):
                min_dim = min(pred.size(1), target.size(1))
                pred = pred[:, :min_dim]
                target = target[:, :min_dim]
        
        losses = {}
        
        mse = self.mse_loss(pred, target)
        losses['mse'] = self._temporal_weighted_loss(mse, target)
        
        huber = self.huber_loss(pred, target)
        losses['huber'] = self._temporal_weighted_loss(huber, target)
        
        l1 = self.l1_loss(pred, target)
        losses['l1'] = self._temporal_weighted_loss(l1, target)
        
        if self.use_trend_loss:
            losses['trend'] = self._trend_loss(pred, target)
        
        if self.use_quantile_loss:
            losses['quantile'] = self._quantile_loss(pred, target)
        
        if self.use_magnitude_penalty:
            losses['magnitude_penalty'] = self._magnitude_penalty(pred, target)
        
        losses['variance_penalty'] = self._variance_penalty(pred, target)
        
        if trend is not None and residual is not None:
            reconstruction_loss = F.mse_loss(trend + residual, pred)
            losses['reconstruction'] = reconstruction_loss
            
            trend_smoothness = F.mse_loss(torch.diff(trend, dim=1), torch.zeros_like(torch.diff(trend, dim=1)))
            losses['trend_smoothness'] = trend_smoothness
            
            residual_sparsity = torch.abs(residual).mean()
            losses['residual_sparsity'] = residual_sparsity
        
        total_loss = (self.alpha * losses['mse'] + 
                      self.beta * losses['huber'] + 
                      self.gamma * losses.get('trend', torch.tensor(0.0, device=pred.device)))
        
        if 'magnitude_penalty' in losses:
            total_loss += losses['magnitude_penalty']
        
        if 'variance_penalty' in losses:
            total_loss += losses['variance_penalty']
        
        if 'reconstruction' in losses:
            total_loss += 0.1 * losses['reconstruction']
        
        losses['total'] = total_loss
        
        return losses



class ContrastiveLoss(nn.Module):

    def __init__(self, temperature: float = 0.07, margin: float = 0.5):
        super().__init__()
        self.temperature = temperature
        self.margin = margin
    
    def forward(self, features1: torch.Tensor, features2: torch.Tensor, 
                labels: torch.Tensor) -> torch.Tensor:
        features1 = F.normalize(features1, dim=1)
        features2 = F.normalize(features2, dim=1)
        
        similarity = torch.mm(features1, features2.t()) / self.temperature
        
        batch_size = features1.size(0)
        positive_mask = labels.unsqueeze(1) * labels.unsqueeze(0)
        negative_mask = 1 - positive_mask
        
        positive_loss = -similarity * positive_mask
        positive_loss = positive_loss.sum() / (positive_mask.sum() + 1e-8)
        
        negative_loss = torch.clamp(similarity - self.margin, min=0) * negative_mask
        negative_loss = negative_loss.sum() / (negative_mask.sum() + 1e-8)
        
        return positive_loss + negative_loss


class FocalLoss(nn.Module):

    def __init__(self, alpha: float = 1.0, gamma: float = 2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

        if target.dim() == 1:
            target_one_hot = torch.zeros_like(pred)
            target_one_hot.scatter_(1, target.unsqueeze(1), 1)
            target = target_one_hot

        ce_loss = F.cross_entropy(pred, target, reduction='none')
        
        p_t = torch.exp(-ce_loss)
        
        focal_weight = self.alpha * (1 - p_t) ** self.gamma
        
        focal_loss = focal_weight * ce_loss
        
        return focal_loss.mean()


class MultiScaleLoss(nn.Module):

    def __init__(self, scales: list = [24, 12, 6, 3, 1], 
                 scale_weights: list = [1.0, 0.8, 0.6, 0.4, 0.2]):
        super().__init__()
        self.scales = scales
        self.scale_weights = scale_weights
        self.base_loss = AdvancedForecastingLoss()
    
    def forward(self, predictions: Dict[str, torch.Tensor], 
                target: torch.Tensor) -> Dict[str, torch.Tensor]:
        total_loss = 0.0
        scale_losses = {}
        
        for i, scale in enumerate(self.scales):
            scale_key = f'forecast_scale_{scale}'
            if scale_key in predictions:
                pred_scale = predictions[scale_key]
                
                if scale != target.size(1):
                    indices = torch.linspace(0, target.size(1)-1, scale, dtype=torch.long, device=target.device)
                    target_scale = target[:, indices]
                else:
                    target_scale = target
                
                scale_loss = self.base_loss(pred_scale, target_scale)
                weighted_loss = self.scale_weights[i] * scale_loss['total']
                
                total_loss += weighted_loss
                scale_losses[f'scale_{scale}'] = scale_loss
        
        scale_losses['total'] = total_loss
        return scale_losses


class DifferentialMultiScaleLoss(nn.Module):

    def __init__(self, 
                 horizon: int = 24,
                 scales: list = [1, 3, 6, 12, 24],  
                 alpha: float = 0.4,      
                 beta: float = 0.3,       
                 gamma: float = 0.2,      
                 delta: float = 0.1,      
                 huber_delta: float = 1.0,
                 use_magnitude_penalty: bool = True,
                 magnitude_weight: float = 0.01):
        super().__init__()
        self.horizon = horizon
        self.scales = scales
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.delta = delta
        self.huber_delta = huber_delta
        self.use_magnitude_penalty = use_magnitude_penalty
        self.magnitude_weight = magnitude_weight
        
        self.mse_loss = nn.MSELoss(reduction='none')
        self.huber_loss = nn.HuberLoss(reduction='none', delta=huber_delta)
        self.l1_loss = nn.L1Loss(reduction='none')
        
        self.scale_weights = torch.ones(len(scales))
        for i, scale in enumerate(scales):
            self.scale_weights[i] = 1.0 + 0.1 * (scale - 1)
    
    def _magnitude_penalty(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        if not self.use_magnitude_penalty:
            return torch.tensor(0.0, device=pred.device)
        
        pred_magnitude = torch.mean(torch.abs(pred))
        target_magnitude = torch.mean(torch.abs(target))
        
        if target_magnitude < 1e-8:
            return torch.tensor(0.0, device=pred.device)
        
        magnitude_ratio = pred_magnitude / target_magnitude
        
        if magnitude_ratio < 0.1:  
            penalty = torch.exp(-magnitude_ratio * 10) * self.magnitude_weight
            return penalty
        else:
            return torch.tensor(0.0, device=pred.device)
    
    def _trend_consistency_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        pred_diff = torch.diff(pred, dim=1)
        target_diff = torch.diff(target, dim=1)
        
        trend_loss = F.mse_loss(pred_diff, target_diff)
        
        pred_trend_strength = torch.abs(pred_diff).mean()
        target_trend_strength = torch.abs(target_diff).mean()
        strength_loss = F.mse_loss(pred_trend_strength, target_trend_strength)
        
        return trend_loss + 0.1 * strength_loss
    
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
        losses = {}
        
        mse = self.mse_loss(pred, target)
        losses['mse'] = mse.mean()
        
        huber = self.huber_loss(pred, target)
        losses['huber'] = huber.mean()
        
        l1 = self.l1_loss(pred, target)
        losses['l1'] = l1.mean()
        
        multiscale_loss = 0.0
        for i, scale in enumerate(self.scales):
            if scale <= pred.size(1):
                pred_scale = pred[:, :scale]
                target_scale = target[:, :scale]
                
                scale_mse = F.mse_loss(pred_scale, target_scale)
                scale_huber = F.huber_loss(pred_scale, target_scale, delta=self.huber_delta)
                
                scale_loss = 0.7 * scale_mse + 0.3 * scale_huber
                multiscale_loss += self.scale_weights[i] * scale_loss
        
        losses['multiscale'] = multiscale_loss / len(self.scales)
        losses['trend'] = self._trend_consistency_loss(pred, target)       
        losses['magnitude_penalty'] = self._magnitude_penalty(pred, target)
        total_loss = (self.alpha * losses['mse'] + 
                      self.beta * losses['huber'] + 
                      self.gamma * losses['trend'] + 
                      self.delta * losses['magnitude_penalty'] +
                      0.1 * losses['multiscale'])
        
        losses['total'] = total_loss
        
        return losses

class MaskedMultiTaskLoss(nn.Module):

    def __init__(self,
                 loss_type: str = "mse",    
                 delta: float = 1.0,          
                 lambda_miss: float = 0.3,    
                 use_quality: bool = True,    
                 eps: float = 1e-8):
        super().__init__()
        self.loss_type = loss_type
        self.delta = delta
        self.lambda_miss = lambda_miss
        self.use_quality = use_quality
        self.eps = eps

    def masked_regression_loss(self, y_pred, target, feature_quality=None):
        y_true = target[..., 0]
        miss_flag = target[..., 1]

        obs_mask = (1.0 - miss_flag).to(y_pred.dtype)  
        num_obs = obs_mask.sum().clamp_min(1.0)

        if self.loss_type == "huber":
            loss_elem = F.smooth_l1_loss(y_pred, y_true, reduction="none", beta=self.delta)
        elif self.loss_type == "mae":
            loss_elem = (y_pred - y_true).abs()
        else:  # mse
            loss_elem = (y_pred - y_true).pow(2)

        if self.use_quality and feature_quality is not None:
            q = (feature_quality - feature_quality.min()) / (feature_quality.max() - feature_quality.min() + self.eps)
            q = 0.3 + 0.7 * q   
            q = q.unsqueeze(0).expand_as(loss_elem)
        else:
            q = 1.0

        loss = (loss_elem * obs_mask * q).sum() / num_obs
        return loss

    def missingness_bce_loss(self, m_pred, target):
        miss_flag = target[..., 1].to(m_pred.dtype)
        return F.binary_cross_entropy(m_pred, miss_flag, reduction="mean")

    def forward(self, output, target, feature_quality=None):
        if isinstance(output, dict):
            if len(target.shape) == 4:
                target = target.squeeze(-3)
            y_pred, m_pred = output["y_pred"], output["m_pred"]

            L_reg = self.masked_regression_loss(y_pred, target, feature_quality)

            L_miss = self.missingness_bce_loss(m_pred, target)

            L_total = L_reg + self.lambda_miss * L_miss

            return L_total, {"L_total": L_total.item(),
                             "L_reg": L_reg.item(),
                             "L_miss": L_miss.item()}
        else:
            if len(target.shape) == 4:
                target = target.squeeze(-3)
            y_pred = output
            m_pred = None
            L_reg = self.masked_regression_loss(y_pred, target, feature_quality)
            L_total = L_reg
            return L_total, {"L_total": L_total.item(),
                             "L_reg": L_reg.item(),
                             "L_miss": 0.0}

def create_loss_function(loss_type: str = "advanced", **kwargs) -> nn.Module:
    if loss_type == "advanced":
        return AdvancedForecastingLoss(**kwargs)
    elif loss_type == "adaptive":
        return AdvancedForecastingLoss(**kwargs)
    elif loss_type == "multiscale":
        return MultiScaleLoss(**kwargs)
    elif loss_type == "contrastive":
        return ContrastiveLoss(**kwargs)
    elif loss_type == "focal":
        return FocalLoss(**kwargs)
    elif loss_type == "differential_multiscale":
        return DifferentialMultiScaleLoss(**kwargs)
    else:
        raise ValueError(f"Unknown loss type: {loss_type}")

def compute_Maskedmetrics(output, target):
    if len(target.shape) == 4:
        target = target.squeeze(1)
    y_pred = output["y_pred"]          # [B, F]
    y_true = target[..., 0]            # [B, F]
    miss_flag = target[..., 1]         # [B, F]
    obs_mask = (1.0 - miss_flag).to(y_pred.dtype)

    m_pred = output.get("m_pred", None)  

    num_obs = obs_mask.sum().clamp_min(1.0)
    num_missing = miss_flag.sum()
    total_elements = miss_flag.numel()

    mse = ((y_pred - y_true) ** 2 * obs_mask).sum() / num_obs
    mae = ((y_pred - y_true).abs() * obs_mask).sum() / num_obs
    rmse = torch.sqrt(mse)

    missing_rate = (num_missing / total_elements).item()
    observation_rate = (num_obs / total_elements).item()

    obs_per_sample = obs_mask.sum(dim=1)               # [B]
    total_features = obs_mask.size(1)
    coverage_per_sample = obs_per_sample / total_features
    avg_coverage = coverage_per_sample.mean().item()
    coverage_std = coverage_per_sample.std().item()

    missing_pred_quality = {}
    if m_pred is not None:
        missing_pred_binary = (m_pred > 0.5).float()
        missing_pred_acc = ((missing_pred_binary == miss_flag).float()).mean().item()
        missing_pred_quality['missing_pred_acc'] = missing_pred_acc

        try:
            from sklearn.metrics import roc_auc_score, f1_score
            missing_pred_auc = roc_auc_score(
                miss_flag.cpu().numpy().flatten(),
                m_pred.cpu().numpy().flatten()
            )
            missing_pred_f1 = f1_score(
                miss_flag.cpu().numpy().flatten(),
                missing_pred_binary.cpu().numpy().flatten()
            )
            missing_pred_quality['missing_pred_auc'] = float(missing_pred_auc)
            missing_pred_quality['missing_pred_f1'] = float(missing_pred_f1)
        except Exception:
            missing_pred_quality['missing_pred_auc'] = 0.0
            missing_pred_quality['missing_pred_f1'] = 0.0

    y_true = y_true.squeeze(1) if y_true.ndim == 3 else y_true
    miss_flag = miss_flag.squeeze(1) if miss_flag.ndim == 3 else miss_flag
    imputation_quality = {}
    if num_missing > 0:
        missing_positions = miss_flag.bool()
        if missing_positions.any():
            y_pred_missing = y_pred[missing_positions]
            y_true_missing = y_true[missing_positions]

            imputation_mse = F.mse_loss(y_pred_missing, y_true_missing).item()
            imputation_mae = F.l1_loss(y_pred_missing, y_true_missing).item()

            imputation_quality['imputation_mse'] = imputation_mse
            imputation_quality['imputation_mae'] = imputation_mae
            imputation_quality['imputation_rmse'] = np.sqrt(imputation_mse)

            y_true_abs = torch.abs(y_true_missing)
            relative_error = torch.abs(y_pred_missing - y_true_missing) / (y_true_abs + 1e-8)
            imputation_quality['imputation_mape'] = relative_error.mean().item() * 100

    sparsity_metrics = {}
    feature_obs_rate = obs_mask.sum(dim=0) / obs_mask.size(0)  # [F]
    feature_sparsity = 1.0 - feature_obs_rate
    sparsity_metrics['avg_feature_sparsity'] = feature_sparsity.mean().item()
    sparsity_metrics['max_feature_sparsity'] = feature_sparsity.max().item()
    sparsity_metrics['min_feature_sparsity'] = feature_sparsity.min().item()

    sample_obs_rate = obs_mask.sum(dim=1) / obs_mask.size(1)  # [B]
    sample_sparsity = 1.0 - sample_obs_rate
    sparsity_metrics['avg_sample_sparsity'] = sample_sparsity.mean().item()
    sparsity_metrics['max_sample_sparsity'] = sample_sparsity.max().item()
    sparsity_metrics['min_sample_sparsity'] = sample_sparsity.min().item()

    stability_metrics = {}
    obs_mask_2d = obs_mask  # [B, F]
    if obs_mask_2d.size(0) > 1:  
        pred_obs = y_pred * obs_mask_2d
        pred_var = pred_obs.var(dim=0)  # [F]
        if (obs_mask_2d.sum(dim=0) > 0).any():
            pred_var_mean = pred_var[obs_mask_2d.sum(dim=0) > 0].mean().item()
            stability_metrics['prediction_variance'] = pred_var_mean

        if obs_mask_2d.size(1) > 1:
            pred_obs_flat = pred_obs.flatten()
            obs_mask_flat = obs_mask_2d.flatten()
            if obs_mask_flat.sum() > 1:
                pred_centered = pred_obs_flat[obs_mask_flat.bool()] - pred_obs_flat[obs_mask_flat.bool()].mean()
                if len(pred_centered) > 1:
                    pred_autocorr = torch.corrcoef(
                        torch.stack([pred_centered[:-1], pred_centered[1:]])
                    )[0, 1]
                    if not torch.isnan(pred_autocorr):
                        stability_metrics['prediction_autocorr'] = pred_autocorr.item()

    metrics = {
        "mse": mse.item(),
        "mae": mae.item(),
        "rmse": rmse.item(),

        "missing_rate": missing_rate,
        "observation_rate": observation_rate,
        "avg_coverage": avg_coverage,
        "coverage_std": coverage_std,

        **sparsity_metrics,
        **stability_metrics,
    }

    metrics.update(missing_pred_quality)

    metrics.update(imputation_quality)

    return metrics

