import sys
from pathlib import Path

# point this to your SMDM repo root 
REPO_ROOT = Path("")
sys.path.insert(0, str(REPO_ROOT))

# ------------------------------
# Optional flash_attn stub (robust version)
# ------------------------------
import types, importlib.util, importlib.abc
import math
import torch
import torch.nn.functional as F

def _fa_flash_stub(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, **kwargs):
    if hasattr(F, "scaled_dot_product_attention"):
        return F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,
            dropout_p=dropout_p if torch.is_grad_enabled() else 0.0,
            is_causal=causal,
        )
    d = q.size(-1)
    scale = (1.0 / math.sqrt(d)) if softmax_scale is None else softmax_scale
    scores = (q @ k.transpose(-2, -1)) * scale
    if causal:
        mask = torch.ones_like(scores, dtype=torch.bool).tril()
        scores = scores.masked_fill(~mask, float("-inf"))
    P = scores.softmax(dim=-1)
    if torch.is_grad_enabled() and dropout_p > 0:
        P = F.dropout(P, p=dropout_p)
    return P @ v

if "flash_attn" not in sys.modules:
    flash_attn_module = types.ModuleType("flash_attn")
    flash_attn_module.flash_attn_func = _fa_flash_stub
    
    # Create a losses submodule for CrossEntropyLoss
    losses_module = types.ModuleType("flash_attn.losses")
    losses_module.cross_entropy = types.ModuleType("flash_attn.losses.cross_entropy")
    
    # Stub for CrossEntropyLoss - use PyTorch's CrossEntropyLoss
    class CrossEntropyLossStub:
        def __init__(self, *args, **kwargs):
            self.ce_loss = torch.nn.CrossEntropyLoss(*args, **kwargs)
        def __call__(self, *args, **kwargs):
            return self.ce_loss(*args, **kwargs)
    
    losses_module.cross_entropy.CrossEntropyLoss = CrossEntropyLossStub
    flash_attn_module.losses = losses_module
    
    class _MinimalLoader(importlib.abc.Loader):
        def create_module(self, spec): return None
        def exec_module(self, module): pass
    
    spec = importlib.util.spec_from_loader("flash_attn", _MinimalLoader(), origin="<stub>")
    flash_attn_module.__spec__ = spec
    sys.modules["flash_attn"] = flash_attn_module
    sys.modules["flash_attn.losses"] = losses_module
    sys.modules["flash_attn.losses.cross_entropy"] = losses_module.cross_entropy

import json, math, re, sys, time
from pathlib import Path
import lightning as L
import torch
from torch.utils.data import DataLoader
from functools import partial
from transformers import AutoTokenizer
from lit_gpt.diffmodel import TransEncoder, Block, Config
from lit_gpt.utils import get_default_supported_precision, num_parameters, step_csv_logger
from lightning.fabric.strategies import FSDPStrategy
from safetensors.torch import load_file as load_safetensors
from pytorch_lightning.loggers import WandbLogger
from flash_attn.losses.cross_entropy import CrossEntropyLoss
from lightning.fabric.strategies import FSDPStrategy
from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor, estimate_flops
import wandb
import argparse


# ---- ADD: Sudoku Dataset Loader ----
MASK_ID = 32000
SEQ_LEN = 81

# -------------------- Tokenizer & digit maps --------------------
def build_tokenizer_and_digit_maps():
    tok = AutoTokenizer.from_pretrained(
        "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
        padding_side="right",
        use_fast=True,
    )
    # IMPORTANT: match SMDM — add PAD before building digit maps
    tok.add_special_tokens({'pad_token': '[PAD]'})
    tok.pad_token_id = MASK_ID

    # LLaMA tokenizers are space-aware; use leading space and take the LAST id
    def digit_id(d: int) -> int:
        ids = tok.encode(f" {d}", add_special_tokens=False)
        assert len(ids) >= 1, f"could not encode digit {d}"
        return ids[-1]

    digit2id = {d: digit_id(d) for d in range(10)}
    id2digit = {v: k for k, v in digit2id.items()}

    # Sanity checks
    assert len(set(digit2id.values())) == 10, f"digits not unique: {digit2id}"
    assert all(t < MASK_ID for t in digit2id.values()), "digit ids must be < 32000"

    return tok, digit2id, id2digit

