import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        # Input dimension = feature dimension + black-box logits dimension
        self.fc = nn.Linear(input_dim, num_classes)
        
    def forward(self, x):
        return self.fc(x)

class LinearHead(nn.Module):
    """Standard linear classification head (for Scratch method)"""
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.fc = nn.Linear(input_dim, num_classes)
    
    def forward(self, x):
        return self.fc(x)

class ConcatHead(nn.Module):
    """Baseline: Concatenates features and black-box logits"""
    def __init__(self, input_dim, num_classes):
        super().__init__()
        # Input dimension = feature dimension + black-box logits dimension
        self.fc = nn.Linear(input_dim + num_classes, num_classes)
        
    def forward(self, features, bb_logits):
        combined = torch.cat([features, bb_logits], dim=-1)
        return self.fc(combined)

class WeightedEnsemble(nn.Module):
    """Baseline: Scalar weighted ensemble"""
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.scratch_head = nn.Linear(input_dim, num_classes)
        self.raw_w = nn.Parameter(torch.zeros(1)) # sigmoid(0) = 0.5
        
    def forward(self, features, bb_logits):
        scratch_logits = self.scratch_head(features)
        w = torch.sigmoid(self.raw_w)
        return w * bb_logits + (1 - w) * scratch_logits

class ResidualModel(nn.Module):
    """
    Our method: BB + Residual
    Key technique: Zero-Init ensures no negative transfer initially
    """
    def __init__(self, input_dim, num_classes, hidden_dim=512):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_classes)
        )
        # Zero-Initialization
        nn.init.zeros_(self.mlp[-1].weight)
        nn.init.zeros_(self.mlp[-1].bias)

    def forward(self, features, bb_logits):
        residual = self.mlp(features)
        return bb_logits + residual
