#!/usr/bin/env python3
# train_deltanet_imm_mod_stepwise.py
"""
DeltaNet-style baseline for IMM-Mod (Prime + Invertible) dataset with STEPWISE TARGETS.

Dataset (from gen.py):
  src: "T|m|qk|mat1|...|matT"
  tgt: "v1|v2|...|vT"   where v_t = (P_t)[qk] mod m, P_t = M1..Mt (mod m)

Training:
  - Causal sequence model
  - Predict v_t at each step t (token for mat_t)
  - Loss: mean cross-entropy over valid (unpadded) steps

Tokenization / features:
  - We treat each step as a MAT token with 9 integer entries in {-1,0,1}
  - We also condition on (m, qk) by concatenating them as additional scalars per token.

NOTE:
  This is a "DeltaNet-like" causal linear attention (recurrent linear attention).
  If you have your own FLA DeltaNet module, you can swap the model block easily.
"""

from __future__ import annotations

import os
import math
import time
import json
import argparse
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


# -------------------------
# utils
# -------------------------
def set_seed(seed: int) -> None:
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def parse_src_line(line: str) -> Tuple[int, int, int, List[np.ndarray]]:
    """
    src: "T|m|qk|mat1|...|matT"
    mat: "a1,...,a9" (row-major), entries in {-1,0,1}
    """
    parts = line.strip().split("|")
    if len(parts) < 4:
        raise ValueError(f"Bad src line: {line[:200]}")
    T = int(parts[0])
    m = int(parts[1])
    qk = int(parts[2])
    mat_parts = parts[3:]
    if len(mat_parts) != T:
        raise ValueError(f"src T={T} but got {len(mat_parts)} matrices")
    mats: List[np.ndarray] = []
    for s in mat_parts:
        arr = np.fromstring(s, sep=",", dtype=np.int64)
        if arr.size != 9:
            raise ValueError(f"Bad matrix block (need 9 ints): {s}")
        mats.append(arr.reshape(3, 3))
    return T, m, qk, mats


def parse_tgt_line(line: str, T: int) -> List[int]:
    """
    tgt: "v1|...|vT"
    """
    vs = line.strip().split("|") if line.strip() else []
    if len(vs) != T:
        raise ValueError(f"tgt length {len(vs)} != T={T}")
    return [int(x) for x in vs]


