import torch.nn as nn
from torch.nn.utils import weight_norm
import torch
from .RevIN import RevIN

class TCNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, dilation, padding=0):
        super(TCNBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu(out)
        return out


class f_emb_block(nn.Module):
    def __init__(self, dim):
        super(f_emb_block, self).__init__()
        self.te = nn.Sequential(
            nn.Conv1d(dim, dim, kernel_size=8, stride=4, groups=1),
            nn.Conv1d(dim, dim, kernel_size=2, stride=1, groups=dim, padding=1),
            
            nn.BatchNorm1d(dim),
            nn.Conv1d(dim, dim, kernel_size=1, stride=1),
        )

    def forward(self, x):
        x = self.te(x)
        return x


class f_embed(nn.Module):
    def __init__(self, input_size, output_size, layers=1, usehead=False):
        super(f_embed, self).__init__()
        if usehead:
            self.emb =  nn.Sequential(
                            nn.Conv1d(input_size, output_size, kernel_size=2, stride=2,  padding=0), # date 24 08 15
                            nn.Conv1d(output_size, output_size, kernel_size=1, stride=1)
            )
        else:
            self.emb = nn.Conv1d(input_size, output_size, kernel_size=1, stride=1)
        self.te = nn.ModuleList([
            f_emb_block(output_size) for _ in range(layers)
        ])

    def forward(self, x):
        x = self.emb(x)
        for t in self.te:
            x = t(x)
        return x


class tcn_emb_block(nn.Module):
    def __init__(self, dim):
        super(tcn_emb_block, self).__init__()
        self.te = nn.Sequential(
            TCNBlock(dim, dim, kernel_size=2, stride=1, dilation=1, padding=4),
            TCNBlock(dim, dim, kernel_size=8, stride=4, dilation=4, padding=10),
            nn.BatchNorm1d(dim),
            nn.Conv1d(dim, dim, kernel_size=1, stride=1),
        )

    def forward(self, x):
        x = self.te(x)
        return x


class tcn_embed(nn.Module):
    def __init__(self, input_size, output_size, layers=1, usehead=False):
        super(tcn_embed, self).__init__()
        if usehead:
            self.emb =  nn.Sequential(
                            nn.Conv1d(input_size, output_size, kernel_size=2, stride=2,  padding=0), # date 24 08 15
                            nn.Conv1d(output_size, output_size, kernel_size=1, stride=1)
            )
        else:
            self.emb = nn.Conv1d(input_size, output_size, kernel_size=1, stride=1)
        self.te = nn.ModuleList([
            tcn_emb_block(output_size) for _ in range(layers)
        ])

    def forward(self, x):
        x = self.emb(x)
        for t in self.te:
            x = t(x)
        return x


class DSConv_block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, depth_multiplier=1, drop=0.5):
        super(DSConv_block, self).__init__()
        self.depthwise = nn.Conv1d(in_channels, in_channels * depth_multiplier, kernel_size,
                                   padding=(kernel_size - 1) // 2, groups=in_channels)
        self.pointwise = nn.Conv1d(in_channels * depth_multiplier, out_channels, 1)
        self.ffn1 = nn.Conv1d(out_channels, out_channels*depth_multiplier//2 , 1)
        self.ffn2 = nn.Conv1d(out_channels * depth_multiplier//2, out_channels, 1)
        self.ffn3 = nn.Conv1d(out_channels, out_channels//2, 2, 1, 1)
        self.ffn4 = nn.Conv1d(out_channels//2, out_channels, 2, 1, 0)
        self.cattn = nn.Conv1d(out_channels, 1, kernel_size*4+1,
                                   padding=(kernel_size*4) // 2, groups=1)
        self.gelu = nn.GELU()
        self.norm = nn.BatchNorm1d(out_channels)
        self.drop = nn.Dropout(drop)

    def forward(self, x, cattn=True):
        input = x
        x = self.depthwise(x)
        x = self.gelu(x)
        x = self.pointwise(x)
        x = self.drop(x)
        if cattn:
            x1 = self.ffn1(x)
            x1 = self.ffn2(x1)
            x2 = self.ffn3(x)
            x2 = self.ffn4(x2)
            x_ = self.cattn(x)
            x = (x1 + x2)*x_
            
        else:
            x1 = self.ffn1(x)
            x1 = self.ffn2(x1)
            x2 = self.ffn3(x)
            x2 = self.ffn4(x2)
            x_ = self.cattn(x)
            x = (x1 + x2)+x_
        x = self.norm(x)
        if x.shape == input.shape:
            x = input + x
        return x


class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, ds_kernel_size=8, depth_multiplier=1, layers=1, revin=True, affine=True,
                 subtract_last=False):
        super(DepthwiseSeparableConv, self).__init__()
        self.downsample_layers = nn.ModuleList()

        stem = nn.Sequential(
            nn.Conv1d(in_channels, in_channels, kernel_size=ds_kernel_size, stride=2, padding=(ds_kernel_size-2)//2),
        )
        self.downsample_layers.append(stem)
        for i in range(layers - 1):
            downsample_layer = nn.Sequential(
                nn.Conv1d(in_channels, in_channels, kernel_size=2, stride=2, padding=0),
                # nn.BatchNorm1d(in_channels),
            )
            self.downsample_layers.append(downsample_layer)

        self.blocks = nn.ModuleList([
            DSConv_block(in_channels, out_channels if i == layers - 1 else in_channels, kernel_size, depth_multiplier)
            for i in range(layers)
        ])

        self.blocks1 = nn.ModuleList([
            DSConv_block(out_channels, out_channels, kernel_size, depth_multiplier)
            for _ in range(layers)
        ])

        self.upconv = nn.ModuleList()
        for i in range(layers):
            up = nn.Sequential(
                nn.ConvTranspose1d(out_channels, out_channels, kernel_size=2, stride=2),
                # nn.BatchNorm1d(out_channels)
                )
            self.upconv.append(up)



        self.revin = revin
        if self.revin:
            self.revin_layer = RevIN(in_channels, affine=affine, subtract_last=subtract_last)

    def forward(self, x, cattn=True):
        if self.revin:
            x = x.permute(0, 2, 1)
            x = self.revin_layer(x, 'norm')
            x = x.permute(0, 2, 1)

        l = len(self.blocks)
        # B, C, N = x.shape
        x_dw = []
        for i in range(l):
            x_dw.append(x)
            x = self.downsample_layers[i](x)
            x = self.blocks[i](x,cattn)

        for i in range(l):
            x = self.blocks1[i](x,cattn)
            x = self.upconv[i](x) + x_dw[-i - 1]

        # if self.use_multi_scale:
        #     multi_scale_features = [conv(x) for conv in self.multi_scale_conv]
        #     x = torch.cat(multi_scale_features, dim=1)

        if self.revin:
            x = x.permute(0, 2, 1)
            x = self.revin_layer(x, 'denorm')
            x = x.permute(0, 2, 1)

        return x
    
    
class DepthwiseSeparableConv1(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, ds_kernel_size=8, depth_multiplier=1, layers=1, revin=True, affine=True,
                 subtract_last=False):
        super(DepthwiseSeparableConv1, self).__init__()
        self.downsample_layers = nn.ModuleList()
        self.emb = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(out_channels),
        )
        stem = nn.Sequential(
            nn.Conv1d(out_channels, out_channels, kernel_size=ds_kernel_size, stride=2, padding=(ds_kernel_size-2)//2),
            
        )
        self.downsample_layers.append(stem)
        for i in range(layers - 1):
            downsample_layer = nn.Sequential(
                nn.Conv1d(out_channels, out_channels, kernel_size=2, stride=2, padding=0),
                
            )
            self.downsample_layers.append(downsample_layer)

        self.blocks = nn.ModuleList([
            DSConv_block1(out_channels, out_channels, kernel_size, depth_multiplier)
            for i in range(layers)
        ])

        self.blocks1 = nn.ModuleList([
            DSConv_block1(out_channels, out_channels, kernel_size, depth_multiplier)
            for _ in range(layers)
        ])

        self.upconv = nn.ModuleList()
        for i in range(layers):
            up = nn.Sequential(
                nn.ConvTranspose1d(out_channels, out_channels, kernel_size=2, stride=2),
                # nn.BatchNorm1d(out_channels)
                )
            self.upconv.append(up)
        
        self.head = nn.Conv1d(out_channels, out_channels,kernel_size=2, stride=1, padding=1)
        self.remb = nn.Conv1d(out_channels, in_channels, kernel_size=1, stride=1, padding=0)

        self.revin = revin
        if self.revin:
            self.revin_layer = RevIN(out_channels, affine=affine, subtract_last=subtract_last)

    def forward(self, x, cattn=True):
        x = x.transpose(2, 1)
        b, d, h = x.shape
        x = self.emb(x)
        if self.revin:
            x = x.permute(0, 2, 1)
            x = self.revin_layer(x, 'norm')
            x = x.permute(0, 2, 1)
        
        l = len(self.blocks)
        # B, C, N = x.shape
        x_dw = []
        for i in range(l):
            x_dw.append(x)
            x = self.downsample_layers[i](x)
            x = self.blocks[i](x,cattn)

        for i in range(l-1):
            x = self.blocks1[i](x,cattn)
            x = self.upconv[i](x) + x_dw[-i - 1]
        if h % 2 != 0:
            x = self.blocks1[l-1](x,cattn)
            x = self.head(self.upconv[l-1](x)) + x_dw[0]
        else:
            x = self.blocks1[l-1](x,cattn)
            x = self.upconv[l-1](x) + x_dw[0]
        
        # if self.use_multi_scale:
        #     multi_scale_features = [conv(x) for conv in self.multi_scale_conv]
        #     x = torch.cat(multi_scale_features, dim=1)
        # x = self.remb(x)
        if self.revin:
            x = x.permute(0, 2, 1)
            x = self.revin_layer(x, 'denorm')
            x = x.permute(0, 2, 1)
        
        return x