import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def Conv1d_with_init(in_channels, out_channels, kernel_size):
    layer = nn.Conv1d(in_channels, out_channels, kernel_size)
    nn.init.kaiming_normal_(layer.weight)
    return layer

class diff_CSDI_PR(nn.Module):
    def __init__(self, config, inputdim=1):
        super().__init__()
        self.channels = config["channels"]
        self.time_len = config["time_len"]
        self.feat_len = config["feat_len"]

        self.input_projection1 = Conv1d_with_init(inputdim, self.channels, 1)
        self.output_projection1 = Conv1d_with_init(self.channels, 1, 1)

        self.residual_layers = nn.ModuleList(
            [
                ResidualBlock(
                    channels=self.channels,
                    time_len=self.time_len,
                    feat_len=self.feat_len
                )
                for _ in range(config["layers"])
            ]
        )

    def forward(self, x):
        B, inputdim, K, L = x.shape

        x = x.reshape(B, inputdim, K * L)
        x = self.input_projection1(x)
        x = F.relu(x)
        x = x.reshape(B, self.channels, K, L)

        skip = []
        for layer in self.residual_layers:
            x, skip_connection = layer(x)
            skip.append(skip_connection)

        x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
        x = x.reshape(B, self.channels, K * L)
        x = self.output_projection1(x)  # (B,channel,K*L)
        x = x.reshape(B, K, L)
        x = torch.sigmoid(x)
        return x


class ResidualBlock(nn.Module):
    def __init__(self, channels, time_len, feat_len):
        super().__init__()
        self.mid_projection = Conv1d_with_init(channels, 2 * channels, 1)
        self.output_projection = Conv1d_with_init(channels, 2 * channels, 1)

        self.time_layer = nn.Sequential(
            nn.Linear(time_len, int(channels/8)),
            nn.GELU(),
            nn.Linear(int(channels/8),time_len),
        )
        self.feature_layer = nn.Sequential(
            nn.Linear(feat_len, int(channels/8)),
            nn.GELU(),
            nn.Linear(int(channels/8),feat_len),
        )

    def forward_time(self, y, base_shape):
        B, channel, K, L = base_shape
        if L == 1:
            return y
        y = y.reshape(B, channel, K, L).permute(0, 2, 1, 3).reshape(B * K, channel, L)

        y = self.time_layer(y)

        y = y.reshape(B, K, channel, L).permute(0, 2, 1, 3).reshape(B, channel, K * L)
        return y


    def forward_feature(self, y, base_shape):
        B, channel, K, L = base_shape
        if K == 1:
            return y
        y = y.reshape(B, channel, K, L).permute(0, 3, 1, 2).reshape(B * L, channel, K)

        y = self.feature_layer(y)

        y = y.reshape(B, L, channel, K).permute(0, 2, 3, 1).reshape(B, channel, K * L)
        return y

    def forward(self, x):
        B, channel, K, L = x.shape
        base_shape = x.shape
        x = x.reshape(B, channel, K * L)

        y = self.forward_time(x, base_shape)
        y = self.forward_feature(y, base_shape)  # (B,channel,K*L)
        y = self.mid_projection(y)  # (B,2*channel,K*L)

        gate, filter = torch.chunk(y, 2, dim=1)
        y = torch.sigmoid(gate) * torch.tanh(filter)  # (B,channel,K*L)
        y = self.output_projection(y)

        residual, skip = torch.chunk(y, 2, dim=1)
        x = x.reshape(base_shape)
        residual = residual.reshape(base_shape)
        skip = skip.reshape(base_shape)
        return (x + residual) / math.sqrt(2.0), skip