import math
import torch
from torch import nn


class PeriodPadding(nn.Module):
    """
    Pad the input sequence to the nearest multiple of the period.
    """

    def __init__(self, period, input_length):
        super(PeriodPadding, self).__init__()
        self.period = period
        self.pad = nn.ConstantPad1d(
            (0, self.period - input_length % self.period), 0) if input_length % self.period != 0 \
            else nn.Identity()

    def forward(self, x):
        return self.pad(x)


class SkipConnectionLayer(nn.Module):

    def __init__(self, input_size, hidden_size, skip_size, share_kernel=True):
        super(SkipConnectionLayer, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.skip_size = skip_size
        self.share_kernel = share_kernel

        self.pad_inp = PeriodPadding(self.skip_size, self.input_size)
        self.input_splits = math.ceil(self.input_size / self.skip_size)
        self.hidden_splits = math.ceil(self.hidden_size / self.skip_size)

        if self.share_kernel:
            self.layer = nn.Linear(self.input_splits, self.hidden_splits)
        else:
            self.layer = nn.ModuleList([
                nn.Linear(self.input_splits, self.hidden_splits) for _ in range(self.skip_size)
            ])

        self.flatten = nn.Flatten(start_dim=-2)

    def forward(self, x):
        # print("before:", x.size(), "skip_size:", self.skip_size)
        x = self.pad_inp(x)
        # print("after:", x.size())
        x = x.view(x.size(0), x.size(1), self.input_splits, self.skip_size)
        x = x.permute(0, 1, 3, 2)
        if self.share_kernel:
            x = self.layer(x)
        else:
            x = torch.cat([layer(x[:, :, i:i+1, :]) for i, layer in enumerate(self.layer)], dim=-2)
        x = x.permute(0, 1, 3, 2)
        x = self.flatten(x)
        x = x.view(x.size(0), x.size(1), self.hidden_splits * self.skip_size)
        return x[..., :self.hidden_size]


class SplitConnectionLayer(nn.Module):

    def __init__(self, input_size, hidden_size, split_size, share_kernel=True):
        super(SplitConnectionLayer, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.split_size = split_size
        self.share_kernel = share_kernel

        self.pad_inp = PeriodPadding(self.split_size, self.input_size)
        self.input_splits = math.ceil(self.input_size / self.split_size)
        self.hidden_split_size = math.ceil(self.hidden_size / self.input_splits)

        if self.share_kernel:
            self.layer = nn.Linear(self.split_size, self.hidden_split_size)
        else:
            self.layer = nn.ModuleList([
                nn.Linear(self.split_size, self.hidden_split_size) for _ in range(self.input_splits)
            ])

        self.flatten = nn.Flatten(start_dim=-2)

    def forward(self, x):
        x = self.pad_inp(x)
        x = x.view(x.size(0), x.size(1), self.input_splits, self.split_size)
        if self.share_kernel:
            x = self.layer(x)
        else:
            x = torch.cat([layer(x[:, :, i:i+1, :]) for i, layer in enumerate(self.layer)], dim=-2)
        x = self.flatten(x)
        x = x.view(x.size(0), x.size(1), self.hidden_split_size * self.input_splits)
        return x[..., :self.hidden_size]


class SkipMLP(nn.Module):

    def __init__(self, input_size, hidden_size, skip_size, dropout, kernel_share=True):
        super(SkipMLP, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.skip_size = skip_size
        self.dropout = dropout
        self.pad = PeriodPadding(self.skip_size, input_size)
        self.num_skips = math.ceil(input_size / self.skip_size)
        self.kernel_share = kernel_share

        if kernel_share:
            self.encoder = nn.Sequential(
                nn.Linear(self.num_skips, self.hidden_size),
                nn.SELU(),
                nn.Dropout(dropout),
                nn.Linear(self.hidden_size, self.num_skips),
            )
        else:
            for i in range(self.skip_size):
                setattr(self, f'encoder_{i}', nn.Sequential(
                    nn.Linear(self.num_skips, self.hidden_size),
                    nn.SELU(),
                    nn.Dropout(dropout),
                    nn.Linear(self.hidden_size, self.num_skips),
                ))


        self.flatten = nn.Flatten(start_dim=-2)

    def forward(self, x):
        """
        input: x: (batch_size, N, input_size)
        output: (batch_size, N, input_size)
        """
        x_pad = self.pad(x)
        # (batch_size, N, input_size) -> (batch_size, N, num_skips * skip_size)
        x_pad = x_pad.view(x_pad.size(0), x_pad.size(1), self.num_skips, self.skip_size)
        # (batch_size, N, num_skips * skip_size) -> (batch_size, N, num_skips, skip_size)
        x_pad = x_pad.permute(0, 1, 3, 2)
        # (batch_size, N, num_skips, skip_size) -> (batch_size, N, skip_size, num_skips)
        if self.kernel_share:
            x_pad = self.encoder(x_pad)
        else:
            x_pad = torch.cat([getattr(self, f'encoder_{i}')(x_pad[:, :, i:i+1, :]) for i in range(self.skip_size)], dim=-2)

        x_pad = x_pad.permute(0, 1, 3, 2)
        x_pad = self.flatten(x_pad)[..., :self.input_size]
        return x_pad


class SplitMLP(nn.Module):

    def __init__(self, input_size, hidden_size, split_size, dropout, kernel_share=True):
        super(SplitMLP, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.split_size = split_size
        self.pad = PeriodPadding(self.split_size, input_size)
        self.num_splits = math.ceil(input_size / self.split_size)
        self.kernel_share = kernel_share

        if kernel_share:
            self.encoder = nn.Sequential(
                nn.Linear(self.split_size, self.hidden_size),
                nn.SELU(),
                nn.Dropout(dropout),
                nn.Linear(self.hidden_size, self.split_size),
            )
        else:
            for i in range(self.num_splits):
                setattr(self, f'encoder_{i}', nn.Sequential(
                    nn.Linear(self.split_size, self.hidden_size),
                    nn.SELU(),
                    nn.Dropout(dropout),
                    nn.Linear(self.hidden_size, self.split_size),
                ))
        self.flatten = nn.Flatten(start_dim=-2)

    def forward(self, x):
        """
        input: x: (batch_size, N, input_size)
        output: (batch_size, N, input_size)
        """
        x_pad = self.pad(x)
        # (batch_size, N, input_size) -> (batch_size, N, num_splits * split_size)
        x_pad = x_pad.view(x_pad.size(0), x_pad.size(1), self.num_splits, self.split_size)
        # (batch_size, N, num_splits * split_size) -> (batch_size, N, num_splits, split_size)
        if self.kernel_share:
            x_pad = self.encoder(x_pad)
        else:
            x_pad = torch.cat([getattr(self, f'encoder_{i}')(x_pad[:, :, i:i+1, :]) for i in range(self.num_splits)], dim=-2)
        x_pad = self.flatten(x_pad)[..., :self.input_size]
        return x_pad


class MLPBlock(nn.Module):

    def __init__(self, input_size, hidden_size, dropout):
        super(MLPBlock, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.dropout = dropout

        self.encoder = nn.Sequential(
            nn.Linear(self.input_size, self.hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(self.hidden_size, self.input_size),
        )

    def forward(self, x):
        """
        input: x: (batch_size, N, input_size)
        output: (batch_size, N, input_size)
        """
        return self.encoder(x)


class SingleLayerMLP(nn.Module):

    def __init__(self, input_size, hidden_size, output_size, dropout, mlp_type, s_size=None):
        super(SingleLayerMLP, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout = dropout
        self.mlp_type = mlp_type
        self.s_size = s_size

        if self.mlp_type == 'mlp':
            self.encoder = MLPBlock(
                input_size=self.input_size,
                hidden_size=self.hidden_size * self.s_size,
                dropout=self.dropout
            )
        elif self.mlp_type == 'skip':
            self.encoder = SkipMLP(
                input_size=self.input_size,
                hidden_size=self.hidden_size,
                skip_size=self.s_size,
                dropout=self.dropout
            )
        elif self.mlp_type == 'skip-i':
            self.encoder = SkipMLP(
                input_size=self.input_size,
                hidden_size=self.hidden_size,
                skip_size=self.s_size,
                dropout=self.dropout,
                kernel_share=False
            )
        elif self.mlp_type == 'split':
            self.encoder = SplitMLP(
                input_size=self.input_size,
                hidden_size=self.hidden_size,
                split_size=self.s_size,
                dropout=self.dropout
            )
        elif self.mlp_type == 'split-i':
            self.encoder = SplitMLP(
                input_size=self.input_size,
                hidden_size=self.hidden_size,
                split_size=self.s_size,
                dropout=self.dropout,
                kernel_share=False
            )
        elif self.mlp_type == 'ss':
            self.encoder = SSMLPBlock(
                input_size=self.input_size,
                hidden_size=self.hidden_size,
                skip_size=self.s_size,
                dropout=self.dropout
            )
        else:
            raise ValueError(f"mlp_type {self.mlp_type} is not supported")
        self.prediction = nn.Linear(self.input_size, self.output_size)

    def forward(self, x):
        """
        input: x: (batch_size, N, input_size)
        output: (batch_size, N, output_size)
        """
        x = self.encoder(x)
        x = self.prediction(x)
        return x


class SSMLPBlock(nn.Module):

    def __init__(self, input_size, hidden_size, skip_size, dropout, skip_first=True):
        super(SSMLPBlock, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.skip_size = skip_size
        self.split_size = skip_size
        self.skip_first = skip_first

        self.skip_mlp = SkipMLP(
            input_size=self.input_size,
            hidden_size=self.hidden_size,
            skip_size=self.skip_size,
            dropout=dropout
        )

        self.split_mlp = SplitMLP(
            input_size=self.input_size,
            hidden_size=self.hidden_size,
            split_size=self.split_size,
            dropout=dropout
        )

    def forward(self, x):
        """
        input: x: (batch_size, N, input_size)
        output: (batch_size, N, input_size)
        """
        if self.skip_first:
            skip_out = self.skip_mlp(x)
            out = self.split_mlp(skip_out) + skip_out
        else:
            split_out = self.split_mlp(x)
            out = self.skip_mlp(split_out) + split_out
        return out


class SSNet(nn.Module):

    def __init__(self,
                 seq_len,
                 pred_len,
                 n_vars,
                 hidden_size,
                 periods,
                 period_strength,
                 dropout,
                 channel_independence=False):
        super(SSNet, self).__init__()
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.n_vars = n_vars
        self.hidden_size = hidden_size
        self.periods = periods
        self.dropout = dropout
        self.channel_independence = channel_independence
        self.period_strength = torch.tensor(period_strength, dtype=torch.float32)
        # scaler
        self.period_strength = self.period_strength / torch.max(self.period_strength)
        self.ss_periods = self.periods + self.periods[:-1][::-1]

        if not self.channel_independence:
            self.skip_layer = nn.ModuleList([
                SkipMLP(
                    input_size=self.seq_len,
                    hidden_size=self.hidden_size,
                    skip_size=self.ss_periods[i],
                    dropout=self.dropout
                ) for i in range(len(self.periods))
            ])

        else:
            self.skip_layer = nn.ModuleList([
                nn.ModuleList([
                    SkipMLP(
                        input_size=self.seq_len,
                        hidden_size=self.hidden_size,
                        skip_size=periods[0],
                        dropout=self.dropout
                    ) for _ in range(self.n_vars)
                ]) for _ in range(len(self.periods))
            ])

        self.ss_mlp_layers = nn.ModuleList([
            SSMLPBlock(
                input_size=self.seq_len,
                hidden_size=self.hidden_size,
                skip_size=period,
                dropout=self.dropout
            ) for period in self.periods
        ])

        self.output_layer = nn.Sequential(
            nn.Linear(self.seq_len, self.periods[0] * self.hidden_size),
            nn.Linear(self.periods[0] * self.hidden_size, self.pred_len)
        )

    def forward(self, x):
        """
        input: x: (batch_size, N, seq_len)
        output: (batch_size, N, pred_len)
        """
        out_x = []
        for i in range(len(self.periods)):
            if not self.channel_independence:
                out_x.append(self.skip_layer[i](x))
            else:
                out_x.append(torch.zeros_like(x).to(x.device))
                for j in range(self.n_vars):
                    out_x[i][:, j: j + 1, :] = self.skip_layer[i][j](x[:, j: j + 1, :])

        out_x = torch.stack(out_x, dim=-1)
        period_weight = nn.Softmax(dim=0)(self.period_strength).to(x.device)
        period_weight = period_weight.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand_as(out_x)
        out_x = torch.sum(out_x * period_weight, -1)

        for ss_mlp_layer in self.ss_mlp_layers:
            out_x = ss_mlp_layer(out_x) + out_x

        out = self.output_layer(out_x)
        return out


class SSNet_SS_Only(nn.Module):

    def __init__(self,
                 seq_len,
                 pred_len,
                 n_vars,
                 hidden_size,
                 periods,
                 period_strength,
                 dropout,
                 channel_independence=False):
        super(SSNet_SS_Only, self).__init__()
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.n_vars = n_vars
        self.hidden_size = hidden_size
        self.periods = periods
        self.dropout = dropout
        self.channel_independence = channel_independence
        self.period_strength = torch.tensor(period_strength, dtype=torch.float32)
        # scaler
        self.period_strength = self.period_strength / torch.max(self.period_strength)
        self.ss_periods = self.periods + self.periods[:-1][::-1]

        if not self.channel_independence:
            self.skip_layer = nn.ModuleList([
                SkipMLP(
                    input_size=self.seq_len,
                    hidden_size=self.hidden_size,
                    skip_size=self.ss_periods[i],
                    dropout=self.dropout
                ) for i in range(len(self.periods))
            ])

        else:
            self.skip_layer = nn.ModuleList([
                nn.ModuleList([
                    SkipMLP(
                        input_size=self.seq_len,
                        hidden_size=self.hidden_size,
                        skip_size=periods[0],
                        dropout=self.dropout
                    ) for _ in range(self.n_vars)
                ]) for _ in range(len(self.periods))
            ])

        self.ss_mlp_layers = nn.ModuleList([
            SSMLPBlock(
                input_size=self.seq_len,
                hidden_size=self.hidden_size,
                skip_size=period,
                dropout=self.dropout
            ) for period in self.periods
        ])

        self.output_layer = nn.Sequential(
            nn.Linear(self.seq_len, self.periods[0] * self.hidden_size),
            nn.Linear(self.periods[0] * self.hidden_size, self.pred_len)
        )

    def forward(self, x):
        """
        input: x: (batch_size, N, seq_len)
        output: (batch_size, N, pred_len)
        """
        out_x = x
        for ss_mlp_layer in self.ss_mlp_layers:
            out_x = ss_mlp_layer(out_x) + out_x

        out = self.output_layer(out_x)
        return out


class SSNet_Skip_Only(nn.Module):

    def __init__(self,
                 seq_len,
                 pred_len,
                 n_vars,
                 hidden_size,
                 periods,
                 period_strength,
                 dropout,
                 channel_independence=False):
        super(SSNet_Skip_Only, self).__init__()
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.n_vars = n_vars
        self.hidden_size = hidden_size
        self.periods = periods
        self.dropout = dropout
        self.channel_independence = channel_independence
        self.period_strength = torch.tensor(period_strength, dtype=torch.float32)
        # scaler
        self.period_strength = self.period_strength / torch.max(self.period_strength)
        self.ss_periods = self.periods + self.periods[:-1][::-1]

        if not self.channel_independence:
            self.skip_layer = nn.ModuleList([
                SkipMLP(
                    input_size=self.seq_len,
                    hidden_size=self.hidden_size,
                    skip_size=self.ss_periods[i],
                    dropout=self.dropout
                ) for i in range(len(self.periods))
            ])

        else:
            self.skip_layer = nn.ModuleList([
                nn.ModuleList([
                    SkipMLP(
                        input_size=self.seq_len,
                        hidden_size=self.hidden_size,
                        skip_size=periods[0],
                        dropout=self.dropout
                    ) for _ in range(self.n_vars)
                ]) for _ in range(len(self.periods))
            ])

        self.ss_mlp_layers = nn.ModuleList([
            SSMLPBlock(
                input_size=self.seq_len,
                hidden_size=self.hidden_size,
                skip_size=period,
                dropout=self.dropout
            ) for period in self.periods
        ])

        self.output_layer = nn.Sequential(
            nn.Linear(self.seq_len, self.periods[0] * self.hidden_size),
            nn.Linear(self.periods[0] * self.hidden_size, self.pred_len)
        )

    def forward(self, x):
        """
        input: x: (batch_size, N, seq_len)
        output: (batch_size, N, pred_len)
        """
        out_x = []
        for i in range(len(self.periods)):
            if not self.channel_independence:
                out_x.append(self.skip_layer[i](x))
            else:
                out_x.append(torch.zeros_like(x).to(x.device))
                for j in range(self.n_vars):
                    out_x[i][:, j: j + 1, :] = self.skip_layer[i][j](x[:, j: j + 1, :])

        out_x = torch.stack(out_x, dim=-1)
        period_weight = nn.Softmax(dim=0)(self.period_strength).to(x.device)
        period_weight = period_weight.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand_as(out_x)
        out_x = torch.sum(out_x * period_weight, -1)

        out = self.output_layer(out_x)
        return out
