import torch.nn as nn
from torch.nn.utils import weight_norm
import torch
from .RevIN import RevIN

class TCNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, dilation, padding=0):
        super(TCNBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu(out)
        return out


class moving_avg(nn.Module):
    """
    Moving average block to highlight the trend of time series
    """
    def __init__(self, kernel_size, stride):
        super(moving_avg, self).__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        # padding on the both ends of time series
        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        return x


class series_decomp(nn.Module):
    """
    Series decomposition block
    """
    def __init__(self, kernel_size):
        super(series_decomp, self).__init__()
        self.moving_avg = moving_avg(kernel_size, stride=1)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean


class f_emb_block(nn.Module):
    def __init__(self, dim):
        super(f_emb_block, self).__init__()
        self.te = nn.Sequential(
            nn.Conv1d(dim, dim, kernel_size=8, stride=4, groups=1),
            nn.Conv1d(dim, dim, kernel_size=2, stride=1, groups=dim, padding=1),
            
            nn.BatchNorm1d(dim),
            nn.Conv1d(dim, dim, kernel_size=1, stride=1),
        )

    def forward(self, x):
        x = self.te(x)
        return x

class f_embed(nn.Module):
    def __init__(self, input_size, output_size, layers=1):
        super(f_embed, self).__init__()
        self.emb = nn.Conv1d(input_size, output_size, kernel_size=1, stride=1)
        # self.emb =  nn.Sequential(
        #                 nn.Conv1d(input_size, output_size, kernel_size=2, stride=2, padding=0), # date 24 08 15
        #                 nn.Conv1d(output_size, output_size, kernel_size=1, stride=1)
        # )
        self.te = nn.ModuleList([
            f_emb_block(output_size) for _ in range(layers)
        ])

    def forward(self, x):
        x = self.emb(x)
        for t in self.te:
            x = t(x)
        return x

class tcn_emb_block(nn.Module):
    def __init__(self, dim):
        super(tcn_emb_block, self).__init__()
        self.te = nn.Sequential(
            TCNBlock(dim, dim, kernel_size=2, stride=1, dilation=1, padding=4),
            TCNBlock(dim, dim, kernel_size=8, stride=4, dilation=4, padding=10),
            # TCNBlock(dim, dim, kernel_size=3, stride=1, dilation=1, padding=2),
            # TCNBlock(dim, dim, kernel_size=4, stride=2, dilation=2, padding=3),
            # TCNBlock(dim, dim, kernel_size=4, stride=2, dilation=4, padding=5),
            nn.BatchNorm1d(dim),
            nn.Conv1d(dim, dim, kernel_size=1, stride=1),
        )

    def forward(self, x):
        x = self.te(x)
        return x

class tcn_embed(nn.Module):
    def __init__(self, input_size, output_size, layers=1):
        super(tcn_embed, self).__init__()
        self.emb = nn.Conv1d(input_size, output_size, kernel_size=1, stride=1)
        self.te = nn.ModuleList([
            tcn_emb_block(output_size) for _ in range(layers)
        ])

    def forward(self, x):
        x = self.emb(x)
        for t in self.te:
            x = t(x)
        return x

class DSConv_block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, depth_multiplier=1, drop=0.5):
        super(DSConv_block, self).__init__()
        self.depthwise = nn.Conv1d(in_channels, in_channels * depth_multiplier, kernel_size,
                                   padding=(kernel_size - 1) // 2, groups=in_channels)
        self.pointwise = nn.Conv1d(in_channels * depth_multiplier, out_channels, 1)
        self.ffn1 = nn.Conv1d(out_channels, out_channels * 4, 1)
        self.ffn2 = nn.Conv1d(out_channels * 4, out_channels, 1)
        self.ffn3 = nn.Conv1d(out_channels, out_channels//4, 2, 1, 1)
        self.ffn4 = nn.Conv1d(out_channels//4, out_channels, 2, 1, 0)
        self.cattn = nn.Conv1d(out_channels, 1, kernel_size*4+1,
                                   padding=(kernel_size*4) // 2, groups=1)
        self.gelu = nn.GELU()
        # self.relu = nn.ReLU()
        self.norm = nn.BatchNorm1d(out_channels)
        self.drop = nn.Dropout(drop)
        
    def forward(self, x, cattn=True):
        input = x
        x = self.depthwise(x)
        x = self.gelu(x)
        x = self.pointwise(x)
        x = self.drop(x)
        if cattn:
            x1 = self.ffn1(x)
            x1 = self.ffn2(x1)
            x2 = self.ffn3(x)
            x2 = self.ffn4(x2)
            x_ = self.cattn(x)
            x = (x1 + x2)*x_
            
        else:
            x1 = self.ffn1(x)
            x1 = self.ffn2(x1)
            x2 = self.ffn3(x)
            x2 = self.ffn4(x2)
            x = x2
        x = self.norm(x)
        if x.shape == input.shape:
            x = input + x
        return x

class DSConv_block1(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, depth_multiplier=1, drop=0.5):
        super(DSConv_block1, self).__init__()
        self.depthwise = nn.Conv1d(in_channels, in_channels * depth_multiplier, kernel_size,
                                   padding=(kernel_size - 1) // 2, groups=in_channels)
        self.pointwise = nn.Conv1d(in_channels * depth_multiplier, out_channels, 1)
        self.ffn1 = nn.Conv1d(out_channels, out_channels * depth_multiplier//2, 1)
        self.ffn2 = nn.Conv1d(out_channels * depth_multiplier//2, out_channels, 1)
        self.ffn3 = nn.Conv1d(out_channels, out_channels, 2, 1, 1)
        self.ffn4 = nn.Conv1d(out_channels, out_channels, 2, 1, 0)
        self.ffn5 = nn.Conv1d(out_channels, out_channels//2, 3, 1, 1)
        self.ffn6 = nn.Conv1d(out_channels//2, out_channels, 3, 1, 1)
        self.cattn = nn.Conv1d(out_channels, 1, kernel_size*4+1,
                                   padding=(kernel_size*4) // 2, groups=1)
        self.cattn1 = nn.Conv1d(out_channels, 1, kernel_size*6+1,
                                   padding=(kernel_size*6) // 2, groups=1)
        self.cattn2 = nn.Conv1d(out_channels, 1, kernel_size*8+1,
                                   padding=(kernel_size*8) // 2, groups=1)
        self.gelu = nn.GELU()
        self.norm = nn.BatchNorm1d(out_channels)
        self.drop = nn.Dropout(drop)
        # self.drop1 = nn.Dropout(drop)

    def forward(self, x, cattn=True):
        input = x
        x = self.depthwise(x)
        x = self.gelu(x)
        x = self.pointwise(x)
        x = self.drop(x)
        if cattn:
            x1 = self.ffn1(x)
            x1 = self.ffn2(x1)
            x2 = self.ffn3(x)
            x2 = self.ffn4(x2)
            x3 = self.ffn5(x)
            x3 = self.ffn6(x3)
            x_ = self.cattn(x)
            x1_ = self.cattn1(x)
            x2_ = self.cattn2(x)
            x = (x1*x_+x2*x1_+x3*x2_)
        else:
            x1 = self.ffn1(x)
            x1 = self.ffn2(x1)
            x2 = self.ffn3(x)
            x2 = self.ffn4(x2)
            x = x2
        x = self.norm(x)
        # x = self.drop1(x)

        if x.shape == input.shape:
            x = input + x
        return x

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, ds_kernel_size=8, depth_multiplier=1, layers=1, revin=True, affine=True,
                 subtract_last=False):
        super(DepthwiseSeparableConv, self).__init__()
        self.downsample_layers = nn.ModuleList()
        self.emb = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(out_channels),
        )
        stem = nn.Sequential(
            nn.Conv1d(out_channels, out_channels, kernel_size=ds_kernel_size, stride=2, padding=(ds_kernel_size-2)//2),
            
        )
        self.downsample_layers.append(stem)
        for i in range(layers - 1):
            downsample_layer = nn.Sequential(
                nn.Conv1d(out_channels, out_channels, kernel_size=2, stride=2, padding=0),
                
            )
            self.downsample_layers.append(downsample_layer)

        self.blocks = nn.ModuleList([
            DSConv_block(out_channels, out_channels, kernel_size, depth_multiplier)
            for i in range(layers)
        ])

        self.blocks1 = nn.ModuleList([
            DSConv_block(out_channels, out_channels, kernel_size, depth_multiplier)
            for _ in range(layers)
        ])

        self.upconv = nn.ModuleList()
        for i in range(layers):
            up = nn.Sequential(
                nn.ConvTranspose1d(out_channels, out_channels, kernel_size=2, stride=2),
                # nn.BatchNorm1d(out_channels)
                )
            self.upconv.append(up)
        
        self.head = nn.Conv1d(out_channels, out_channels,kernel_size=2, stride=1, padding=1)
        self.remb = nn.Conv1d(out_channels, in_channels, kernel_size=1, stride=1, padding=0)

        self.revin = revin
        if self.revin:
            self.revin_layer = RevIN(out_channels, affine=affine, subtract_last=subtract_last)

    def forward(self, x):
        x = x.transpose(2, 1)
        b, d, h = x.shape
        x = self.emb(x)
        if self.revin:
            x = x.permute(0, 2, 1)
            x = self.revin_layer(x, 'norm')
            x = x.permute(0, 2, 1)
        
        l = len(self.blocks)
        # B, C, N = x.shape
        x_dw = []
        for i in range(l):
            x_dw.append(x)
            x = self.downsample_layers[i](x)
            x = self.blocks[i](x)

        for i in range(l-1):
            x = self.blocks1[i](x)
            x = self.upconv[i](x) + x_dw[-i - 1]
        if h % 2 != 0:
            x = self.blocks1[l-1](x)
            x = self.head(self.upconv[l-1](x)) + x_dw[0]
        else:
            x = self.blocks1[l-1](x)
            x = self.upconv[l-1](x) + x_dw[0]
        
        # if self.use_multi_scale:
        #     multi_scale_features = [conv(x) for conv in self.multi_scale_conv]
        #     x = torch.cat(multi_scale_features, dim=1)
        # x = self.remb(x)
        if self.revin:
            x = x.permute(0, 2, 1)
            x = self.revin_layer(x, 'denorm')
            x = x.permute(0, 2, 1)
        
        return x

class reDepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, ds_kernel_size=8, depth_multiplier=1, layers=1, revin=True, affine=True,
                 subtract_last=False):
        super(reDepthwiseSeparableConv, self).__init__()
        self.downsample_layers = nn.ModuleList()
        self.emb = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(out_channels),
        )
        stem = nn.Sequential(
            nn.Conv1d(out_channels, out_channels, kernel_size=ds_kernel_size, stride=2, padding=(ds_kernel_size-2)//2),
            
        )
        self.downsample_layers.append(stem)
        for i in range(layers - 1):
            downsample_layer = nn.Sequential(
                nn.Conv1d(out_channels, out_channels, kernel_size=2, stride=2, padding=0),
                
            )
            self.downsample_layers.append(downsample_layer)

        self.blocks = nn.ModuleList([
            DSConv_block(out_channels, out_channels, kernel_size, depth_multiplier)
            for i in range(layers)
        ])

        self.blocks1 = nn.ModuleList([
            DSConv_block(out_channels, out_channels, kernel_size, depth_multiplier)
            for _ in range(layers)
        ])

        self.upconv = nn.ModuleList()
        for i in range(layers):
            up = nn.Sequential(
                nn.ConvTranspose1d(out_channels, out_channels, kernel_size=2, stride=2),
                # nn.BatchNorm1d(out_channels)
                )
            self.upconv.append(up)
        
        # self.head = nn.Conv1d(out_channels, out_channels,kernel_size=2, stride=1, padding=1)
        # self.remb = nn.Conv1d(out_channels, in_channels, kernel_size=1, stride=1, padding=0)

        self.revin = revin
        if self.revin:
            self.revin_layer = RevIN(out_channels, affine=affine, subtract_last=subtract_last)

    def forward(self, x):
        x = x.transpose(2, 1)
        b, d, h = x.shape
        x = self.emb(x)
        if self.revin:
            x = x.permute(0, 2, 1)
            x = self.revin_layer(x, 'norm')
            x = x.permute(0, 2, 1)
        
        l = len(self.blocks)
        # B, C, N = x.shape
        x_dw = []
        for i in range(l):
            x_dw.append(x)
            x = self.upconv[i](x) 
            x = self.blocks[i](x)

        for i in range(l):
            x = self.blocks1[i](x)
            x = self.downsample_layers[i](x) + x_dw[-i - 1]
        # if h % 2 != 0:
        #     x = self.blocks1[l-1](x)
        #     x = self.head(self.downsample_layers[l-1](x)) + x_dw[0]
        # else:
        #     x = self.blocks1[l-1](x)
        #     x = self.downsample_layers[l-1](x) + x_dw[0]
        
        # if self.use_multi_scale:
        #     multi_scale_features = [conv(x) for conv in self.multi_scale_conv]
        #     x = torch.cat(multi_scale_features, dim=1)
        # x = self.remb(x)
        if self.revin:
            x = x.permute(0, 2, 1)
            x = self.revin_layer(x, 'denorm')
            x = x.permute(0, 2, 1)
        
        return x
    
class DepthwiseSeparableConv1(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, ds_kernel_size=8, depth_multiplier=1, layers=1, revin=True, affine=True,
                 subtract_last=False):
        super(DepthwiseSeparableConv1, self).__init__()
        self.downsample_layers = nn.ModuleList()

        stem = nn.Sequential(
            nn.Conv1d(in_channels, in_channels, kernel_size=ds_kernel_size, stride=2, padding=(ds_kernel_size-2)//2),
        )
        self.downsample_layers.append(stem)
        for i in range(layers - 1):
            downsample_layer = nn.Sequential(
                nn.Conv1d(in_channels, in_channels, kernel_size=2, stride=2, padding=0),
                # nn.BatchNorm1d(in_channels),
            )
            self.downsample_layers.append(downsample_layer)

        self.blocks = nn.ModuleList([
            DSConv_block1(in_channels, out_channels if i == layers - 1 else in_channels, kernel_size, depth_multiplier)
            for i in range(layers)
        ])

        self.blocks1 = nn.ModuleList([
            DSConv_block1(out_channels, out_channels, kernel_size, depth_multiplier)
            for _ in range(layers)
        ])

        self.upconv = nn.ModuleList()
        for i in range(layers):
            up = nn.Sequential(
                nn.ConvTranspose1d(out_channels, out_channels, kernel_size=2, stride=2),
                # nn.BatchNorm1d(out_channels)
                )
            self.upconv.append(up)



        self.revin = revin
        if self.revin:
            self.revin_layer = RevIN(in_channels, affine=affine, subtract_last=subtract_last)

    def forward(self, x):
        if self.revin:
            x = x.permute(0, 2, 1)
            x = self.revin_layer(x, 'norm')
            x = x.permute(0, 2, 1)

        l = len(self.blocks)
        # B, C, N = x.shape
        x_dw = []
        for i in range(l):
            x_dw.append(x)
            x = self.downsample_layers[i](x)
            x = self.blocks[i](x)

        for i in range(l):
            x = self.blocks1[i](x)
            x = self.upconv[i](x) + x_dw[-i - 1]

        # if self.use_multi_scale:
        #     multi_scale_features = [conv(x) for conv in self.multi_scale_conv]
        #     x = torch.cat(multi_scale_features, dim=1)

        if self.revin:
            x = x.permute(0, 2, 1)
            x = self.revin_layer(x, 'denorm')
            x = x.permute(0, 2, 1)

        return x