import torch
import torch.nn as nn


class Condition(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.config = config
        self.dec = nn.Linear(config.seq_len, config.pred_len)
        if self.config.task_name == "classification":
            self.linear = nn.Linear(config.feature_dim, config.num_class)

    def forward(self, x):
        out = self.dec(x.permute(0, 2, 1)).permute(0, 2, 1)

        if self.config.task_name == "classification":
            out = self.linear(out)

            out = out.mean(dim=1, keepdim=True)

        return out


class TransformerCondition(nn.Module):
    """
    A Transformer-based conditioning module.
    MODIFIED to include a convolutional pre-processing layer.
    """

    def __init__(self, config) -> None:
        super().__init__()


        d_model = config.d_model

        self.input_projection = nn.Linear(config.feature_dim, d_model)

        self.positional_encoding = nn.Parameter(torch.zeros(1, config.seq_len, d_model))

        self.conv_preprocess = nn.Conv1d(
            in_channels=d_model,
            out_channels=d_model,
            kernel_size=7,
            padding="same",  
            padding_mode="circular",  
        )
        self.conv_norm = nn.LayerNorm(d_model)
        self.conv_activation = nn.GELU()

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=config.n_heads,
            dim_feedforward=d_model * config.mlp_ratio,  
            dropout=0.1,
            activation="gelu",
            batch_first=True,
        )

        num_layers = getattr(
            config, "cond_n_depth", 2
        )  
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=num_layers
        )

        self.time_projection = nn.Linear(config.seq_len, config.pred_len)
        self.feature_projection = nn.Linear(d_model, config.feature_dim)

    def forward(self, x):
        """
        Input x: (B, seq_len, num_feat)
        """
        x_emb = self.input_projection(x) + self.positional_encoding

        x_conv_in = x_emb.permute(0, 2, 1)
        x_conv_out = self.conv_preprocess(x_conv_in)
        x_conv_out = x_conv_out.permute(0, 2, 1)

        x_processed = self.conv_norm(x_emb + self.conv_activation(x_conv_out))

        transformer_out = self.transformer_encoder(x_processed)

        out_time_proj = self.time_projection(transformer_out.permute(0, 2, 1))
        out = self.feature_projection(out_time_proj.permute(0, 2, 1))

        return out

