import numpy as np
import torch
import torch.nn as nn
import torch.fft as fft
import torch.nn.functional as F
from torch.nn.utils import weight_norm
import math


class PositionalEmbedding(nn.Module):

    def __init__(self, d_model, max_len=5000):

        super(PositionalEmbedding, self).__init__()

        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float()
                    * -(math.log(10000.0) / d_model)).exp()
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):

        return self.pe[:, :x.size(1)]


class TokenEmbedding(nn.Module):

    def __init__(self, c_in, d_model):

        super(TokenEmbedding, self).__init__()
        padding = 1 if torch.__version__ >= '1.5.0' else 2
        self.token_conv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
                                    kernel_size=3, padding=padding, padding_mode='circular', bias=False)
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_in', nonlinearity='leaky_relu')

    def forward(self, x):

        x = self.token_conv(x.permute(0, 2, 1)).transpose(1, 2)
        return x


class FixedEmbedding(nn.Module):
    def __init__(self, c_in, d_model):
        super(FixedEmbedding, self).__init__()

        w = torch.zeros(c_in, d_model).float()
        w.require_grad = False

        position = torch.arange(0, c_in).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float()
                    * -(math.log(10000.0) / d_model)).exp()

        w[:, 0::2] = torch.sin(position * div_term)
        w[:, 1::2] = torch.cos(position * div_term)

        self.emb = nn.Embedding(c_in, d_model)
        self.emb.weight = nn.Parameter(w, requires_grad=False)

    def forward(self, x):
        return self.emb(x).detach()


class TemporalEmbedding(nn.Module):
    def __init__(self, d_model, embed_type='fixed', freq='h'):
        super(TemporalEmbedding, self).__init__()

        minute_size = 4
        hour_size = 24
        weekday_size = 7
        day_size = 32
        month_size = 13

        Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding
        if freq == 't':
            self.minute_embed = Embed(minute_size, d_model)
        self.hour_embed = Embed(hour_size, d_model)
        self.weekday_embed = Embed(weekday_size, d_model)
        self.day_embed = Embed(day_size, d_model)
        self.month_embed = Embed(month_size, d_model)

    def forward(self, x):
        x = x.long()

        minute_x = self.minute_embed(x[:, :, 4]) if hasattr(
            self, 'minute_embed') else 0.
        hour_x = self.hour_embed(x[:, :, 3])
        weekday_x = self.weekday_embed(x[:, :, 2])
        day_x = self.day_embed(x[:, :, 1])
        month_x = self.month_embed(x[:, :, 0])

        return hour_x + weekday_x + day_x + month_x + minute_x


class TimeFeatureEmbedding(nn.Module):
    def __init__(self, d_model, embed_type='timeF', freq='h'):
        super(TimeFeatureEmbedding, self).__init__()
        freq_map = {'h': 4, 't': 5, 's': 6,
                    'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
        d_inp = freq_map[freq]
        self.embed = nn.Linear(d_inp, d_model, bias=False)

    def forward(self, x):
        return self.embed(x)


class DataEmbedding(nn.Module):

    def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):

        super(DataEmbedding, self).__init__()
        self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
        self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
                                                    freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
            d_model=d_model, embed_type=embed_type, freq=freq)
        self.position_embedding = PositionalEmbedding(d_model=d_model)

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, x_mark):

        if x_mark is None:
            x = self.value_embedding(x) + self.position_embedding(x)
        else:
            x = self.value_embedding(
                x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
        return self.dropout(x)


class InvertedDataEmbedding(nn.Module):
    def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
        super(InvertedDataEmbedding, self).__init__()
        self.value_embedding = nn.Linear(c_in, d_model)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, x_mark):
        x = x.permute(0, 2, 1)

        if x_mark is None:
            x = self.value_embedding(x)
        else:

            x = self.value_embedding(
                torch.cat([x, x_mark.permute(0, 2, 1)], 1))

        return self.dropout(x)


