import torch
from torch import nn
from torch.nn import functional as F, init

class ResidualBlock(nn.Module):
    """A general-purpose residual block. Works only with 1-dim inputs."""

    def __init__(self, features, context_features, activation=F.relu, dropout_probability=0.0, use_batch_norm=False, zero_initialization=True):
        super().__init__()
        self.activation = activation

        self.use_batch_norm = use_batch_norm
        if use_batch_norm:
            self.batch_norm_layers = nn.ModuleList([nn.BatchNorm1d(features, eps=1e-3) for _ in range(2)])
        if context_features is not None:
            self.context_layer = nn.Linear(context_features, features)
        self.linear_layers = nn.ModuleList([nn.Linear(features, features) for _ in range(2)])
        self.dropout = nn.Dropout(p=dropout_probability)
        if zero_initialization:
            init.uniform_(self.linear_layers[-1].weight, -1e-3, 1e-3)
            init.uniform_(self.linear_layers[-1].bias, -1e-3, 1e-3)

    def forward(self, inputs, context=None):
        temps = inputs
        if self.use_batch_norm:
            temps = self.batch_norm_layers[0](temps)
        temps = self.activation(temps)
        temps = self.linear_layers[0](temps)
        if self.use_batch_norm:
            temps = self.batch_norm_layers[1](temps)
        temps = self.activation(temps)
        temps = self.dropout(temps)
        temps = self.linear_layers[1](temps)
        if context is not None:
            temps = F.glu(torch.cat((temps, self.context_layer(context)), dim=1), dim=1)
        return inputs + temps


class ResidualNet(nn.Module):
    """A general-purpose residual network. Works only with 1-dim inputs."""

    def __init__(self, in_features, out_features, hidden_features, context_features=None, num_blocks=2, activation=F.relu, dropout_probability=0.0, use_batch_norm=False):
        super().__init__()
        self.hidden_features = hidden_features
        self.context_features = context_features
        if context_features is not None:
            self.initial_layer = nn.Linear(in_features + context_features, hidden_features)
        else:
            self.initial_layer = nn.Linear(in_features, hidden_features)
        self.blocks = nn.ModuleList(
            [
                ResidualBlock(
                    features=hidden_features, context_features=context_features, activation=activation, dropout_probability=dropout_probability, use_batch_norm=use_batch_norm,
                )
                for _ in range(num_blocks)
            ]
        )
        self.final_layer = nn.Linear(hidden_features, out_features)

    def forward(self, inputs, context=None):
        if context is None:
            temps = self.initial_layer(inputs)
        else:
            temps = self.initial_layer(torch.cat((inputs, context), dim=1))
        for block in self.blocks:
            temps = block(temps, context=context)
        outputs = self.final_layer(temps)
        return outputs


class S_ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.xa_dim = 512
        self.image_channel_size = 3
        self.channel_size = 64
        self.num_classes = num_classes

        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)

        # Replace hard flatten with adaptive pooling
        self.pool = nn.AdaptiveAvgPool2d((1, 1))  # Output shape: [B, 256, 1, 1]
        self.fc1 = nn.Linear(256, self.xa_dim)    # Input channel = last conv output channels
        self.fc2 = nn.Linear(self.xa_dim, self.xa_dim)
        self.fc_classifier = nn.Linear(self.xa_dim, self.num_classes)

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        xa = self.forward_to_xa(x)
        classes_p, logits = self.forward_from_xa(xa)
        return classes_p, xa, logits

    def forward_to_xa(self, x):
        xa = F.leaky_relu(self.conv1(x))
        xa = F.leaky_relu(self.conv2(xa))
        xa = F.leaky_relu(self.conv3(xa))
        xa = self.pool(xa)               # shape: [B, 256, 1, 1]
        xa = xa.view(xa.size(0), -1)     # shape: [B, 256]
        xa = F.leaky_relu(self.fc1(xa))  # shape: [B, 512]
        return xa

    def forward_from_xa(self, xa):
        xb = F.leaky_relu(self.fc2(xa))
        logits = self.fc_classifier(xb)
        classes_p = self.softmax(logits)
        return classes_p, logits