import torch
import torch.nn as nn

class LinearHead(nn.Module):
    """
    Standard linear classification head used by Scratch method.
    Input: Qwen Hidden States
    """
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.fc = nn.Linear(input_dim, num_classes)
    
    def forward(self, x):
        return self.fc(x)

class ConcatHead(nn.Module):
    """
    Concat baseline: Concatenates features and black-box logits.
    Input: [Qwen Hidden States, Black-box Logits]
    """
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.fc = nn.Linear(input_dim + num_classes, num_classes)
        
    def forward(self, features, bb_logits):
        combined = torch.cat([features, bb_logits], dim=-1)
        return self.fc(combined)

class WeightedEnsemble(nn.Module):
    """
    Weighted baseline: Learns a scalar weight w.
    Output = w * BB + (1-w) * Scratch
    """
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.scratch_head = nn.Linear(input_dim, num_classes)
        self.raw_w = nn.Parameter(torch.zeros(1)) 
        
    def forward(self, features, bb_logits):
        scratch_logits = self.scratch_head(features)
        w = torch.sigmoid(self.raw_w)
        return w * bb_logits + (1 - w) * scratch_logits

class ResidualModel(nn.Module):
    """
    Ours: Residual Estimator
    Input: Features, BB_Logits
    Output: BB_Logits + MLP(Features)
    
    Key technique: Zero-Initialization
    This ensures that in the initial training phase, the model output is exactly equal to the black-box output,
    thus avoiding the negative transfer risk caused by random initialization.
    """
    def __init__(self, input_dim, num_classes, hidden_dim=1024):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_classes)
        )
        
        nn.init.zeros_(self.mlp[-1].weight)
        nn.init.zeros_(self.mlp[-1].bias)

    def forward(self, features, bb_logits):
        residual = self.mlp(features)
        return bb_logits + residual