import torch
import numpy as np
from typing import List, Optional
import torch.nn as nn
# # 初始化一个(1,2,2)的List[int]命名为dim_mult
# dim_mult: List[int] = [1, 2, 2] 
# in_dim : int = 512
# dims = [32 * m for m in np.cumprod(dim_mult)]
# print(dims)
# dims = [in_dim] + dims
# print(dims)
class Downsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Upsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)
    
x = torch.tensor([[[1, 2]]], dtype=torch.float32)
dowmsample = Downsample1d(1)
print(dowmsample(x).shape)  
upsample = Upsample1d(1)
print(upsample(x).shape)