#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
partition_rankmistral_pac_tc.py — PAC-stable partitioning of transformer MLP neurons
====================================================================================

Key features:
  • General top-responsive (TR) modeling (best-friend, top-k, synergy) with deterministic tie-breaks
  • Informative bucketization of utilities to identify choice sets
  • Per-round coalition resampling on the current remaining set R
  • Proper choice-set estimation from sampled coalitions
  • Directed graph construction from estimated choice sets
  • Minimal closed (sink-SCC) peeling per round (Tarjan SCC)
  • OCA uses Pearson correlation
  • PAS uses a layer-local logit (post-MLP) for gradients
  • Random neuron subset selection via `fraction`
  • ε, δ parameters and sample-complexity guidance for PAC Top-Cover
"""

from __future__ import annotations
import argparse
import math
import os
import pickle
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Set

import torch
import torch.nn as nn
import torch.distributed as dist
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftConfig, PeftModel
from tqdm import tqdm


# ─────────────────────────────  Data loader  ──────────────────────────────

def load_ms_marco_data(n_queries: int,
                       n_docs: int,
                       file_path: str = 'dataset/top1000.dev') -> Dict[str, List[str]]:
    """
    Simple 4-col TSV loader: (qid, docid, query, doc)
    Returns dict: query -> list[doc] (truncated to n_docs each; total <= n_queries queries).
    """
    out: Dict[str, List[str]] = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            cols = line.rstrip('\n').split('\t')
            if len(cols) < 4:
                continue
            _, _, q, doc = cols[:4]
            if q not in out and len(out) >= n_queries:
                continue
            lst = out.setdefault(q, [])
            if len(lst) < n_docs:
                lst.append(doc)
    return out


def pair_text(query: str, doc: str) -> str:
    return f"query: {query} document: {doc}"


# ─────────────────────────────  Model loading  ──────────────────────────────

def load_rankmistral(lora_path: str, device: torch.device):
    """
    Loads a LoRA-tuned model and tokenizer, merges LoRA into base for a single module.
    Adjusts padding token for decoder-like tokenizers where needed.
    """
    cfg  = PeftConfig.from_pretrained(lora_path)
    base = AutoModelForSequenceClassification.from_pretrained(
        cfg.base_model_name_or_path,
        num_labels=1,
        torch_dtype=torch.float16,
        device_map={'': device.index if device.type == 'cuda' else -1}
    )
    tok  = AutoTokenizer.from_pretrained(cfg.base_model_name_or_path)
    if tok.pad_token is None and hasattr(tok, "eos_token"):
        tok.pad_token = tok.eos_token
    base.config.pad_token_id = tok.pad_token_id
    model = PeftModel.from_pretrained(base, lora_path).eval().to(device)
    # Merge LoRA adapters for evaluation simplicity
    try:
        model = model.merge_and_unload()
    except Exception:
        pass
    return model, tok


# ─────────────────────────────  OCA valuation  ──────────────────────────────

def _random_subset_indices(n_total: int, n_keep: int, g: torch.Generator) -> torch.Tensor:
    # Use a CPU generator; randperm with CUDA generators is not universally supported.
    idx = torch.randperm(n_total, generator=g)[:n_keep]
    idx, _ = torch.sort(idx)
    return idx  # CPU LongTensor


def collect_oca_phi(model, tok, data, layer, device, fraction,
                    batch=4, max_len=256, seed=0, dtype=torch.float16):
    """
    OCA φ(i,j) = (1 − |cos(W_i, W_j)|) * Corr[a_i, a_j].
    Uses Pearson correlation for activations captured at gate_proj output.
    """
    assert 0. < fraction <= 1.
    g = torch.Generator().manual_seed(seed)  # CPU generator

    # Access layer MLP; adjust attribute path if architecture differs
    mlp = model.model.layers[layer].mlp
    down = mlp.down_proj
    # Weight matrix: shape (d_model, H). We want neuron vectors W_i in shape (H, d_model).
    W = down.weight.t().to(torch.float32)  # (H, d_model)
    H_full = W.size(0)
    n_keep = max(1, int(round(H_full * fraction)))
    keep_idx = _random_subset_indices(H_full, n_keep, g)  # CPU indices
    W = W.index_select(0, keep_idx.to(W.device))          # CPU idx works on CPU/GPU tensors

    # Cosine similarity (abs) between neuron weight vectors
    W_norm = W / W.norm(dim=1, keepdim=True).clamp_min(1e-6)
    cos_W  = (W_norm @ W_norm.T).abs()  # (n_keep, n_keep)

    # Accumulate first and second moments of activations for Pearson corr
    sum_a   = torch.zeros(n_keep, device=device, dtype=torch.float64)
    sum_aaT = torch.zeros(n_keep, n_keep, device=device, dtype=torch.float64)
    count   = 0

    cache: dict = {}

    def fwd_hook(_, __, hidden):
        # hidden: (B, L, H_full) or (tokens, H_full)
        act = hidden.detach().to(torch.float32).view(-1, H_full)
        cache['a'] = act[:, keep_idx.to(act.device)]  # (tokens, n_keep)

    hook = mlp.gate_proj.register_forward_hook(fwd_hook)
    try:
        pairs = [pair_text(q, d) for q, docs in data.items() for d in docs]
        for i in tqdm(range(0, len(pairs), batch), desc=f'Layer-{layer} OCA', disable=device.index not in {None, 0}):
            enc = tok(pairs[i:i+batch], return_tensors='pt',
                      padding=True, truncation=True, max_length=max_len).to(device)
            with torch.no_grad():
                model(**enc)
            a = cache.pop('a')  # (tokens, n_keep)
            sum_a += a.sum(dim=0, dtype=torch.float64)
            sum_aaT += (a.T @ a).to(torch.float64)
            count += a.size(0)
    finally:
        hook.remove()

    mu  = (sum_a / max(count, 1)).to(torch.float32)                     # (n_keep,)
    Eaa = (sum_aaT / max(count, 1)).to(torch.float32)                    # (n_keep, n_keep)
    cov = Eaa - mu.unsqueeze(1) @ mu.unsqueeze(0)                        # covariance
    var = cov.diag().clamp_min(1e-8)
    std = var.sqrt()
    corr = cov / (std.unsqueeze(1) * std.unsqueeze(0)).clamp_min(1e-8)   # Pearson

    phi = (1.0 - cos_W) * corr
    phi.fill_diagonal_(0)
    return phi.to(dtype=dtype), keep_idx


# ─────────────────────────────  PAS valuation  ──────────────────────────────

class _LayerCaches:
    def __init__(self):
        self.gate_out = None     # (tokens, H_full) gate projection output
        self.post_mlp = None     # (tokens, d_model) decoder layer output (post-MLP + residual)


def collect_pas_phi(model, tok, data, layer, device, fraction,
                    batch=4, max_len=256, seed=0, dtype=torch.float16):
    """
    PAS proxy:
        φ(i,j) ≈ - E_x[ (∂ℓ_local/∂a_i · a_i)(∂ℓ_local/∂a_j · a_j) ]
    where ℓ_local is a layer-local scalar after the MLP + residual,
    computed by applying the (biasless) final head weight to the decoder-layer output.

    OOM-safe: we cut the autograd graph at gate_proj by detaching its output and
    re-introducing it as a leaf that requires grad. Model parameters are kept
    requires_grad=False to avoid retaining massive upstream state.
    """
    assert 0. < fraction <= 1.
    g = torch.Generator().manual_seed(seed)  # CPU generator

    mlp      = model.model.layers[layer].mlp
    H_full   = mlp.gate_proj.out_features
    n_keep   = max(1, int(round(H_full * fraction)))
    keep_idx = _random_subset_indices(H_full, n_keep, g)  # CPU indices

    # Classification head: Linear(d_model->1, bias=False) in your printed model
    head_w = model.score.weight.detach().squeeze(0).to(device)     # (d_model,)
    head_b = torch.zeros((), device=device, dtype=head_w.dtype)    # scalar 0.0 for bias=False

    class _Caches:
        gate_out_raw = None   # raw tensor from gate_proj (B,L,H_full) or (tokens,H_full)
        post_mlp     = None   # (tokens, d_model)
    caches = _Caches()

    # Gate hook: detach upstream and return a new leaf that requires grad
    def hook_gate(_, __, out):
        out_det = out.detach()            # cut the graph upstream of gate_proj
        out_det.requires_grad_(True)      # new leaf so autograd tracks downstream ops
        caches.gate_out_raw = out_det     # keep RAW; no reshape here
        return out_det                    # IMPORTANT: replace module output with the leaf

    # Post-MLP hook: unwrap tuple from decoder layer; keep hidden_states
    def hook_post_mlp(_, __, out):
        hidden = out[0] if isinstance(out, (tuple, list)) else out
        caches.post_mlp = hidden.contiguous().view(-1, hidden.size(-1))  # (tokens, d_model)

    h1 = mlp.gate_proj.register_forward_hook(hook_gate)
    block = model.model.layers[layer]
    h2 = block.register_forward_hook(hook_post_mlp)

    # Make sure params DO NOT require grad (we only need grads w.r.t. gate_out_raw)
    for p in model.parameters():
        if p.requires_grad:
            p.requires_grad_(False)

    vv_accum  = torch.zeros(n_keep, n_keep, device=device, dtype=torch.float32)
    token_cnt = 0

    try:
        with torch.enable_grad():  # record ops so autograd can backprop to gate_out_raw
            pairs = [pair_text(q, d) for q, docs in data.items() for d in docs]
            for i in tqdm(range(0, len(pairs), batch), desc=f'Layer-{layer} PAS',
                          disable=device.index not in {None, 0}):
                enc = tok(pairs[i:i+batch], return_tensors='pt',
                          padding=True, truncation=True, max_length=max_len).to(device)

                # Forward pass populates caches.gate_out_raw and caches.post_mlp
                _ = model(**enc)

                # Layer-local scalar using the (biasless) head
                local_logits = (caches.post_mlp @ head_w) + head_b  # (tokens,)
                loss_like = local_logits.sum()

                # Gradients w.r.t. the RAW gate output, then flatten & slice
                gate_grads_raw = torch.autograd.grad(
                    loss_like, caches.gate_out_raw,
                    retain_graph=False, create_graph=False, allow_unused=False
                )[0]  # shape matches gate_out_raw

                act_flat  = caches.gate_out_raw.view(-1, H_full)  # (tokens, H_full)
                grad_flat = gate_grads_raw.view(-1, H_full)       # (tokens, H_full)

                idx = keep_idx.to(act_flat.device)
                act_slice  = act_flat[:,  idx]                    # (tokens, n_keep)
                grad_slice = grad_flat[:, idx]                    # (tokens, n_keep)

                v = (grad_slice * act_slice).float()              # Hadamard vector
                vv_accum += v.T @ v
                token_cnt += v.size(0)

                # Clear caches for next batch
                caches.gate_out_raw = None
                caches.post_mlp = None

    finally:
        h1.remove()
        h2.remove()
        torch.set_grad_enabled(False)

    phi = -(vv_accum / max(token_cnt, 1)).to(dtype=dtype)
    phi.fill_diagonal_(0)
    return phi, keep_idx



# ─────────────────────────────  TR utilities  ──────────────────────────────

def tr_choice_set_and_utility_for_S(
    phi_ij: torch.Tensor,      # (n, n) affinity from i to j
    psi_jk: Optional[torch.Tensor],  # (n, n) partner-partner synergy (may be None)
    S_idx: torch.Tensor,       # indices in [0..n-1] of the coalition S (CPU LongTensor)
    i_pos_in_R: int,           # position of i within S (0..|S|-1)
    model: str,
    k_top: int = 2,
    lam: float = 0.0,
    alpha: float = 100.0,
) -> Tuple[torch.Tensor, int]:
    """
    For a given coalition S and focal player i (in S), compute:
      - Ch_i_S : tensor of partner indices (subset of S\\{i}) per TR model
      - bucket: integer bucket = floor(alpha * utility_i(S))
    Deterministic tie-breaks are enforced lexicographically on indices.
    """
    S = S_idx.tolist()
    i = S[i_pos_in_R]
    others = [x for x in S if x != i]
    if not others:
        return torch.tensor([], dtype=torch.long), 0

    if model == "best_friend":
        j_best = max(others, key=lambda j: (float(phi_ij[i, j]), -j))
        Ch = torch.tensor([j_best], dtype=torch.long)
        utility = float(phi_ij[i, j_best])

    elif model == "topk":
        k = min(max(1, k_top), len(others))
        others_sorted = sorted(others, key=lambda j: (float(phi_ij[i, j]), -j), reverse=True)
        top = sorted(others_sorted[:k])
        Ch = torch.tensor(top, dtype=torch.long)
        utility = float(torch.tensor([phi_ij[i, j] for j in top]).mean())

    elif model == "synergy2":
        best_tuple = None
        best_score = -1e30
        # size 1
        for j in others:
            score = float(phi_ij[i, j])
            if score > best_score or (math.isclose(score, best_score) and (best_tuple is None or (len(best_tuple) > 1) or -j > -max(best_tuple))):
                best_score = score
                best_tuple = (j,)
        # size 2
        if len(others) >= 2:
            for a_idx in range(len(others)):
                for b_idx in range(a_idx + 1, len(others)):
                    j, k = others[a_idx], others[b_idx]
                    pair_synergy = float(psi_jk[j, k]) if psi_jk is not None else 0.0
                    score = float(phi_ij[i, j] + phi_ij[i, k] + lam * pair_synergy)
                    if score > best_score or (math.isclose(score, best_score) and (best_tuple is None or tuple(sorted((j, k))) < tuple(sorted(best_tuple)))):
                        best_score = score
                        best_tuple = (j, k)
        Ch = torch.tensor(sorted(best_tuple), dtype=torch.long)
        utility = best_score

    else:
        raise ValueError(f"Unknown TR model: {model}")

    bucket = int(math.floor(alpha * utility))
    return Ch, bucket


# ─────────────────────  Coalition sampling (per-round)  ─────────────────────

def sample_coalitions_from_R(
    R: torch.Tensor,             # tensor of indices in [0..n-1], can live on CUDA
    m: int,
    g: torch.Generator,
    min_k: int = 2,
    max_k: int = 6,
) -> List[torch.Tensor]:
    """
    Sample m coalitions uniformly from subsets of R with size in [min_k, max_k].
    Returns CPU LongTensors for indexing stability.
    """
    n = R.numel()
    if n == 0:
        return []
    if min_k > max_k:
        min_k = max_k
    max_k = min(max_k, n)
    min_k = min(min_k, max_k)
    out: List[torch.Tensor] = []
    for _ in range(m):
        k = int(torch.randint(min_k, max_k + 1, (1,), generator=g).item())
        perm = torch.randperm(n, generator=g)[:k]  # CPU indices into R
        S = R[perm]  # OK: CUDA/CPU tensor indexed by CPU LongTensor
        out.append(S.cpu())
    return out


# ─────────────────────  Choice-set estimation from samples  ─────────────────────

def estimate_choice_sets_from_samples(
    R: torch.Tensor,
    phi: torch.Tensor,                 # (n, n) per i->j affinity
    psi: Optional[torch.Tensor],       # (n, n) partner-partner synergy (for synergy model)
    S_list: List[torch.Tensor],        # sampled coalitions on R (CPU LongTensors)
    tr_model: str,
    k_top: int,
    lam: float,
    alpha: float,
) -> Dict[int, torch.Tensor]:
    """
    For each i in R:
      • compute per-sample (Ch_i(S), bucket_i(S))
      • take the highest bucket observed for i
      • intersect per-sample Ch_i(S) across all S with that highest bucket to get a unique Ch_hat(i,R)
    Returns dict: i -> Ch_hat(i,R) (CPU LongTensor of partner indices in R).
    """
    per_i: Dict[int, List[Tuple[int, torch.Tensor]]] = {int(i.item()): [] for i in R}

    for S in S_list:
        S_cpu = S.cpu()
        for idx_i in range(len(S_cpu)):
            i = int(S_cpu[idx_i].item())
            Ch_i_S, bucket = tr_choice_set_and_utility_for_S(
                phi_ij=phi,
                psi_jk=psi,
                S_idx=S_cpu,
                i_pos_in_R=idx_i,
                model=tr_model,
                k_top=k_top,
                lam=lam,
                alpha=alpha,
            )
            per_i[i].append((bucket, Ch_i_S.cpu()))

    Ch_hat: Dict[int, torch.Tensor] = {}
    for i, lst in per_i.items():
        if not lst:
            Ch_hat[i] = torch.tensor([], dtype=torch.long)
            continue
        max_bucket = max(b for b, _ in lst)
        candidates = [Ch for (b, Ch) in lst if b == max_bucket and Ch.numel() > 0]
        if not candidates:
            Ch_hat[i] = torch.tensor([], dtype=torch.long)
            continue
        candidates = sorted(candidates, key=lambda t: t.numel())
        intersect_set: Set[int] = set(int(x.item()) for x in candidates[0])
        for t in candidates[1:]:
            intersect_set &= set(int(x.item()) for x in t)
        if not intersect_set:
            chosen = candidates[0]
        else:
            chosen = torch.tensor(sorted(intersect_set), dtype=torch.long)
        Ch_hat[i] = chosen

    return Ch_hat


# ─────────────────────  Directed SCC (Tarjan) and closed-set selection  ─────────────────────

def tarjan_scc(adj: List[List[int]]) -> List[List[int]]:
    N = len(adj)
    index = 0
    indices = [-1] * N
    lowlink = [0] * N
    onstack = [False] * N
    stack: List[int] = []
    sccs: List[List[int]] = []

    def strongconnect(v: int):
        nonlocal index
        indices[v] = index
        lowlink[v] = index
        index += 1
        stack.append(v)
        onstack[v] = True

        for w in adj[v]:
            if indices[w] == -1:
                strongconnect(w)
                lowlink[v] = min(lowlink[v], lowlink[w])
            elif onstack[w]:
                lowlink[v] = min(lowlink[v], indices[w])

        if lowlink[v] == indices[v]:
            comp: List[int] = []
            while True:
                w = stack.pop()
                onstack[w] = False
                comp.append(w)
                if w == v:
                    break
            sccs.append(sorted(comp))

    for v in range(N):
        if indices[v] == -1:
            strongconnect(v)
    return sccs


def pick_minimal_sink_closed_set(
    R: torch.Tensor,
    Ch_hat: Dict[int, torch.Tensor],
) -> List[int]:
    idx2node = [int(x.item()) for x in R]
    node2idx = {idx2node[i]: i for i in range(len(idx2node))}
    M = len(idx2node)

    adj = [[] for _ in range(M)]
    for i_node in idx2node:
        i_local = node2idx[i_node]
        for j_node in Ch_hat.get(i_node, torch.tensor([], dtype=torch.long)).tolist():
            if j_node in node2idx:
                adj[i_local].append(node2idx[j_node])

    sccs = tarjan_scc(adj)

    def is_sink_scc(comp: List[int]) -> bool:
        comp_set = set(comp)
        for u in comp:
            for v in adj[u]:
                if v not in comp_set:
                    return False
        return True

    def is_closed(comp: List[int]) -> bool:
        comp_nodes = [idx2node[u] for u in comp]
        comp_set_nodes = set(comp_nodes)
        for u in comp:
            i_node = idx2node[u]
            for j_node in Ch_hat.get(i_node, torch.tensor([], dtype=torch.long)).tolist():
                if j_node not in comp_set_nodes:
                    return False
        return True

    candidates: List[List[int]] = []
    for comp in sccs:
        if is_sink_scc(comp) and is_closed(comp):
            candidates.append(comp)

    if not candidates:
        return [min(idx2node)]

    candidates.sort(key=lambda comp: (len(comp), [idx2node[u] for u in comp]))
    chosen = candidates[0]
    return [idx2node[u] for u in chosen]


# ─────────────────────  PAC Top-Cover main loop  ─────────────────────

def pac_top_cover(
    phi: torch.Tensor,           # (n, n)
    psi: Optional[torch.Tensor], # (n, n) partner-partner synergy (may be None)
    epsilon: float,
    delta: float,
    rng: torch.Generator,
    tr_model: str = "best_friend",
    k_top: int = 2,
    lam: float = 0.0,
    alpha: float = 100.0,
    samples_per_round: int = 2048,
    min_k: int = 2,
    max_k: int = 6,
) -> List[List[int]]:
    device = phi.device
    n = phi.size(0)
    R = torch.arange(n, device=device)
    parts: List[List[int]] = []

    m_needed = int((2 * n**4 + 2 * n**3) * (1.0 / max(epsilon, 1e-6)) * math.log(max(2 * n**3 / max(delta, 1e-12), 1.0001)))
    print(f"[PAC] Guidance: total samples m_needed ≈ {m_needed:,}. Using {samples_per_round} per round.")

    # CPU generator for sampling coalitions
    g_cpu = torch.Generator().manual_seed(int(rng.initial_seed()) if hasattr(rng, "initial_seed") else 0)

    while R.numel() > 0:
        S_list = sample_coalitions_from_R(R, samples_per_round, g_cpu, min_k=min_k, max_k=max_k)
        print("Coalitions Sampled:", len(S_list), " Remaining neurons:", R.numel())
        Ch_hat = estimate_choice_sets_from_samples(
            R=R, phi=phi, psi=psi, S_list=S_list,
            tr_model=tr_model, k_top=k_top, lam=lam, alpha=alpha
        )
        print("Choice Sets Estimated.")
        empty_rate = sum(1 for i in R.tolist() if Ch_hat[i].numel()==0) / len(R)
        print(f"Empty Choice set rate: {empty_rate:.2%}")

        X_nodes = pick_minimal_sink_closed_set(R, Ch_hat)
        print(f"Picked closed sink set of size {len(X_nodes)}: {X_nodes}")
        parts.append(sorted(X_nodes))

        X_set = set(X_nodes)
        R = torch.tensor([i for i in R.tolist() if i not in X_set], device=device, dtype=torch.long)

    return parts


# ─────────────────────────────  Main  ──────────────────────────────

def main():
    p = argparse.ArgumentParser()
    p.add_argument('--model', default='AnonymousForReview2/watereddown_reranker_mistral_cqtr_mlp_only')
    p.add_argument('--ms_file', default='dataset/top1000.dev')
    p.add_argument('--n_queries', type=int, default=50)
    p.add_argument('--n_docs',    type=int, default=10)

    p.add_argument('--layer',     type=int, default=9)
    p.add_argument('--fraction',  type=float, default=0.2)
    p.add_argument('--valuation', choices=['oca','pas'], default='pas')

    # TR model options
    p.add_argument('--tr_model', choices=['best_friend','topk','synergy2'], default='best_friend')
    p.add_argument('--k_top', type=int, default=2, help='k for top-k TR model')
    p.add_argument('--lambda_synergy', type=float, default=0.0, help='lambda for partner-partner synergy')
    p.add_argument('--alpha_bucket', type=float, default=100.0, help='utility bucket scaling (floor(alpha*u))')

    # PAC parameters
    p.add_argument('--epsilon', type=float, default=0.05)
    p.add_argument('--delta',   type=float, default=0.05)
    p.add_argument('--samples_per_round', type=int, default=4096)
    p.add_argument('--min_k', type=int, default=2)
    p.add_argument('--max_k', type=int, default=6)

    # Tokenization / batching
    p.add_argument('--batch_size', type=int, default=4)
    p.add_argument('--max_len',    type=int, default=256)

    p.add_argument('--seed',       type=int, default=42)
    p.add_argument('--output',     default='partition.pkl')
    args = p.parse_args()

    if not (0. < args.fraction <= 1.):
        raise ValueError('--fraction must be in (0,1]')

    rank  = int(os.getenv('RANK', 0))
    world = int(os.getenv('WORLD_SIZE', 1))
    device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')

    if world > 1 and not dist.is_initialized():
        dist.init_process_group(backend='nccl', rank=rank, world_size=world)

    model, tok = load_rankmistral(args.model, device)
    data       = load_ms_marco_data(args.n_queries, args.n_docs, args.ms_file)

    if world > 1:
        items = list(data.items())
        data = dict(items[rank::world])

    if args.valuation == 'oca':
        phi, keep_idx = collect_oca_phi(
            model, tok, data,
            layer=args.layer, device=device,
            fraction=args.fraction, batch=args.batch_size,
            max_len=args.max_len, seed=args.seed
        )
        psi = phi.clone()
    else:
        # PAS requires autograd
        phi, keep_idx = collect_pas_phi(
            model, tok, data,
            layer=args.layer, device=device,
            fraction=args.fraction, batch=args.batch_size,
            max_len=args.max_len, seed=args.seed
        )
        psi = phi.clone()

    rng = torch.Generator().manual_seed(args.seed + rank)  # CPU generator seed

    parts = pac_top_cover(
        phi=phi,
        psi=psi if args.tr_model == 'synergy2' else None,
        epsilon=args.epsilon,
        delta=args.delta,
        rng=rng,
        tr_model=args.tr_model,
        k_top=args.k_top,
        lam=args.lambda_synergy,
        alpha=args.alpha_bucket,
        samples_per_round=args.samples_per_round,
        min_k=args.min_k,
        max_k=args.max_k,
    )

    # Map neuron indices back to original neuron ids via keep_idx (CPU tensor)
    parts_mapped = [[int(keep_idx[i].item()) for i in comp] for comp in parts]

    if rank == 0:
        Path(args.output).parent.mkdir(parents=True, exist_ok=True)
        with open(args.output, 'wb') as f:
            pickle.dump(parts_mapped, f)
        print(f'✓ Saved {len(parts_mapped)} coalitions → {args.output}')

    if world > 1:
        dist.barrier()
        dist.destroy_process_group()


if __name__ == '__main__':
    main()