
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            # If no node features, use constant
            if batch.x is None:
                batch.x = torch.ones(batch.num_nodes, 1, device=device)

            # Split into lists for the MV-HH pipeline
            X_list, E_list, y_list = split_batch_to_graphs(batch)
            treeG_list, S_list = [], []
            # Build hierarchy and Haar bases
            treeG, S_assign = Make_tree_real2(
                X_list[0], E_list[0], encoder,
                levels=args.levels, ratio=args.ratio,
                temp=args.temp, tau=args.tau
            )
            treeG = HaarGOB_with_Sassign(treeG, S_assign)

            # Extract U and features per level
            U_list    = [ torch.from_numpy(np.stack(treeG[l]['u'], axis=0).T).float().to(device)
                          for l in range(args.levels-1) ]
            feats_list = [ torch.from_numpy(treeG[l]['features']).float().to(device)
                           for l in range(args.levels-1) ]

            # Forward through classifier
            logits, H_list, Hhat_list = model(U_list, feats_list, treeG, S_assign_list=S_assign)
            # Graph‐level output by averaging node logits
            graph_logits = logits.mean(dim=0, keepdim=True)  # [1, C]
            true_label   = y_list[0].view(1).to(device)      # graph label

            # Metrics
            pred_class = graph_logits.argmax(dim=-1)
            acc = (pred_class == true_label).float().item() * 100
            L_div = loss_diversity_from_S(S_assign, device=device).item()
            L_rec = loss_reconstruction_from_lists(H_list, Hhat_list).item()

            print(f"[{args.dataset.capitalize()}] Acc: {acc:.2f}%  "
                  f"L_div: {L_div:.4f}  L_rec: {L_rec:.2f}")

if __name__ == "__main__":
    main()
