import torch
import torch.nn as nn
import torch.nn.functional as F


class NewsMLP(nn.Module):
    """
    3-layer MLP for the NEWS dataset with fixed GloVe embeddings,
    adaptive average pooling, batch normalization, Softsign activation,
    and always returning intermediates.
    """
    def __init__(
        self,
        device,
        embedding_weights,
        embedding_dim=300,
        num_classes=7,
        dropout_rate=0.0
    ):
        super(NewsMLP, self).__init__()
        self.device = device
        # Embedding layer (frozen)
        embedding_weights = torch.tensor(embedding_weights, dtype=torch.float32)
        self.embedding = nn.Embedding.from_pretrained(
            embedding_weights, freeze=True
        )
        # Adaptive pooling to reduce sequence length to 20
        self.pool = nn.AdaptiveAvgPool1d(20)

        # First dense: 20*embedding_dim -> 4*embedding_dim
        self.fc1 = nn.Linear(20 * embedding_dim, 4 * embedding_dim)
        self.bn1 = nn.BatchNorm1d(4 * embedding_dim)

        # Second dense: 4*embedding_dim -> embedding_dim
        self.fc2 = nn.Linear(4 * embedding_dim, embedding_dim)
        self.bn2 = nn.BatchNorm1d(embedding_dim)

        # Final classification layer: embedding_dim -> num_classes
        self.fc3 = nn.Linear(embedding_dim, num_classes)

        # Dropout
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, idx=None, layer_indexes=[], CND_reg_only_last_layer=False, apply_mask=False, por_neuron=1.0):
        """
        x: LongTensor of shape [batch_size, seq_length]
        Always returns (logits, intermediates).
        """
        intermediates = []

        # Embed: [B, T] -> [B, T, D]
        emb = self.embedding(x).to(self.device)

        # Pool: [B, T, D] -> [B, D, 20]
        emb = emb.permute(0, 2, 1)  # [B, D, T]
        emb = self.pool(emb)
        # Restore shape: [B, 20, D] -> [B, 20*D]
        emb = emb.permute(0, 2, 1).contiguous()
        bsz = emb.size(0)
        out = emb.view(bsz, -1)

        intermediates.append(out)

        # First dense + BN + Softsign
        out = self.fc1(out)  # [B, 4*D]
        intermediates.append(out)
        out = self.bn1(out)
        out = F.softsign(out)
        out = self.dropout(out)

        # Second dense + BN + Softsign
        out = self.fc2(out)  # [B, D]
        intermediates.append(out)
        out = self.bn2(out)
        out = F.softsign(out)
        out = self.dropout(out)

        # Final classification
        logits = self.fc3(out)  # [B, num_classes]

        interm = torch.cat(intermediates, dim=1)
        return logits, interm



class FullyConnectedNN_preact_CND(nn.Module):
    """
    Fully connected network for CND with optional embeddings.
    Always returns intermediates.
    """
    def __init__(self, device, input_channels=1, input_size=28, num_classes=10, L=3, N=128,
                 batch_normalization_flag=False, dropout_rate=0.0, embedding_weights=None, activation_fn='relu'):
        super(FullyConnectedNN_preact_CND, self).__init__()
        self.L = L
        self.N = N
        self.batch_normalization_flag = batch_normalization_flag
        self.device = device
        self.dropout_rate = dropout_rate
        self.activation_fn = activation_fn

        if embedding_weights is not None:
            self.embedding = torch.nn.Embedding.from_pretrained(embedding_weights, freeze=False)
            self.embedding_dim = self.embedding.embedding_dim
            self.news_mode = True
        else:
            self.news_mode = False

        if not self.news_mode:  # Images
            self.input_view = input_channels * input_size * input_size
        else:
            self.input_view = self.embedding_dim

        # Fully connected layers
        self.layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList() if batch_normalization_flag else None
        self.dropout = nn.Dropout(dropout_rate)

        # First layer
        self.layers.append(nn.Linear(self.input_view, N))
        if batch_normalization_flag:
            self.batch_norms.append(nn.BatchNorm1d(N))

        # Hidden layers
        for _ in range(L - 1):
            self.layers.append(nn.Linear(N, N))
            if batch_normalization_flag:
                self.batch_norms.append(nn.BatchNorm1d(N))

        # Final output layer
        self.output_layer = nn.Linear(N, num_classes)

    def forward(self, x, idx, layer_indexes=[], CND_reg_only_last_layer=False, apply_mask=False, por_neuron=1.0):
        """
        Forward pass. Always returns (logits, intermediates).
        """
        if getattr(self, "news_mode", False):
            embeddings = self.embedding(x)  # [B, T, D]
            mask = (x != 0).float().unsqueeze(-1)  # [B, T, 1]
            summed = (embeddings * mask).sum(dim=1)
            count = mask.sum(dim=1).clamp(min=1e-9)
            x = summed / count

        x = x.view(-1, self.input_view)
        intermediates = []

        # Pass through each hidden layer
        for i, layer in enumerate(self.layers):
            x = layer(x)

            if (not CND_reg_only_last_layer) or (i == len(self.layers) - 1):
                neuron_index = int(x.shape[1] * por_neuron)
                intermediates.append(x[:, :neuron_index])

            x = F.softsign(x) if self.activation_fn == 'softsign' else F.relu(x)
            x = self.dropout(x)

        logits = self.output_layer(x)
        return logits, torch.cat(intermediates, dim=1)
