# This source code is provided for the purposes of scientific reproducibility
# under the following limited license from Element AI Inc. The code is an
# implementation of the N-BEATS model (Oreshkin et al., N-BEATS: Neural basis
# expansion analysis for interpretable time series forecasting,
# https://arxiv.org/abs/1905.10437). The copyright to the source code is
# licensed under the Creative Commons - Attribution-NonCommercial 4.0
# International license (CC BY-NC 4.0):
# https://creativecommons.org/licenses/by-nc/4.0/.  Any commercial use (whether
# for the benefit of third parties or internally in production) requires an
# explicit license. The subject-matter of the N-BEATS model and associated
# materials are the property of Element AI Inc. and may be subject to patent
# protection. No license to patents is granted hereunder (whether express or
# implied). Copyright © 2020 Element AI Inc. All rights reserved.

"""
Loss functions for PyTorch.
"""
import torch
import torch as t
import torch.nn as nn
import numpy as np
import pdb
import random

from torch.fft import fft
import torch.nn.functional as F
def divide_no_nan(a, b):
    """
    a/b where the resulted NaN or Inf are replaced by 0.
    """
    result = a / b
    result[result != result] = .0
    result[result == np.inf] = .0
    return result


class mape_loss(nn.Module):
    def __init__(self):
        super(mape_loss, self).__init__()

    def forward(self, insample: t.Tensor, freq: int,
                forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float:
        """
        MAPE loss as defined in: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error

        :param forecast: Forecast values. Shape: batch, time
        :param target: Target values. Shape: batch, time
        :param mask: 0/1 mask. Shape: batch, time
        :return: Loss value
        """
        weights = divide_no_nan(mask, target)
        return t.mean(t.abs((forecast - target) * weights))


class smape_loss(nn.Module):
    def __init__(self):
        super(smape_loss, self).__init__()

    def forward(self, insample: t.Tensor, freq: int,
                forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float:
        """
        sMAPE loss as defined in https://robjhyndman.com/hyndsight/smape/ (Makridakis 1993)

        :param forecast: Forecast values. Shape: batch, time
        :param target: Target values. Shape: batch, time
        :param mask: 0/1 mask. Shape: batch, time
        :return: Loss value
        """
        return 200 * t.mean(divide_no_nan(t.abs(forecast - target),
                                          t.abs(forecast.data) + t.abs(target.data)) * mask)

# 确定主频
def main_freq_part(x, k=10, rfft=True):
    # freq normalization
    # start = time.time()
    if rfft:
        xf = torch.fft.rfft(x, dim=1)
    else:
        xf = torch.fft.fft(x, dim=1)

    k_values = torch.topk(xf.abs(), k, dim=1)
    indices = k_values.indices

    mask = torch.zeros_like(xf)
    mask.scatter_(1, indices, 1)
    xf_filtered = xf * mask

    if rfft:
        x_filtered = torch.fft.irfft(xf_filtered, dim=1).real.float()
    else:
        x_filtered = torch.fft.ifft(xf_filtered, dim=1).real.float()

    norm_input = x - x_filtered
    # print(f"decompose take:{ time.time() - start} s")
    return norm_input, x_filtered

class mase_loss(nn.Module):
    def __init__(self):
        super(mase_loss, self).__init__()

    def forward(self, insample: t.Tensor, freq: int,
                forecast: t.Tensor, target: t.Tensor, mask: t.Tensor) -> t.float:
        """
        MASE loss as defined in "Scaled Errors" https://robjhyndman.com/papers/mase.pdf

        :param insample: Insample values. Shape: batch, time_i
        :param freq: Frequency value
        :param forecast: Forecast values. Shape: batch, time_o
        :param target: Target values. Shape: batch, time_o
        :param mask: 0/1 mask. Shape: batch, time_o
        :return: Loss value
        """
        masep = t.mean(t.abs(insample[:, freq:] - insample[:, :-freq]), dim=1)
        masked_masep_inv = divide_no_nan(mask, masep[:, None])
        return t.mean(t.abs(target - forecast) * masked_masep_inv)



class moving_avg(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        assert kernel_size % 2 == 1, "kernel_size 为奇数"
        self.avg = nn.AvgPool1d(
            kernel_size=kernel_size,
            stride=1,
            padding=kernel_size//2,
            count_include_pad=False  # 忽略填充的零值
        )

    def forward(self, x):
        # 输入 x 形状: (B, L, C)
        x = x.permute(0, 2, 1)  # (B, C, L)
        trend = self.avg(x)
        trend = trend.permute(0, 2, 1)  # (B, L, C)
        return trend

class series_decomp(nn.Module):
    """
    Series decomposition block
    """

    def __init__(self, kernel_size):
        super(series_decomp, self).__init__()
        self.moving_avg = moving_avg(kernel_size)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean


class series_decomp_multi(nn.Module):
    """
    Series decomposition block
    """
    def __init__(self, kernel_size):
        super(series_decomp_multi, self).__init__()
        # self.moving_avg = [moving_avg(kernel) for kernel in kernel_size]
        self.moving_avg = nn.ModuleList([moving_avg(k) for k in kernel_size])
        self.layer = torch.nn.Linear(1, len(kernel_size))
        print("Using Multi Series Decomposition block")

    def forward(self, x):
        moving_mean=[]
        for func in self.moving_avg:
            moving_avg = func(x)
            moving_mean.append(moving_avg.unsqueeze(-1))
        moving_mean=torch.cat(moving_mean,dim=-1)
        moving_mean = torch.sum(moving_mean*nn.Softmax(-1)(self.layer(x.unsqueeze(-1))),dim=-1)
        res = x - moving_mean
        return res, moving_mean
    
    
class seasonal_trend_loss(nn.Module):
    def __init__(self, moving_avg=25, beta=1.0, strict='exp'):
        super(multi_fourier_phase_loss, self).__init__()
        if strict == 'multi':
            print("using the learnable decomp")
            moving_avg = [7, 13, 15, 25, 49]
            self.decomp = series_decomp_multi(moving_avg).to('cuda:0')
        else:
            print("Using the fixed decomp")
            self.decomp = series_decomp(moving_avg)
        self.strict = strict
        self.rfft = True
        self.mse = nn.MSELoss()
        self.eps = 1e-8
        self.freq_weight = nn.Parameter(torch.ones(1, 7, 49)).cuda() 
        self.beta = beta
        print('beta"', beta)


        print("Using dynamic seasonal and trend loss #")


    def dynamic_weighted_loss(self, seasonal_loss, trend_loss):
        seasonal_loss_detached = seasonal_loss.detach()
        trend_loss_detached = trend_loss.detach()

        m = torch.max(torch.stack([seasonal_loss_detached, trend_loss_detached]))

        exp_seasonal = torch.exp(self.beta * (seasonal_loss_detached - m))
        exp_trend = torch.exp(self.beta * (trend_loss_detached - m))

        lambda_s = exp_seasonal / (exp_seasonal + exp_trend)
        lambda_tau = exp_trend / (exp_seasonal + exp_trend)

        total_loss = lambda_s * seasonal_loss + lambda_tau * trend_loss

        return total_loss


    def forward(self, pred, true):
        pred_seasonal, pred_trend = self.decomp(pred)
        true_seasonal, true_trend = self.decomp(true)

        seasonal_loss = self.mse(pred_seasonal, true_seasonal)
        trend_diff = (pred_trend - true_trend)
        trend_loss = torch.mean(torch.log1p(self.eps + trend_diff.abs()))
        
        total_loss = self.dynamic_weighted_loss(seasonal_loss, trend_loss)


        return total_loss




def amp_loss(outputs, targets):
    # outputs = B, T, 1 --> B, 1, T
    B, _, T = outputs.shape
    fft_size = 1 << (2 * T - 1).bit_length()
    out_fourier = torch.fft.fft(outputs, fft_size, dim=-1)
    tgt_fourier = torch.fft.fft(targets, fft_size, dim=-1)

    out_norm = torch.norm(outputs, dim=-1, keepdim=True)
    tgt_norm = torch.norm(targets, dim=-1, keepdim=True)

    # calculate normalized auto correlation
    auto_corr = torch.fft.ifft(tgt_fourier * tgt_fourier.conj(), dim=-1).real
    auto_corr = torch.cat([auto_corr[..., -(T - 1):], auto_corr[..., :T]], dim=-1)
    nac_tgt = auto_corr / (tgt_norm * tgt_norm)

    # calculate cross correlation
    cross_corr = torch.fft.ifft(tgt_fourier * out_fourier.conj(), dim=-1).real
    cross_corr = torch.cat([cross_corr[..., -(T - 1):], cross_corr[..., :T]], dim=-1)
    nac_out = cross_corr / (tgt_norm * out_norm)

    loss = torch.mean(torch.abs(nac_tgt - nac_out))
    return loss


def ashift_loss(outputs, targets):
    B, _, T = outputs.shape
    return T * torch.mean(torch.abs(1 / T - torch.softmax(outputs - targets, dim=-1)))


def phase_loss(outputs, targets):
    B, _, T = outputs.shape
    out_fourier = torch.fft.fft(outputs, dim=-1)
    tgt_fourier = torch.fft.fft(targets, dim=-1)
    tgt_fourier_sq = (tgt_fourier.real ** 2 + tgt_fourier.imag ** 2)
    mask = (tgt_fourier_sq > (T)).float()
    topk_indices = tgt_fourier_sq.topk(k=int(T ** 0.5), dim=-1).indices
    mask = mask.scatter_(-1, topk_indices, 1.)
    mask[..., 0] = 1.
    mask = torch.where(mask > 0, 1., 0.)
    mask = mask.bool()
    not_mask = (~mask).float()
    not_mask /= torch.mean(not_mask)
    out_fourier_sq = (torch.abs(out_fourier.real) + torch.abs(out_fourier.imag))
    zero_error = torch.abs(out_fourier) * not_mask
    zero_error = torch.where(torch.isnan(zero_error), torch.zeros_like(zero_error), zero_error)
    mask = mask.float()
    mask /= torch.mean(mask)
    ae = torch.abs(out_fourier - tgt_fourier) * mask
    ae = torch.where(torch.isnan(ae), torch.zeros_like(ae), ae)
    phase_loss = (torch.mean(zero_error) + torch.mean(ae)) / (T ** .5)
    return phase_loss


def tildeq_loss(outputs, targets, alpha=.5, gamma=.0, beta=.5):
    outputs = outputs.permute(0, 2, 1)
    targets = targets.permute(0, 2, 1)
    assert not torch.isnan(outputs).any(), "Nan value detected!"
    assert not torch.isinf(outputs).any(), "Inf value detected!"
    B, _, T = outputs.shape
    l_ashift = ashift_loss(outputs, targets)
    l_amp = amp_loss(outputs, targets)
    l_phase = phase_loss(outputs, targets)
    loss = alpha * l_ashift + (1 - alpha) * l_phase + gamma * l_amp

    assert loss == loss, "Loss Nan!"
    return loss
