import torch
import torch.nn as nn
from typing import Optional


class ConvEncoder(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, kernel_size, num_layers, conditional=False,
                 n_tokens_cond=None):
        super().__init__()

        self.conditional = conditional
        self.conv_layers = nn.ModuleList()
        self.conv_layers.append(
            nn.Conv1d(embedding_dim, hidden_dim, kernel_size, padding=(kernel_size // 2))
        )

        if self.conditional:
            if n_tokens_cond is not None:
                embedding_dim = n_tokens_cond
            self.cond = nn.Sequential(nn.Conv1d(embedding_dim, hidden_dim, kernel_size, padding=(kernel_size // 2)),
                                     nn.ReLU(),
                                     nn.BatchNorm1d(hidden_dim)
                                     )

        for _ in range(1, num_layers):
            self.conv_layers.append(
                nn.Conv1d(hidden_dim, hidden_dim, kernel_size, padding=(kernel_size // 2))
            )

        self.activation = nn.ReLU()
        self.batch_norm = nn.ModuleList([nn.BatchNorm1d(hidden_dim) for _ in range(num_layers)])

    def forward(self, x, cond=False):
        x = x.transpose(1, 2)
        for i, (conv, bn) in enumerate(zip(self.conv_layers, self.batch_norm)):
            if self.conditional and i == 0:
                cond = cond.transpose(1, 2)
                x = self.activation(bn(conv(x))) + self.cond(cond)
            else:
                x = self.activation(bn(conv(x)))
        return x.transpose(1, 2)


class ConvDecoder(nn.Module):
    def __init__(self, hidden_dim, output_dim, kernel_size, num_layers):
        super().__init__()

        self.conv_layers = nn.ModuleList()
        for _ in range(num_layers - 1):
            self.conv_layers.append(
                nn.Conv1d(hidden_dim, hidden_dim, kernel_size, padding=(kernel_size // 2))
            )

        self.conv_layers.append(
            nn.Conv1d(hidden_dim, output_dim, kernel_size, padding=(kernel_size // 2))
        )

        self.activation = nn.ReLU()
        self.batch_norm = nn.ModuleList([nn.BatchNorm1d(hidden_dim) for _ in range(num_layers - 1)])

    def forward(self, x):
        x = x.transpose(1, 2)
        for i, conv in enumerate(self.conv_layers):
            if i < len(self.batch_norm):
                x = self.activation(self.batch_norm[i](conv(x)))
            else:
                x = conv(x)
        return x.transpose(1, 2)

