import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

class MLPProbe(nn.Module):
    def __init__(self, input_dim, hidden_dim=0):
        super(MLPProbe, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        if hidden_dim == 0:
            self.classifier = nn.Linear(input_dim, 1)
        else:
            self.mlp = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, 1)
            )
        
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        if self.hidden_dim == 0:
            logits = self.classifier(x)
        else:
            logits = self.mlp(x)
        return self.sigmoid(logits.squeeze(-1))
    
def train_probe(model, train_loader, val_loader, alpha, pos_weight, lr=1e-4, epochs=200):
    criterion = nn.BCELoss(weight=pos_weight * alpha if alpha != 1 else pos_weight)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for x, y in train_loader:
            optimizer.zero_grad()
            pred = model(x).squeeze()
            loss = criterion(pred, y.float())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        val_loss = evaluate(model, val_loader, criterion)
        print(f"Epoch {epoch}, Train Loss: {total_loss/len(train_loader):.4f}, Val Loss: {val_loss:.4f}")

def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y in loader:
            pred = model(x).squeeze()
            loss = criterion(pred, y.float())
            total_loss += loss.item()
    return total_loss / len(loader)

if __name__ == "__main__":
    hidden_dict = torch.load("/path/to/save/hidden_states/hidden.pt")
    layer_idx = 48
    layer_data = hidden_dict[layer_idx]
    
    X_list, y_list = [], []
    for sample_idx, v in layer_data.items():
        cor = v["correct"]
        incor = v["incorrect"]
        if cor.numel() > 0:
            X_list.append(cor)
            y_list.append(torch.ones(cor.size(0)))
        if incor.numel() > 0:
            X_list.append(incor)
            y_list.append(torch.zeros(incor.size(0)))
    
    X = torch.cat(X_list, dim=0)
    y = torch.cat(y_list, dim=0)
    
    perm = torch.randperm(len(y))
    train_size = int(0.8 * len(y))
    train_idx, val_idx = perm[:train_size], perm[train_size:]
    train_ds = TensorDataset(X[train_idx], y[train_idx])
    val_ds = TensorDataset(X[val_idx], y[val_idx])
    
    train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
    val_loader   = DataLoader(val_ds,   batch_size=64)

    num_pos = y.sum().item()
    num_neg = len(y) - num_pos
    pos_weight = torch.tensor([num_neg / num_pos])

    input_dim = X.size(1)
    model = MLPProbe(input_dim, hidden_dim=0)

    train_probe(model, train_loader, val_loader, alpha=2.0, pos_weight=pos_weight, lr=1e-4)
    torch.save(model.state_dict(), "/path/to/save/probe/model/model.pth")