from torch import nn
from dilated_causal_cnn import DilatedCausalConvEncoder

class MultiResolutionEncoder(nn.Module):
    def __init__(self, input_dims, output_dims, hidden_dims=16, depth=8):
        super(MultiResolutionEncoder, self).__init__()
        self.input_dims = input_dims
        self.output_dims = output_dims
        self.hidden_dims = hidden_dims

        self.temporal_encoder = DilatedCausalConvEncoder(
            input_dims,
            [hidden_dims] * depth + [output_dims],
            kernel_size=3
        )
        self.repr_dropout = nn.Dropout(p=0.1)

    def forward(self, x):
        x = x.permute(1, 2, 0)
        x = self.repr_dropout(self.temporal_encoder(x))
        x = x.permute(2, 0, 1)
        return x