import torch
import torch.nn as nn

from stork.loss_stacks import LossStack
import stork

class CSTLossStack(LossStack):
    def __init__(self, 
                 mask=None, 
                 density_weighting_func=False):
        super().__init__()
        self.mask = mask
        self.loss_fn = None  # to be defined in the child class
        self.density_weighting_func = density_weighting_func

    def get_R2(self, pred, target):
        # Julian Rossbroich
        # modified july 2024
        """
        Args:
            pred: Predicted series of the model (batch_size * timestep * nb_outputs),
            target: Ground truth series (batch_size * timestep * nb_outputs).

        Return:
            r2: R-squared between the inputs along consecutive axis, over a batch.
        """

        # For each feature, calculate R2
        # We use the mean across all samples to calculate sst
        ssr = torch.sum((target - pred) ** 2, dim=(0, 1))
        sst = torch.sum((target - torch.mean(target, dim=(0, 1))) ** 2, dim=(0, 1))
        r2 = (1 - ssr / sst).detach().cpu().numpy()

        if r2.shape[0]==1:
            return [float(r2[0].round(3)), float(r2.mean().round(3))]
        else:
            return [float(r2[0].round(3)), float(r2[1].round(3)), float(r2.mean().round(3))]
        # return [float(r2[0].round(3)), float(r2[1].round(3)), float(r2.mean().round(3))]

    def get_metric_names(self, outputnum=2):
        # Julian Rossbroich
        # modified july 2024
        if outputnum==2:
            return ["r2x", "r2y", "r2"]
        else:
            return ["r2x", "r2"]

    def compute_loss(self, output, target):
        """Computes MSQE loss between output and target."""

        if self.mask is not None:
            output = output * self.mask.expand_as(output)
            target = target * self.mask.expand_as(output)
            
        if self.density_weighting_func:
            weight = self.density_weighting_func(target)
        else:
            weight = None

        self.metrics = self.get_R2(output, target)
        return self.loss_fn(output, target, weight=weight)

    def predict(self, output):
        return output

    def __call__(self, output, targets):
        return self.compute_loss(output, targets)


class MeanSquareError(CSTLossStack):
    def __init__(self, mask=None, density_weighting_func=False):
        super().__init__(mask=mask, density_weighting_func=density_weighting_func)
        self.loss_fn = self._weighted_MSEloss
        
    def _weighted_MSEloss(self, output, target, weight=None):
        if weight is not None:
            return torch.mean(weight * (output - target) ** 2)
        else:
            return torch.mean((output - target) ** 2)


class RootMeanSquareError(CSTLossStack):

    def __init__(self, mask=None, density_weighting_func=False):
        super().__init__(mask=mask, density_weighting_func=density_weighting_func)
        self.loss_fn = self._weighted_RMSEloss

    def _weighted_RMSEloss(self, output, target, weight=None):
        if weight is not None:
            return torch.sqrt(torch.mean(weight * (output - target) ** 2))
        else:
            return torch.sqrt(torch.mean((output - target) ** 2))


class MeanAbsoluteError(CSTLossStack):
    def __init__(self, mask=None, density_weighting_func=False):
        super().__init__(mask=mask, density_weighting_func=density_weighting_func)
        self.loss_fn = self._weighted_MAEloss

    def _weighted_MAEloss(self, output, target, weight=None):
        if weight is not None:
            return torch.mean(weight * torch.abs(output - target))
        else:
            return torch.mean(torch.abs(output - target))


class HuberLoss(CSTLossStack):
    def __init__(self, delta=1.0, mask=None, density_weighting_func=False):
        
        if density_weighting_func:
            raise ValueError("Density weighting not supported for Huber loss.")
        
        super().__init__(mask=mask)
        self.loss_fn = nn.SmoothL1Loss(beta=delta)
        self.delta = delta


class RootMeanSquareError_with_MaxOverTimeCrossEntropy(CSTLossStack):
    def __init__(self,
                 num_classes: int,
                 time_dim: int = 1,
                 alpha: float = 1.0,
                 beta: float = 1.0,
                 mask=None,
                 density_weighting_func=False):
        """
        多任务损失：RMSE回归损失 + 时序最大交叉熵分类损失

        参数：
        - num_classes: 分类任务类别数
        - time_dim: 时间维度索引 (默认1)
        - alpha: RMSE损失权重
        - beta: 交叉熵损失权重
        """
        super().__init__(mask=mask, density_weighting_func=density_weighting_func)

        # 初始化子损失组件
        self.rmse_module = RootMeanSquareError(mask, density_weighting_func)
        self.ce_module = stork.loss_stacks.MaxOverTimeCrossEntropy(time_dimension=time_dim)

        # 配置多任务参数
        self.num_classes = num_classes
        self.time_dim = time_dim
        self.alpha = alpha
        self.beta = beta

    def split_output(self, output):
        """分割输出张量为回归和分类部分"""
        assert output.shape[-1]-2 == self.num_classes, "输出维度需等于类别数"
        # split_idx = output.shape[-1] - self.num_classes
        return output[..., :2], output[..., 2:]

    def compute_loss(self, output, targets):
        """
        输入：
        - output: 形状 (B, T, D+C)
                  其中D为回归特征维度，C为分类类别数
        - targets: 元组 (reg_target, cls_target)
                  reg_target形状 (B, T, D)
                  cls_target形状 (B,)
        """
        # 分割输入和目标
        reg_target = targets[:,:,:2]
        cls_target = targets[:,0,-1:].squeeze(dim=1).long()
        output_reg, output_cls = self.split_output(output)

        # 计算回归损失
        rmse_loss = self.rmse_module(output_reg, reg_target)

        # 计算分类损失
        ce_loss = self.ce_module(output_cls, cls_target)

        # 合并指标
        self.metrics = self.rmse_module.metrics + self.ce_module.metrics

        # 加权总损失
        return self.alpha * rmse_loss + self.beta * ce_loss

    def get_metric_names(self, outputnum=2):
        """组合指标名称"""
        return self.rmse_module.get_metric_names(outputnum) + self.ce_module.get_metric_names()

    def predict(self, output):
        """返回元组 (回归预测, 分类预测)"""
        output_reg, output_cls = self.split_output(output)
        return self.rmse_module.predict(output_reg), self.ce_module.predict(output_cls)