# -------------------- Dataset --------------------
from torch.utils.data import Dataset
class SudokuDataset(Dataset):
    """JSONL with {"puzzle":[81 ints], "solution":[81 ints]}."""
    def __init__(self, jsonl_path: str, digit2id: dict):
        self.rows = []
        with open(jsonl_path, "r") as f:
            for line in f:
                ex = json.loads(line)
                p, s = ex["puzzle"], ex["solution"]
                if len(p) == 81 and len(s) == 81:
                    self.rows.append((p, s))
        self.d2i = digit2id

    def __len__(self): return len(self.rows)

    def __getitem__(self, idx):
        puzzle, solution = self.rows[idx]
        # x: givens as token ids; blanks = MASK_ID
        x = torch.full((SEQ_LEN,), MASK_ID, dtype=torch.long)
        # y: target token ids for every position
        y = torch.zeros((SEQ_LEN,), dtype=torch.long)
        # mask: True where the puzzle was blank (we supervise there)
        mask = torch.zeros((SEQ_LEN,), dtype=torch.bool)
        for j in range(SEQ_LEN):
            y[j] = self.d2i[int(solution[j])]  # solution digits 1..9
            if puzzle[j] == 0:
                mask[j] = True            # needs prediction
            else:
                x[j] = self.d2i[int(puzzle[j])]  # keep givens fixed
        return x, y, mask
# ------------------------------------

# -------------------- Args --------------------
def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", type=int, required=True, help="e.g., 1028 for Diff_LLaMA_1028M")
    ap.add_argument("--pretrain_path", type=str, required=True, help="SMDM checkpoint (.safetensors)")
    ap.add_argument("--jsonl", type=str, required=True, help="train jsonl path")
    ap.add_argument("--devices", type=int, default=1, help="GPUs to use")
    ap.add_argument("--bs", type=int, default=256, help="global batch size")
    ap.add_argument("--epoch", type=int, default=3)
    ap.add_argument("--micro_bs", type=int, default=None, help="per-GPU micro batch (auto if None)")
    ap.add_argument("--num_workers", type=int, default=2)
    ap.add_argument("--save_dir", type=str, default="workdir/finetune/sudoku")
    ap.add_argument("--log_every", type=int, default=100)
    ap.add_argument("--lr", type=float, default=6e-4)
    return ap.parse_args()

args = parse_args()
model_name = f"Diff_LLaMA_{args.model}M"

# Same hyperparameters as original
num_of_devices = 1
global_batch_size = args.bs
learning_rate = 6e-4
#micro_batch_size = 8 if args.model <= 1000 else 4
micro_batch_size = 256
max_step = int(55612 * args.epoch / global_batch_size)
save_step_interval = 5000
batch_size = global_batch_size // num_of_devices
gradient_accumulation_steps = batch_size // micro_batch_size
assert gradient_accumulation_steps > 0

logger = step_csv_logger("out", model_name)

# -------------------- Setup & Train --------------------
def setup_and_run():
    args = parse_args()
    model_name = f"Diff_LLaMA_{args.model}M"
    save_dir = Path(args.save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    # precision & strategy
    precision = get_default_supported_precision(training=True)
    if args.devices > 1:
        strategy = FSDPStrategy(
            auto_wrap_policy={Block},
            activation_checkpointing_policy=None,
            state_dict_type="full",
            limit_all_gathers=True,
            cpu_offload=False,
        )
    else:
        strategy = "auto"

    fabric = L.Fabric(strategy=strategy, precision=precision)
    fabric.print(f"Saving to: {save_dir}")

    # tokenizer & digit maps
    tokenizer, digit2id, id2digit = build_tokenizer_and_digit_maps()
    fabric.print(f"digit2id: {digit2id}")

    # dataset / loader
    train_set = SudokuDataset(args.jsonl, digit2id)
    N = len(train_set)
    fabric.print(f"Dataset size: {N}")

    # micro batch heuristic
    if args.micro_bs is None:
        micro_bs = 128 if args.model <= 1000 else 16
    else:
        micro_bs = args.micro_bs

    # compute grad accumulation from desired global batch
    num_devices = max(1, args.devices)
    global_bs = args.bs
    assert global_bs % (micro_bs * num_devices) == 0, \
        f"--bs ({global_bs}) must be divisible by micro_bs*devices ({micro_bs*num_devices})"
    grad_accum = global_bs // (micro_bs * num_devices)

    train_loader = DataLoader(
        train_set,
        batch_size=micro_bs,
        shuffle=True,
        drop_last=True,
        num_workers=args.num_workers,
        pin_memory=True,
        persistent_workers=False,
    )
    train_loader = fabric.setup_dataloaders(train_loader)

    # model
    config = Config.from_name(model_name)
    with fabric.init_module():
        model = TransEncoder(config)
        # load SMDM pretrain weights (safetensors expected)
        sd = load_safetensors(args.pretrain_path)
        model.load_state_dict(sd)
    fabric.print(f"Total params: {num_parameters(model):,}")
    model = fabric.setup(model)

    # opt
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=1e-1)
    optimizer = fabric.setup_optimizers(optimizer)

    # train
    train(fabric, model, optimizer, train_loader, grad_accum, save_dir, args)