def infer_max_T_from_src_path(src_path: str) -> int:
    # robust: parse first field only
    max_T = 0
    with open(src_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            T = int(line.split("|", 1)[0])
            if T > max_T:
                max_T = T
    return max_T


def autocast_dtype_from_str(s: str) -> torch.dtype:
    s = s.lower().strip()
    if s in ("bf16", "bfloat16"):
        return torch.bfloat16
    if s in ("fp16", "float16", "half"):
        return torch.float16
    raise ValueError(f"Unknown amp dtype: {s}")


# -------------------------
# dataset
# -------------------------
@dataclass
class Sample:
    T: int
    m: int
    qk: int
    mats: np.ndarray   # (T, 9) int64
    ys: np.ndarray     # (T,) int64


class IMMStepwiseDataset(Dataset):
    def __init__(self, src_path: str, tgt_path: str):
        self.src_path = src_path
        self.tgt_path = tgt_path

        self.samples: List[Sample] = []
        self.max_T: int = 0
        self.m_values: set[int] = set()
        self.qk_values: set[int] = set()

        self._preload()

    def _preload(self) -> None:
        with open(self.src_path, "r", encoding="utf-8") as fsrc, open(self.tgt_path, "r", encoding="utf-8") as ftgt:
            for sline, tline in tqdm(zip(fsrc, ftgt), desc=f"Preload {os.path.basename(self.src_path)}", dynamic_ncols=True):
                if not sline.strip():
                    continue
                T, m, qk, mats_list = parse_src_line(sline)
                ys_list = parse_tgt_line(tline, T=T)

                mats = np.stack([M.reshape(-1) for M in mats_list], axis=0).astype(np.int64)  # (T, 9)
                ys = np.asarray(ys_list, dtype=np.int64)  # (T,)

                self.samples.append(Sample(T=T, m=m, qk=qk, mats=mats, ys=ys))
                self.max_T = max(self.max_T, T)
                self.m_values.add(m)
                self.qk_values.add(qk)

        if len(self.samples) == 0:
            raise RuntimeError(f"No samples loaded from {self.src_path}")

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Sample:
        return self.samples[idx]


def collate_batch(samples: List[Sample]) -> Dict[str, torch.Tensor]:
    """
    Returns:
      x_mats: (B, T_max, 9) float32   (values in {-1,0,1})
      x_meta: (B, T_max, 2) float32   (m_norm, qk_norm) repeated per step
      y:      (B, T_max) long        (0..m-1), padded with -100
      attn_mask: (B, T_max) bool     (True where valid)
      m:      (B,) long
      T:      (B,) long
    """
    B = len(samples)
    T_max = max(s.T for s in samples)

    x_mats = torch.zeros(B, T_max, 9, dtype=torch.float32)
    x_meta = torch.zeros(B, T_max, 2, dtype=torch.float32)
    y = torch.full((B, T_max), -100, dtype=torch.long)
    attn_mask = torch.zeros(B, T_max, dtype=torch.bool)
    m_tensor = torch.zeros(B, dtype=torch.long)
    T_tensor = torch.zeros(B, dtype=torch.long)

    for i, s in enumerate(samples):
        T = s.T
        x_mats[i, :T] = torch.from_numpy(s.mats.astype(np.float32))  # {-1,0,1}

        # normalize m and qk to roughly [-1,1] / [0,1] ranges
        # (m is typically fixed, but keep general)
        m_norm = float(s.m) / 100.0  # conservative scaling
        qk_norm = float(s.qk) / 8.0
        x_meta[i, :T, 0] = m_norm
        x_meta[i, :T, 1] = qk_norm

        y[i, :T] = torch.from_numpy(s.ys.astype(np.int64))
        attn_mask[i, :T] = True
        m_tensor[i] = s.m
        T_tensor[i] = T

    return {
        "x_mats": x_mats,
        "x_meta": x_meta,
        "y": y,
        "attn_mask": attn_mask,
        "m": m_tensor,
        "T": T_tensor,
    }


# -------------------------
# DeltaNet-like block (causal recurrent linear attention)
# -------------------------
class RMSNorm(nn.Module):
    def __init__(self, d: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (..., d)
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
        return x / rms * self.weight


def phi_feature(x: torch.Tensor) -> torch.Tensor:
    # positive feature map for linear attention; ELU+1 is standard
    return F.elu(x) + 1.0


class CausalLinearAttention(nn.Module):
    """
    Recurrent linear attention (causal):
      out_t = (phi(q_t) @ S_t) / (phi(q_t) @ z_t)
    where:
      S_t = sum_{i<=t} phi(k_i)^T v_i
      z_t = sum_{i<=t} phi(k_i)
    """
    def __init__(self, d_model: int, n_heads: int, head_dim: int, dropout: float):
        super().__init__()
        assert n_heads * head_dim == d_model, "For simplicity, require n_heads*head_dim == d_model"
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = head_dim

        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out = nn.Linear(d_model, d_model, bias=False)
        self.drop = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        x: (B, T, D)
        attn_mask: (B, T) bool, True for valid tokens. (Padding must not contribute.)
        """
        B, T, D = x.shape
        qkv = self.qkv(x)  # (B,T,3D)
        q, k, v = qkv.chunk(3, dim=-1)

        # (B, H, T, Hd)
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        q = phi_feature(q)
        k = phi_feature(k)

        if attn_mask is not None:
            # mask shape -> (B,1,T,1)
            m = attn_mask[:, None, :, None].to(dtype=q.dtype)
            k = k * m
            v = v * m
            q = q * m  # ensures padded query positions output ~0

        # recurrent scan over time (T up to a few hundred is fine)
        S = torch.zeros(B, self.n_heads, self.head_dim, self.head_dim, device=x.device, dtype=x.dtype)
        z = torch.zeros(B, self.n_heads, self.head_dim, device=x.device, dtype=x.dtype)

        outs = []
        eps = 1e-6

        # time loop
        for t in range(T):
            kt = k[:, :, t, :]          # (B,H,Hd)
            vt = v[:, :, t, :]          # (B,H,Hd)
            qt = q[:, :, t, :]          # (B,H,Hd)

            # S += kt^T vt  (outer product)
            S = S + torch.einsum("bhm,bhn->bhmn", kt, vt)
            z = z + kt

            # numerator: qt @ S  => (B,H,Hd)
            num = torch.einsum("bhm,bhmn->bhn", qt, S)
            den = torch.einsum("bhm,bhm->bh", qt, z).unsqueeze(-1)  # (B,H,1)
            outt = num / (den + eps)
            outs.append(outt)

        out = torch.stack(outs, dim=2)  # (B,H,T,Hd)
        out = out.transpose(1, 2).contiguous().view(B, T, D)  # (B,T,D)
        out = self.out(self.drop(out))
        return out


class DeltaNetBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float, mlp_mult: int = 4):
        super().__init__()
        assert d_model % n_heads == 0
        head_dim = d_model // n_heads

        self.norm1 = RMSNorm(d_model)
        self.attn = CausalLinearAttention(d_model, n_heads=n_heads, head_dim=head_dim, dropout=dropout)

        self.norm2 = RMSNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, mlp_mult * d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_mult * d_model, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        x = x + self.attn(self.norm1(x), attn_mask=attn_mask)
        x = x + self.mlp(self.norm2(x))
        return x


class DeltaNetStepwiseClassifier(nn.Module):
    def __init__(self, d_in: int, d_model: int, n_layers: int, n_heads: int, dropout: float, n_classes: int):
        super().__init__()
        self.in_proj = nn.Linear(d_in, d_model)
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([DeltaNetBlock(d_model, n_heads=n_heads, dropout=dropout) for _ in range(n_layers)])
        self.norm = RMSNorm(d_model)
        self.head = nn.Linear(d_model, n_classes)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        x: (B,T,d_in)
        returns logits: (B,T,C)
        """
        h = self.drop(self.in_proj(x))
        for blk in self.blocks:
            h = blk(h, attn_mask=attn_mask)
        h = self.norm(h)
        logits = self.head(h)
        return logits


# -------------------------
# train / eval
# -------------------------
@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> Dict[str, float]:
    model.eval()
    total_loss = 0.0
    total_steps = 0
    correct = 0
    total = 0

    for batch in loader:
        x_mats = batch["x_mats"].to(device)
        x_meta = batch["x_meta"].to(device)
        y = batch["y"].to(device)
        mask = batch["attn_mask"].to(device)

        x = torch.cat([x_mats, x_meta], dim=-1)  # (B,T,11)
        logits = model(x, attn_mask=mask)         # (B,T,C)

        # CE over steps
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-100, reduction="sum")
        n_valid = mask.sum().item()
        total_loss += loss.item()
        total_steps += int(n_valid)

        pred = logits.argmax(dim=-1)  # (B,T)
        valid = (y != -100)
        correct += (pred[valid] == y[valid]).sum().item()
        total += valid.sum().item()

    mean_loss = total_loss / max(1, total_steps)
    acc = correct / max(1, total)
    return {"loss": mean_loss, "acc": acc}


def main() -> None:
    ap = argparse.ArgumentParser()

    ap.add_argument("--data_dir", type=str, required=True)
    ap.add_argument("--max_mod", type=int, default=64, help="num classes C; must be >= m. (default 64)")

    # model
    ap.add_argument("--d_model", type=int, default=256)
    ap.add_argument("--layers", type=int, default=2)
    ap.add_argument("--heads", type=int, default=4)
    ap.add_argument("--dropout", type=float, default=0.0)

    # training
    ap.add_argument("--batch_size", type=int, default=256)
    ap.add_argument("--lr", type=float, default=1e-4)
    ap.add_argument("--weight_decay", type=float, default=0.01)
    ap.add_argument("--grad_clip", type=float, default=1.0)
    ap.add_argument("--max_steps", type=int, default=30_000)
    ap.add_argument("--eval_every", type=int, default=500)
    ap.add_argument("--save_path", type=str, default="ckpt_deltanet_imm_mod_stepwise.pt")

    # system
    ap.add_argument("--seed", type=int, default=0)
    ap.add_argument("--num_workers", type=int, default=2)
    ap.add_argument("--cuda", action="store_true")

    # amp
    ap.add_argument("--amp", action="store_true")
    ap.add_argument("--amp_dtype", type=str, default="bf16", choices=["bf16", "fp16"])

    args = ap.parse_args()

    set_seed(args.seed)
    device = torch.device("cuda" if (args.cuda and torch.cuda.is_available()) else "cpu")
    amp_dtype = autocast_dtype_from_str(args.amp_dtype)

    # paths
    def p(name: str) -> str:
        return os.path.join(args.data_dir, name)

    train_ds = IMMStepwiseDataset(p("train_src.txt"), p("train_tgt.txt"))
    val_ds   = IMMStepwiseDataset(p("val_bin0_src.txt"), p("val_bin0_tgt.txt"))
    test0_ds = IMMStepwiseDataset(p("test_bin0_src.txt"), p("test_bin0_tgt.txt"))
    test1_ds = IMMStepwiseDataset(p("test_bin1_src.txt"), p("test_bin1_tgt.txt"))
    test2_ds = IMMStepwiseDataset(p("test_bin2_src.txt"), p("test_bin2_tgt.txt"))

    # sanity: m typically fixed
    all_m = sorted(list(train_ds.m_values))
    all_qk = sorted(list(train_ds.qk_values))
    print(f"[Data] train: {len(train_ds)} samples, max_T={train_ds.max_T}, m_values={all_m}, qk_values={all_qk}")

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, pin_memory=(device.type == "cuda"),
                              collate_fn=collate_batch, drop_last=True)
    val_loader   = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
                              num_workers=args.num_workers, pin_memory=(device.type == "cuda"),
                              collate_fn=collate_batch)
    test0_loader = DataLoader(test0_ds, batch_size=args.batch_size, shuffle=False,
                              num_workers=args.num_workers, pin_memory=(device.type == "cuda"),
                              collate_fn=collate_batch)
    test1_loader = DataLoader(test1_ds, batch_size=args.batch_size, shuffle=False,
                              num_workers=args.num_workers, pin_memory=(device.type == "cuda"),
                              collate_fn=collate_batch)
    test2_loader = DataLoader(test2_ds, batch_size=args.batch_size, shuffle=False,
                              num_workers=args.num_workers, pin_memory=(device.type == "cuda"),
                              collate_fn=collate_batch)

    # model
    d_in = 9 + 2  # mats + (m,qk)
    model = DeltaNetStepwiseClassifier(
        d_in=d_in,
        d_model=args.d_model,
        n_layers=args.layers,
        n_heads=args.heads,
        dropout=args.dropout,
        n_classes=args.max_mod,
    ).to(device)

    n_params = sum(p.numel() for p in model.parameters())
    print(f"[Model] params={n_params/1e6:.2f}M, d_in={d_in}, d_model={args.d_model}, layers={args.layers}, heads={args.heads}")

    opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scaler = torch.cuda.amp.GradScaler(enabled=(args.amp and device.type == "cuda" and amp_dtype == torch.float16))

    # training loop
    model.train()
    best_val_acc = -1.0
    step = 0
    t0 = time.time()

    train_iter = iter(train_loader)

    while step < args.max_steps:
        try:
            batch = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            batch = next(train_iter)

        x_mats = batch["x_mats"].to(device, non_blocking=True)
        x_meta = batch["x_meta"].to(device, non_blocking=True)
        y = batch["y"].to(device, non_blocking=True)
        mask = batch["attn_mask"].to(device, non_blocking=True)

        x = torch.cat([x_mats, x_meta], dim=-1)

        opt.zero_grad(set_to_none=True)

        if args.amp and device.type == "cuda":
            with torch.autocast(device_type="cuda", dtype=amp_dtype):
                logits = model(x, attn_mask=mask)
                loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-100)
            if scaler.is_enabled():
                scaler.scale(loss).backward()
                if args.grad_clip > 0:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
                scaler.step(opt)
                scaler.update()
            else:
                loss.backward()
                if args.grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
                opt.step()
        else:
            logits = model(x, attn_mask=mask)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1), ignore_index=-100)
            loss.backward()
            if args.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            opt.step()

        step += 1

        if step % 50 == 0:
            dt = time.time() - t0
            print(f"[Train] step={step:6d} loss={loss.item():.4f} time={dt:.1f}s")
            t0 = time.time()

        if step % args.eval_every == 0 or step == args.max_steps:
            val = evaluate(model, val_loader, device=device)
            test0 = evaluate(model, test0_loader, device=device)
            test1 = evaluate(model, test1_loader, device=device)
            test2 = evaluate(model, test2_loader, device=device)

            msg = {
                "step": step,
                "val_bin0": val,
                "test_bin0": test0,
                "test_bin1": test1,
                "test_bin2": test2,
            }
            print("[Eval]", json.dumps(msg, indent=2))

            if val["acc"] > best_val_acc:
                best_val_acc = val["acc"]
                ckpt = {
                    "step": step,
                    "args": vars(args),
                    "model_state": model.state_dict(),
                    "opt_state": opt.state_dict(),
                    "best_val_acc": best_val_acc,
                }
                torch.save(ckpt, args.save_path)
                print(f"[CKPT] saved to {args.save_path} (best_val_acc={best_val_acc:.4f})")

    print("[Done]")


if __name__ == "__main__":
    main()
