import torch
from torch import nn


class VanillaNetwork(nn.Module):
    def __init__(self, num_variable_value, hidden_dim, output_size, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.network = nn.Sequential(
            nn.Linear(num_variable_value, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Linear(hidden_dim // 4, output_size),
        )

    def forward(self, evidence, **kwargs):
        return self.network(evidence)


class BBAttentionNetwork(torch.nn.Module):

    def __init__(
        self,
        num_variable_value,
        embed_dim,
        num_layers,
        hidden_dim,
        output_size,
        padding_idx,
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)

        # Initialize embeddings for all variable value pair
        self.embedding_input = nn.Embedding(
            num_variable_value, embed_dim, padding_idx=padding_idx
        )

        # Evidence feature extractor
        self.evid_net = nn.Sequential(nn.Linear(embed_dim, hidden_dim), nn.ReLU())

        # Choice feature extractor
        self.choice_net = nn.Sequential(nn.Linear(embed_dim, hidden_dim), nn.ReLU())

        # Attention mechanism
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim, num_heads=2, batch_first=True
        )

        self.fusion_net = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim), nn.ReLU()
        )

        self.activation = nn.ReLU()
        self.skip_connection_layers = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        self.layer_norms = nn.ModuleList()

        for _ in range(num_layers):
            self.skip_connection_layers.append(nn.Linear(hidden_dim, hidden_dim))
            self.dropouts.append(nn.Dropout(0.1))
            self.layer_norms.append(nn.LayerNorm(hidden_dim))

        self.optimality_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Linear(hidden_dim // 4, output_size),
        )

        self.decimation_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Linear(hidden_dim // 4, output_size),
        )

    def _get_choice_output(self, evid_features, choice):
        # Process Choice Input
        # Shape: (batch_size, choice_size, embed_dim)
        choice_enc = self.embedding_input(choice)

        # Shape: (batch_size, choice_size, hidden_dim)
        choice_input = self.choice_net(choice_enc)
        # Apply attention mechanism
        # Choice acts as the query, Evidence as key and value
        # Attention Output Shape: (batch_size, 1, hidden_dim) and Attention weights shape: (batch_size, 1, num_evid_vars)
        attention_output, attention_weights = self.attention(
            query=choice_input, key=evid_features, value=evid_features
        )
        # Concatenate attention-modulated evidence features and query feature
        fused_features = torch.cat([attention_output, choice_input], dim=2)

        # Pass through fusion network
        # Shape: (batch_size, hidden_dim)
        fused_output = self.fusion_net(fused_features)

        # Pass through multiple linear layers with skip connections
        residual = fused_output
        for layer, dp, ly in zip(
            self.skip_connection_layers, self.dropouts, self.layer_norms
        ):
            # Add residual connection
            new_residual = fused_output
            fused_output = ly(fused_output)
            fused_output = layer(fused_output)
            fused_output = dp(fused_output)
            fused_output = self.activation(fused_output)

            fused_output = fused_output + residual  # Skip connection
            residual = new_residual  # Update residual

        optimality_output = self.optimality_head(fused_output)
        decimation_output = self.decimation_head(fused_output)
        return optimality_output.squeeze(dim=2), decimation_output.squeeze(dim=2)

    def forward(self, evidence, all_choices, **kwargs):
        # Process evidence input
        # Shape: (batch_size, num_evid_vars, embed_dim)
        evid_enc = self.embedding_input(evidence)
        # Shape: (batch_size, num_evid_vars, hidden_dim)
        evid_features = self.evid_net(evid_enc)
        return self._get_choice_output(evid_features, choice=all_choices)
