import torch
import torch.nn as nn
import torch.nn.functional as F




class Extract_T(nn.Module):

    def __init__(self, kernel_size, stride= 1):
        super().__init__()
        assert kernel_size % 2 == 1, "Only odd kernel size is supported."
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
    

    def forward(self, x):
        # padding on the both ends of time series
        # [B, T, C], B, C repeat 1, T repeat (self.kernel_size - 1) // 2
        front = x[ : , 0 : 1, : ].repeat(1, (self.kernel_size - 1) // 2, 1)     
        end = x[ : , -1 : , : ].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        return x




class Decomposition(nn.Module):

    def __init__(self, kernel_size):
        super().__init__()
        self.kernel_size = kernel_size
        self.extract_t = Extract_T(kernel_size, 1)
    

    def forward(self, x):
        t = self.extract_t(x)
        s = x - t
        return s, t




if __name__ == "__main__" :

    decomp = Decomposition(kernel_size= 3)
    x = torch.randn(4, 61, 8)
    s, t = decomp(x)
    print(s.shape, t.shape)