#!/usr/bin/env python
# probe_first_segment.py
# ------------------------------------------------------------------------
#   Probe CAMAE‑style encoders (+ InterTTM) with six heads:
#   • LOCAL   : top‑K tokens  → column‑id                 (n_cols‑way)
#   • GLOBAL  : top‑K tokens  → own‑stream energy‑bin     (n_bins‑way)
#   • CROSS   : top‑K tokens  → other‑stream energy‑bin   (n_bins‑way)
# ------------------------------------------------------------------------
import argparse, os, torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader
import wandb

# ─────────────────────────── project imports ────────────────────────────
from models.ast import CAMAEEncoder
from models.vittm.laddersym import vittm_base_patch16_224, laddersym
from dataset.dataset_2_random import Dataset

# ------------------------------------------------------------------------
#                           PRESET DICTIONARY
# ------------------------------------------------------------------------
PRESETS = {
    # name : (checkpoint_path,      constructor,            feat_fn_ctor, mod_depth)
    "poly": (
        "/home/code/outputs/Poly_updated/23-02-29/"
        "polytuneNet_MAESTRO/l2524hbm/checkpoints/best.ckpt",
        CAMAEEncoder,
        lambda e: lambda a1, a2: get_feats_camae(e, a1, a2),
        11,
    ),
    "full": (
        "/home/code/outputs/2025-03-26/12-33-57/"
        "polytuneNet_MAESTRO/lam5ftyv/checkpoints/last.ckpt",
        CAMAEEncoder,
        lambda e: lambda a1, a2: get_feats_camae(e, a1, a2),
        0,
    ),
    "inter": (
        "/home/code/outputs/2025-04-23/11-01-01/"
        "interttm_MT3Net_MAESTRO/2tx99a3d/checkpoints/best.ckpt",
        vittm_base_patch16_224,  # returns InterTTM
        lambda e: lambda a1, a2: get_feats_inter(e, a1, a2),
        None,
    ),
}


# ------------------------------------------------------------------------
#                         LINEAR PROBE HEADS
# ------------------------------------------------------------------------
class ColumnProbe(nn.Module):  # LOCAL
    def __init__(self, d, n_cols):
        super().__init__()
        self.fc = nn.Linear(d, n_cols)

    def forward(self, z):  # (B,K,d) → (B,K,n_cols)
        return self.fc(z)


class GlobalProbe(nn.Module):  # GLOBAL / CROSS
    def __init__(self, d, n_bins):
        super().__init__()
        self.fc = nn.Linear(d, n_bins)

    def forward(self, z):  # (B,K,d) → (B,K,n_bins)
        return self.fc(z)


# ------------------------------------------------------------------------
#                    ENCODER → TOKEN‑FEATURE HELPERS
# ------------------------------------------------------------------------
@torch.no_grad()
def get_feats_camae(enc: CAMAEEncoder, a1, a2):
    """Return (z1, z2)  each shape (B,P,768)  BEFORE enc.proj."""
    # patch‑embed  +  pos + modality
    a1 = enc.patch_embed_a1(a1.unsqueeze(1).transpose(2, 3))
    a2 = enc.patch_embed_a2(a2.unsqueeze(1).transpose(2, 3))
    a1 = a1 + enc.pos_embed_a1 + enc.modality_a1
    a2 = a2 + enc.pos_embed_a2 + enc.modality_a2

    if len(enc.blocks_a1):  # modality‑specific stack
        for blk in enc.blocks_a1:
            a1 = blk(a1)
        for blk in enc.blocks_a2:
            a2 = blk(a2)
    else:  # only joint blocks
        x = torch.cat([a1, a2], 1)
        for blk in enc.blocks_u:
            x = blk(x)
        x = enc.norm(x)
        P = enc.patch_embed_a1.num_patches
        a1, a2 = x[:, :P], x[:, P:]
    return a1, a2


@torch.no_grad()
def get_feats_inter(enc: InterTTM, a1, a2):
    """InterTTM already returns (memory, process)."""
    return enc.forward_features(a1, a2, visualize_ca=False)  # (B,P,768)×2