class DataEmbeddingWithoutPositional(nn.Module):
    def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
        super(DataEmbeddingWithoutPositional, self).__init__()

        self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
        self.position_embedding = PositionalEmbedding(d_model=d_model)
        self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
                                                    freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
            d_model=d_model, embed_type=embed_type, freq=freq)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, x_mark):
        x = self.value_embedding(x) + self.temporal_embedding(x_mark)
        return self.dropout(x)


class DataEmbeddingWithoutPositionalTemporal(nn.Module):
    def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
        super(DataEmbeddingWithoutPositionalTemporal, self).__init__()

        self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
        self.position_embedding = PositionalEmbedding(d_model=d_model)
        self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
                                                    freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
            d_model=d_model, embed_type=embed_type, freq=freq)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, x_mark):
        x = self.value_embedding(x)
        return self.dropout(x)


class DataEmbeddingWithoutTemporal(nn.Module):
    def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
        super(DataEmbeddingWithoutTemporal, self).__init__()

        self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
        self.position_embedding = PositionalEmbedding(d_model=d_model)
        self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type,
                                                    freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(
            d_model=d_model, embed_type=embed_type, freq=freq)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, x_mark):
        x = self.value_embedding(x) + self.position_embedding(x)
        return self.dropout(x)


class ComplexFrequencyEmbedding(nn.Module):
    def __init__(self, seq_len, d_model):
        super(ComplexFrequencyEmbedding, self).__init__()
        self.linear_real = nn.Linear(2 * seq_len, d_model)
        self.linear_imag = nn.Linear(2 * seq_len, d_model)

    def forward(self, x, x_mark):
        x = x.permute(0, 2, 1)
        B, N, L = x.size()
        x_fft = torch.fft.fft(x, n=2 * L)

        x_real = self.linear_real(x_fft.real)
        x_imag = self.linear_imag(x_fft.imag)
        x = torch.complex(x_real, x_imag)
        return x


class InterpolatedFrequencyEmbedding(nn.Module):
    def __init__(self, seq_len, d_model):
        super(InterpolatedFrequencyEmbedding, self).__init__()
        self.seq_len = seq_len
        self.d_model = d_model

    def forward(self, x, x_mark):
        x = x.permute(0, 2, 1)
        B, N, L = x.size()
        x_fft = torch.fft.fft(x, n=2 * L)
        x_fft_resampled = self.resample_fft(x_fft, self.d_model)
        return x_fft_resampled

    def resample_fft(self, x_fft, new_length):
        real_part = x_fft.real.unsqueeze(2)
        imag_part = x_fft.imag.unsqueeze(2)
        real_interpolated = F.interpolate(
            real_part, size=new_length, mode='linear', align_corners=False).squeeze(2)
        imag_interpolated = F.interpolate(
            imag_part, size=new_length, mode='linear', align_corners=False).squeeze(2)
        x_fft_resampled = torch.complex(real_interpolated, imag_interpolated)
        return x_fft_resampled


class FourierInterpolatedFrequencyEmbedding(nn.Module):
    def __init__(self, seq_len, d_model, c_in):
        super(FourierInterpolatedFrequencyEmbedding, self).__init__()
        self.seq_len = seq_len
        self.d_model = d_model

        self.scalars = nn.Parameter(torch.ones(
            c_in, d_model), requires_grad=True)

        self.bias = nn.Parameter(torch.zeros(
            c_in, d_model, dtype=torch.cfloat), requires_grad=True)

    def forward(self, x, x_mark):

        x = x.permute(0, 2, 1)
        B, N, L = x.size()

        x_fft = torch.fft.rfft(x, n=2 * L)

        x_fft_resampled = self.fourier_interpolate(x_fft, self.d_model)

        x_fft_resampled = x_fft_resampled * self.scalars + self.bias

        return x_fft_resampled

    def fourier_interpolate(self, x_fft, new_length):
        B, N, L = x_fft.shape
        if new_length > L:

            resampled_data = torch.zeros(B, N, new_length, dtype=torch.cfloat,
                                         device=x_fft.device) + 0.0001

            resampled_data[:, :, :L] = x_fft
        else:

            resampled_data = x_fft[:, :, :new_length]

        return resampled_data
