import torch.nn as nn
import torch
import torch.nn.functional as F


class SequenceInitializer(nn.Module):
    def __init__(self, d_embed):
        super().__init__()
        self.window_size = 3
        self.tcn = nn.Sequential(nn.Conv1d(d_embed, 100, kernel_size=self.window_size), nn.ReLU(), 
                                 nn.Conv1d(100, d_embed, kernel_size=1))
    
    def forward(self, x):
        # Assume x has shape (bsz x C x seq_len)
        return self.tcn(F.pad(x, (self.window_size-1,0)))

class MultiscaleInitializer(nn.Module):
    def __init__(self, num_channels):
        super().__init__()
        self.num_channels = num_channels
        self.num_branches = len(num_channels)
        self.conv = nn.Sequential(nn.Conv2d(num_channels[0], num_channels[0], kernel_size=3, padding=1), nn.GroupNorm(4, num_channels[0]), nn.ReLU(),
                                  nn.Conv2d(num_channels[0], num_channels[0], kernel_size=1))
        self.low_convs = nn.ModuleList([None])
        for branch_index in range(1, self.num_branches):
            self.low_convs.append(nn.Sequential(nn.Conv2d(num_channels[branch_index-1], num_channels[branch_index], kernel_size=3, padding=1, stride=2), nn.ReLU(),
                                                nn.Conv2d(num_channels[branch_index], num_channels[branch_index], kernel_size=1)))
        
    def forward(self, x):
        z1_init = [self.conv(x)]
        for branch_index in range(1, self.num_branches):
            z1_init.append(self.low_convs[branch_index](z1_init[-1]))
        return z1_init