# ------------------------------------------------------------------------
#                     LOAD ENCODER  (skip decoder proj.*)
# ------------------------------------------------------------------------
def load_encoder(preset: str, device="cpu"):
    if preset not in PRESETS:
        raise ValueError(f"Unknown preset {preset!r}. Choose from {list(PRESETS)}.")

    ckpt_path, ctor, feat_ctor, mod_depth = PRESETS[preset]
    enc = (
        ctor(pretrained=False)
        if preset == "inter"
        else ctor(total_depth=12, modality_specific_depth=mod_depth)
    )

    sd_raw = torch.load(ckpt_path, map_location="cpu")
    sd_raw = sd_raw.get("state_dict", sd_raw)

    enc_sd = {
        k.replace("model.encoder.", ""): v
        for k, v in sd_raw.items()
        if k.startswith("model.encoder.") and "proj" not in k.split(".")[2:3]
    }
    _ = enc.load_state_dict(enc_sd, strict=False)

    enc.to(device).eval()
    for p in enc.parameters():
        p.requires_grad_(False)

    feat_fn = feat_ctor(enc)

    if isinstance(enc, CAMAEEncoder):
        n_rows = enc.mel_bins // enc.patch_embed_a1.patch_size[0]
        n_cols = enc.patch_embed_a1.num_patches // n_rows
    else:  # InterTTM: 14×14 patch grid (224 / 16)
        n_cols = 14
    return feat_fn, n_cols


# ------------------------------------------------------------------------
#                DATASET WRAPPER – only first spectrogram slice
# ------------------------------------------------------------------------
class FirstSeg(Dataset):
    def __getitem__(self, idx):
        m, s, *_ = super().__getitem__(idx)
        return m[0], s[0]  # (T, F)


def collate_firstseg(batch):
    a1, a2 = zip(*batch)
    return torch.stack(a1), torch.stack(a2)


# ------------------------------------------------------------------------
#                 PER‑CLIP PEAK‑ENERGY BIN  (fine‑grained)
# ------------------------------------------------------------------------
def peak_bin(x: torch.Tensor, n_bins: int):
    """
    x : (B, T, F)   → returns (B,) energy‑bin labels
    """
    score = x.max(dim=-1).values.mean(dim=1)  # (B,)
    bounds = torch.quantile(
        score, torch.linspace(0, 1, n_bins + 1, device=x.device)
    ).unique()
    if bounds.numel() <= 2:  # degenerate
        return torch.zeros_like(score, dtype=torch.long)
    return torch.bucketize(score, bounds[1:-1])  # (B,)


# ------------------------------------------------------------------------
#                        TESTER  (top‑K tokens)
# ------------------------------------------------------------------------
class Tester(nn.Module):
    """
    Heads
    -----
      loc1 / loc2   : column id of *own* stream           (LOCAL)
      glob1 / glob2 : energy bin of *own* stream          (GLOBAL)
      cross1 /2     : energy bin of *other* stream        (CROSS)
    """

    def __init__(self, feat_fn, n_cols, n_bins, top_k=1):
        super().__init__()
        d = 768
        self.feat_fn, self.n_cols, self.top_k = feat_fn, n_cols, top_k

        # heads
        self.loc1 = ColumnProbe(d, n_cols)
        self.loc2 = ColumnProbe(d, n_cols)
        self.glob1 = GlobalProbe(d, n_bins)
        self.glob2 = GlobalProbe(d, n_bins)
        self.cross1 = GlobalProbe(d, n_bins)
        self.cross2 = GlobalProbe(d, n_bins)

    def forward(self, a1, a2):
        z1, z2 = self.feat_fn(a1, a2)  # (B,P,d)
        # pick top‑K by L2‑norm
        idx1 = z1.norm(dim=-1).topk(self.top_k, dim=1).indices  # (B,K)
        idx2 = z2.norm(dim=-1).topk(self.top_k, dim=1).indices
        tok1 = z1.gather(1, idx1[..., None].expand(-1, -1, z1.size(-1)))  # (B,K,d)
        tok2 = z2.gather(1, idx2[..., None].expand(-1, -1, z2.size(-1)))
        col_lbl1 = idx1 % self.n_cols  # (B,K)
        col_lbl2 = idx2 % self.n_cols
        return dict(tok1=tok1, tok2=tok2, col_lbl1=col_lbl1, col_lbl2=col_lbl2)

    # split out heads for clarity
    def local_logits(self, t1, t2):
        return self.loc1(t1), self.loc2(t2)

    def own_logits(self, t1, t2):
        return self.glob1(t1), self.glob2(t2)

    def cross_logits(self, t1, t2):
        return self.cross1(t1), self.cross2(t2)


