"""Nonlinear (MLP) probe vs linear probe for node decoding from GRU hidden states."""

import json
import sys
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.linear_model import LogisticRegression

SEEDS = range(5)
HIDDEN_DIM = 128
N_NODES = 127
MLP_HIDDEN = 256
DROPOUT = 0.1
LR = 1e-3
MAX_EPOCHS = 100
PATIENCE = 10
BATCH_SIZE = 256
OBS_CLASS_NAMES = ["Root", "Depth 1", "Depth 2-4", "Depth 5", "Leaves"]


class MLPProbe(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_classes):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(DROPOUT),
            nn.Linear(hidden_dim, n_classes),
        )

    def forward(self, x):
        return self.net(x)


def reconstruct_traj_ids(timesteps):
    """Reconstruct trajectory IDs from timestep array (resets to 0 at each trajectory start)."""
    return np.cumsum(timesteps == 0) - 1


def trajectory_split(traj_ids, train_frac=0.8):
    """80/20 split by trajectory ID, matching analyze_structural_hidden.py."""
    unique_tids = np.unique(traj_ids)
    n_train = int(train_frac * len(unique_tids))
    train_tids = set(unique_tids[:n_train])
    train_mask = np.array([tid in train_tids for tid in traj_ids])
    return train_mask, ~train_mask


def train_mlp_probe(X_train, y_train, X_val, y_val, n_classes):
    device = torch.device("cpu")
    model = MLPProbe(X_train.shape[1], MLP_HIDDEN, n_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()

    train_ds = TensorDataset(torch.tensor(X_train, dtype=torch.float32),
                             torch.tensor(y_train, dtype=torch.long))
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

    X_val_t = torch.tensor(X_val, dtype=torch.float32).to(device)
    y_val_t = torch.tensor(y_val, dtype=torch.long).to(device)

    best_val_loss = float("inf")
    best_val_acc = 0.0
    patience_counter = 0
    best_state = None

    for epoch in range(MAX_EPOCHS):
        model.train()
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            loss = criterion(model(X_batch), y_batch)
            loss.backward()
            optimizer.step()

        model.eval()
        with torch.no_grad():
            val_logits = model(X_val_t)
            val_loss = criterion(val_logits, y_val_t).item()
            val_preds = val_logits.argmax(dim=1)
            val_acc = (val_preds == y_val_t).float().mean().item()

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_acc = val_acc
            patience_counter = 0
            best_state = {k: v.clone() for k, v in model.state_dict().items()}
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                break

    model.load_state_dict(best_state)
    model.eval()
    with torch.no_grad():
        val_preds = model(X_val_t).argmax(dim=1).cpu().numpy()

    return best_val_acc, val_preds, epoch + 1


def per_class_accuracy(y_true, y_pred, obs_classes, n_obs_classes=5):
    results = {}
    for cls_id in range(n_obs_classes):
        mask = obs_classes == cls_id
        if mask.sum() == 0:
            continue
        n_nodes = len(np.unique(y_true[mask]))
        chance = 1.0 / max(n_nodes, 1)
        acc = np.mean(y_pred[mask] == y_true[mask])
        results[OBS_CLASS_NAMES[cls_id]] = {
            "accuracy": float(acc),
            "chance": float(chance),
            "n_nodes": int(n_nodes),
            "n_samples": int(mask.sum()),
        }
    return results


def run_seed(seed):
    ckpt_path = f"checkpoints/structural_hidden_analysis_seed{seed}.pt"
    ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)

    hidden_states = ckpt["hidden_states"]
    positions = ckpt["positions"]
    obs_classes = ckpt["obs_classes"]
    timesteps = ckpt["timesteps"]

    traj_ids = reconstruct_traj_ids(timesteps)
    train_mask, test_mask = trajectory_split(traj_ids)

    X_train, y_train = hidden_states[train_mask], positions[train_mask]
    X_test, y_test = hidden_states[test_mask], positions[test_mask]
    obs_test = obs_classes[test_mask]

    n_visited = len(np.unique(positions))
    chance = 1.0 / n_visited

    clf = LogisticRegression(max_iter=1000, solver="lbfgs", C=1.0)
    clf.fit(X_train, y_train)
    linear_acc = clf.score(X_test, y_test)
    linear_preds = clf.predict(X_test)
    linear_per_class = per_class_accuracy(y_test, linear_preds, obs_test)

    mlp_acc, mlp_preds, n_epochs = train_mlp_probe(X_train, y_train, X_test, y_test, N_NODES)
    mlp_per_class = per_class_accuracy(y_test, mlp_preds, obs_test)

    result = {
        "seed": int(seed),
        "n_visited_nodes": int(n_visited),
        "chance": float(chance),
        "linear_acc": float(linear_acc),
        "mlp_acc": float(mlp_acc),
        "mlp_epochs": int(n_epochs),
        "linear_per_class": linear_per_class,
        "mlp_per_class": mlp_per_class,
    }

    print(f"Seed {seed}: linear={linear_acc:.1%}  MLP={mlp_acc:.1%}  "
          f"(chance={chance:.1%}, MLP epochs={n_epochs})")
    return result


def main():
    all_results = []

    for seed in SEEDS:
        result = run_seed(seed)
        all_results.append(result)

    linear_accs = [r["linear_acc"] for r in all_results]
    mlp_accs = [r["mlp_acc"] for r in all_results]
    chance = np.mean([r["chance"] for r in all_results])

    print(f"linear {np.mean(linear_accs):.1%} +/- {np.std(linear_accs):.1%}, "
          f"MLP {np.mean(mlp_accs):.1%} +/- {np.std(mlp_accs):.1%} "
          f"(+{np.mean(mlp_accs) - np.mean(linear_accs):.1%}), chance {chance:.1%}")

    print(f"\nPer-class accuracy (mean across seeds) ")
    print(f"{'Class':<12} {'Linear':>8} {'MLP':>8} {'Chance':>8} {'N nodes':>8}")
    for cls_name in OBS_CLASS_NAMES:
        lin_vals = [r["linear_per_class"][cls_name]["accuracy"]
                    for r in all_results if cls_name in r["linear_per_class"]]
        mlp_vals = [r["mlp_per_class"][cls_name]["accuracy"]
                    for r in all_results if cls_name in r["mlp_per_class"]]
        ch = np.mean([r["linear_per_class"][cls_name]["chance"]
                      for r in all_results if cls_name in r["linear_per_class"]])
        nn = int(np.mean([r["linear_per_class"][cls_name]["n_nodes"]
                          for r in all_results if cls_name in r["linear_per_class"]]))
        if lin_vals:
            print(f"{cls_name:<12} {np.mean(lin_vals):>7.1%} {np.mean(mlp_vals):>7.1%} "
                  f"{ch:>7.1%} {nn:>8d}")

    out_path = "checkpoints/nonlinear_probe_results.json"
    with open(out_path, "w") as f:
        json.dump({"seeds": all_results, "summary": {
            "chance": float(chance),
            "linear_mean": float(np.mean(linear_accs)),
            "linear_std": float(np.std(linear_accs)),
            "mlp_mean": float(np.mean(mlp_accs)),
            "mlp_std": float(np.std(mlp_accs)),
        }}, f, indent=2)
    print(f"\nResults saved to {out_path}")


if __name__ == "__main__":
    main()
