import torch
import torch.nn as nn
import torch.fft


class moving_avg(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        assert kernel_size % 2 == 1, "kernel_size 为奇数"
        self.avg = nn.AvgPool1d(
            kernel_size=kernel_size,
            stride=1,
            padding=kernel_size//2,
            count_include_pad=False  # 忽略填充的零值
        )

    def forward(self, x):
        # 输入 x 形状: (B, L, C)
        x = x.permute(0, 2, 1)  # (B, C, L)
        trend = self.avg(x)
        trend = trend.permute(0, 2, 1)  # (B, L, C)
        return trend


class series_decomp(nn.Module):
    """
    Series decomposition block
    """

    def __init__(self, kernel_size):
        super(series_decomp, self).__init__()
        self.moving_avg = moving_avg(kernel_size)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean


class PhaseLoss(nn.Module):
    """ 计算两个时序信号相位差的损失（频域角度谐振保持）"""
    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        # 进行FFT
        pred_fft = torch.fft.rfft(pred, dim=-1)
        target_fft = torch.fft.rfft(target, dim=-1)

        # 获取相位
        pred_phase = torch.angle(pred_fft)
        target_phase = torch.angle(target_fft)

        # 计算相位差
        phase_diff = torch.abs(pred_phase - target_phase)

        # 映射到[0, pi]
        phase_diff = torch.fmod(phase_diff + torch.pi, 2 * torch.pi) - torch.pi
        phase_diff = phase_diff.abs()

        return phase_diff.mean()

class LogCoshLoss(nn.Module):
    """ 趋势项缓变+弹性损失，抑制大偏差 """
    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        diff = pred - target
        return torch.mean(torch.log(torch.cosh(diff + 1e-12)))

class SeasonalTrendLoss(nn.Module):
    """
    物理建模版损失函数
    - seasonal: MSE + PhaseLoss -> 谐振模式
    - trend: LogCoshLoss -> 弹性模式
    - dynamic weighted -> 综合优化
    """
    def __init__(self, phase_weight=1.0):
        super().__init__()
        self.mse = nn.MSELoss()
        self.phase = PhaseLoss()
        self.trend_loss = LogCoshLoss()
        self.phase_weight = phase_weight
        self.decomp = series_decomp(kernel_size=25)
        self.eps = 1e-6
        self.beta=1.5
        self.w_s = 0.5
        print('新的损失函数')

    def dynamic_weight(self, seasonal_loss, trend_loss):
        # softmax加权
        exp_seasonal = torch.exp(seasonal_loss.detach())
        exp_trend = torch.exp(trend_loss.detach())
        total = exp_seasonal + exp_trend

        w_seasonal = exp_seasonal / total
        w_trend = exp_trend / total

        return w_seasonal, w_trend

    def forward(self, pred, true):
        pred_seasonal, pred_trend = self.decomp(pred)
        true_seasonal, true_trend = self.decomp(true)

        # 1. 计算季节项损失 (谐振模式)
        mse_loss = self.mse(pred_seasonal, true_seasonal)
        # phase_loss = self.phase(pred_seasonal, true_seasonal)
        seasonal_loss = mse_loss

        # 2. 计算趋势项损失 (弹性模式)
        # trend_loss = self.trend_loss(pred_trend, true_trend)
        trend_diff = (pred_trend - true_trend).abs()
        trend_diff = torch.pow(trend_diff, self.beta)
        trend_loss = torch.mean(torch.log1p(self.eps + trend_diff))


        # 3. 动态加权
        # w_seasonal, w_trend = self.dynamic_weight(seasonal_loss, trend_loss)

        w_seasonal = self.w_s
        w_trend = 1 - self.w_s

        # 4. 组合总损失
        total_loss = w_seasonal * seasonal_loss + w_trend * trend_loss

        return total_loss