# ------------------------------------------------------------------------
#                          TRAIN / LOG FUNCTION
# ------------------------------------------------------------------------
def run_probe(
    preset,
    device="cpu",
    *,
    n_bins=12,
    top_k=5,
    wandb_project=None,
    wandb_entity=None,
    wandb_run=None,
):

    feat_fn, n_cols = load_encoder(preset, device)
    model = Tester(feat_fn, n_cols, n_bins, top_k).to(device)

    opt = optim.Adam(model.parameters(), 1e-3)
    ce = nn.CrossEntropyLoss()

    if wandb_project:
        wandb.init(
            project=wandb_project,
            entity=wandb_entity,
            name=wandb_run,
            config=dict(preset=preset, n_bins=n_bins, top_k=top_k, n_cols=n_cols),
        )

    train_ds = FirstSeg(
        root_dir="",
        split="train",
        split_json_path="",
        mel_length=256,
        event_length=1024,
        midi_folder="MIDI",
        audio_filename="mix_16k.wav",
        num_rows_per_batch=1,
        split_frame_length=2000,
        is_randomize_tokens=False,
        is_random_alignment_shift_augmentation=False,
        use_prompt=False,
        skip_build=False,
        is_deterministic=True,
    )
    loader = DataLoader(
        train_ds,
        batch_size=1024,
        shuffle=True,
        num_workers=32,
        collate_fn=collate_firstseg,
        pin_memory=True,
    )

    # --------------------------------------------------------------------
    for ep in range(25):
        counters = dict(loc1=0, loc2=0, glob1=0, glob2=0, cross1=0, cross2=0, tot=0)
        for a1, a2 in loader:
            a1, a2 = a1.to(device), a2.to(device)
            outs = model(a1, a2)

            # labels
            glob1_lbl = peak_bin(a1, n_bins).unsqueeze(1).expand(-1, top_k)  # (B,K)
            glob2_lbl = peak_bin(a2, n_bins).unsqueeze(1).expand(-1, top_k)

            # logits
            ll1, ll2 = model.local_logits(outs["tok1"], outs["tok2"])
            own1, own2 = model.own_logits(outs["tok1"], outs["tok2"])
            cr1, cr2 = model.cross_logits(outs["tok1"], outs["tok2"])

            # flatten helper ------------------------------------------------
            flat = lambda x: x.reshape(-1, x.size(-1))

            # loss
            loss = (
                ce(flat(ll1), outs["col_lbl1"].reshape(-1))
                + ce(flat(ll2), outs["col_lbl2"].reshape(-1))
                + ce(flat(own1), glob1_lbl.reshape(-1))
                + ce(flat(own2), glob2_lbl.reshape(-1))
                + ce(flat(cr1), glob2_lbl.reshape(-1))  # cross
                + ce(flat(cr2), glob1_lbl.reshape(-1))
            ) / 6
            opt.zero_grad()
            loss.backward()
            opt.step()

            # accuracies ----------------------------------------------------
            counters["loc1"] += (ll1.argmax(-1) == outs["col_lbl1"]).sum().item()
            counters["loc2"] += (ll2.argmax(-1) == outs["col_lbl2"]).sum().item()
            counters["glob1"] += (own1.argmax(-1) == glob1_lbl).sum().item()
            counters["glob2"] += (own2.argmax(-1) == glob2_lbl).sum().item()
            counters["cross1"] += (cr1.argmax(-1) == glob2_lbl).sum().item()
            counters["cross2"] += (cr2.argmax(-1) == glob1_lbl).sum().item()
            counters["tot"] += a1.size(0) * top_k

        # aggregate ---------------------------------------------------------
        tot = counters["tot"]
        metrics = {
            "epoch": ep + 1,
            "loss": loss.item(),
            "local1_acc": counters["loc1"] / tot,
            "local2_acc": counters["loc2"] / tot,
            "global1_acc": counters["glob1"] / tot,
            "global2_acc": counters["glob2"] / tot,
            "cross1_acc": counters["cross1"] / tot,
            "cross2_acc": counters["cross2"] / tot,
        }

        print(
            f"[{preset}] ep{ep+1:02d} | "
            f"L1 {metrics['local1_acc']:.3f} L2 {metrics['local2_acc']:.3f} | "
            f"G1 {metrics['global1_acc']:.3f} G2 {metrics['global2_acc']:.3f} | "
            f"X1 {metrics['cross1_acc']:.3f} X2 {metrics['cross2_acc']:.3f}"
        )

        if wandb_project:
            wandb.log(metrics)

    if wandb_project:
        wandb.finish()


# ------------------------------------------------------------------------
#                                 CLI
# ------------------------------------------------------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--preset",
        choices=list(PRESETS.keys()),
        required=True,
        help="poly | full | inter",
    )
    parser.add_argument("--device", default="cpu")
    parser.add_argument("--n_bins", type=int, default=12, help="energy‑bin granularity")
    parser.add_argument(
        "--top_k", type=int, default=5, help="number of highest‑norm tokens per stream"
    )
    parser.add_argument("--wandb_project")
    parser.add_argument("--wandb_entity")
    parser.add_argument("--wandb_run")
    args = parser.parse_args()

    run_probe(
        args.preset,
        args.device,
        n_bins=args.n_bins,
        top_k=args.top_k,
        wandb_project=args.wandb_project,
        wandb_entity=args.wandb_entity,
        wandb_run=args.wandb_run,
    )