def train(fabric, model, optimizer, train_loader, grad_accum, save_dir: Path, args):
    wandb.init(project="sudoku-mdm", name=model_name, config={
        "bs": args.bs, "epochs": args.epoch, "lr": learning_rate})
    ce = torch.nn.CrossEntropyLoss(reduction='none')
    steps_per_epoch = len(train_loader)
    fabric.print(f"steps/epoch: {steps_per_epoch}, epochs: {args.epoch}, "
                 f"global_bs: {args.bs}, micro_bs: {train_loader.batch_size}, "
                 f"grad_accum: {grad_accum}")

    t0 = time.time()
    sup_tokens = 0
    iter_num = 0
    step_count = 0

    for ep in range(args.epoch):
        for x, y, mask in train_loader:
            # x: (B,81) givens or MASK_ID; y: (B,81) target token ids; mask: (B,81) blanks
            is_accum = ((iter_num + 1) % grad_accum) != 0

            # forward
            logits = model(x)  # (B,81,V)
            B, L, V = logits.shape

            # CE over all, then select blanks
            loss_tok = ce(logits.view(B * L, V), y.view(B * L)).view(B, L)
            denom = mask.float().sum().clamp_min(1.0)
            loss = (loss_tok * mask.float()).sum() / denom

            # backward
            fabric.backward(loss / grad_accum)

            if not is_accum:
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
                step_count += 1

            iter_num += 1
            sup_tokens += int(mask.sum().item())

            # logs
            if iter_num % args.log_every == 0:
                dt = time.time() - t0
                ips = sup_tokens / dt if dt > 0 else 0.0
                fabric.print(f"[epoch {ep+1}/{args.epoch} iter {iter_num}] "
                             f"loss={loss.item():.6f}  supervised_toks/s={ips:,.0f}")
                wandb.log({"loss": loss.item(), "supervised_toks/s": ips})

            # save checkpoint every 20000 iterations
            if iter_num % 20000 == 0:
                ckpt_path = save_dir / f"iter-{iter_num:06d}-ckpt.pth"
                fabric.print(f"Saving checkpoint at iteration {iter_num} to {ckpt_path}")
                state = {"model": model, "optimizer": optimizer,
                         "iter_num": iter_num, "step_count": step_count}
                fabric.save(ckpt_path, state)

        # save a checkpoint each epoch
        ckpt_path = save_dir / f"epoch-{ep+1:02d}-ckpt.pth"
        fabric.print(f"Saving checkpoint to {ckpt_path}")
        state = {"model": model, "optimizer": optimizer,
                 "iter_num": iter_num, "step_count": step_count}
        fabric.save(ckpt_path, state)

    # final pure weights (easy for inference)
    final_pt = save_dir / "final_model.pt"
    fabric.print(f"Saving final weights to {final_pt}")
    # gather to CPU clean state_dict
    cpu_sd = {k: v.detach().to("cpu") for k, v in model.state_dict().items()}
    torch.save(cpu_sd, final_pt)

# -------------------- main --------------------
if __name__ == "__main__":
    torch.set_float32_matmul_precision("high")
    setup_and_run()