from __future__ import annotations
import os
import argparse
from pathlib import Path
import math
import torch
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_
from model_src.generation import LmSampler, generation_accuracy
from model_src import (
    JsonlSequenceDataset,
    build_dataloader,
    build_causal_lm,
    load_vocab,
    set_seed,
    write_json,
)

from datagen.languages import LANGUAGES
import json

# ---- Fixed W&B entity (team or username). Only used if --wandb is passed. ----
WANDB_ENTITY = "<FILL>" 

@torch.no_grad()
def evaluate(model, eval_loader, device, num_batches: int) -> float:
    model.eval()
    total_loss = 0.0
    steps = 0
    for batch in eval_loader:
        if steps >= num_batches:
            break
        inputs = {k: v.to(device) for k, v in batch.items()}
        out = model(**inputs)
        total_loss += out.loss.item()
        steps += 1
    return total_loss / max(steps, 1)


@torch.no_grad()
def sample_one_string(
    model: torch.nn.Module,
    start_id: int,      # [BOS] id 
    stop_id: int,       # [EOS] id
    id2tok: dict[int, str],
    device: torch.device,
    max_steps: int,
    ):
    """
    Unconditional multinomial sampling.
    Seed context with [BOS]. Stop at first [EOS].
    """
    input_ids = torch.tensor([[start_id]], dtype=torch.long, device=device)
    out_ids: list[int] = []
 
    for _ in range(max_steps):
        outputs = model(input_ids=input_ids)
        logits = outputs.logits[:, -1, :]              # [1, V]
        probs = torch.softmax(logits, dim=-1)          # multinomial sampling
        next_id = torch.multinomial(probs, num_samples=1)  # [1, 1]
        nid = int(next_id.item())
        out_ids.append(nid)
        input_ids = torch.cat([input_ids, next_id], dim=1)

        if nid == stop_id:
            break
 
    toks = [id2tok[i] for i in out_ids]
    return "".join(toks), toks


def parse_args():
    p = argparse.ArgumentParser(description="Train a Transformer LM on JSONL data.")
    p.add_argument("--run-name", type=str, required=True,
                   help="Run name; used for models/<run_name> and (if --wandb) W&B run name.")
    # Data
    p.add_argument("--data-dir", type=str, required=True,
                   help="Directory with train.jsonl, test.jsonl, vocab.json. Format 'data/<language>/<run_id>'")
    p.add_argument("--max-seq-len", type=int, default=256, help="Max sequence length (≤ context).")
    p.add_argument("--batch-size", type=int, default=64)
    # Model
    p.add_argument("--n-layer", type=int, default=8)
    p.add_argument("--n-head", type=int, default=8)
    p.add_argument("--n-embd", type=int, default=512)
    p.add_argument("--dropout", type=float, default=0.0)
    # Optim
    p.add_argument("--lr", type=float, default=2e-5)
    p.add_argument("--weight-decay", type=float, default=0.01)
    p.add_argument("--grad-clip", type=float, default=1.0, help="Clip global grad-norm to this value; 0 disables.")

    # Loop
    p.add_argument("--max-steps", type=int, default=20000, help="Total optimization steps.")
    p.add_argument("--eval-every", type=int, default=250, help="Eval every N steps.")
    p.add_argument("--log-every", type=int, default=100, help="Log every N steps.")
    p.add_argument("--eval-batches", type=int, default=5, help="Batches per eval run.")
    p.add_argument("--early-stop-patience", type=int, default=20000, help="Stop if no eval improvement for this many steps.")
    p.add_argument("--min-delta", type=float, default=0.0, help="Minimum loss decrease to count as improvement.")
    p.add_argument("--num-workers", type=int, default=0)

    # Sampling display
    p.add_argument("--gen-samples", type=int, default=5, help="How many samples to print after each eval.")
    p.add_argument("--eval-gen-strategy", type=str, choices=["natural", "top-p", "min-p"], default="natural",
                   help="Sampling rule used to compute eval generation accuracy.")
    p.add_argument("--eval-gen-param", type=float, default=0.0,
                   help="Parameter for top-p (e.g., 0.9) or min-p (e.g., 1e-3). Ignored for 'natural'.")
    p.add_argument("--eval-gen-samples", type=int, default=100, help="How many samples to score for eval accuracy.")

    # Misc
    p.add_argument("--seed", type=int, default=123)
    p.add_argument("--device", type=int, default=4, help="CUDA device index (int).")

    # ---- Minimal W&B / run identity ----
    p.add_argument("--wandb", action="store_true",
                   help="Enable Weights & Biases logging.")
    p.add_argument("--wandb-project", type=str, default="NSP-lm-models",
                   help="W&B project name (used only if --wandb).")
    
    return p.parse_args()


def main():
    args = parse_args()

    # Derive save_dir from run-name (always models/<run_name>)
    args.save_dir = str(Path("models") / args.run_name)
    os.makedirs(args.save_dir, exist_ok=True)

    set_seed(args.seed)
    device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")

    args.gen_max_steps = args.max_seq_len  
    args.betas, args.eps = (0.9, 0.999), 1e-8

    # ----- Data & vocab -----
    dd = Path(args.data_dir)
    if not dd.is_absolute() and not (dd.exists() and dd.is_dir()):
        dd = Path("data") / dd
    train_path = dd / "train.jsonl"
    eval_path = dd / "test.jsonl"
    vocab_path = dd / "vocab.json"
    for pth in (train_path, eval_path, vocab_path):
        if not pth.exists():
            raise SystemExit(f"Expected file missing: {pth}")

    tok2id = load_vocab(str(vocab_path))
    id2tok = {i: t for t, i in tok2id.items()}
    eos_id = tok2id.get("[EOS]")
    bos_id = tok2id.get("[BOS]")
    if eos_id is None or bos_id is None:
        raise SystemExit("Vocab must contain [EOS] and [BOS].")
    
    # Read language name from the first training record's meta
    # load meta file in data_dir/
    meta_path = dd / "meta.json"
    if not meta_path.exists():
        raise SystemExit(f"Expected file missing: {meta_path}")

    with open(meta_path, "r", encoding="utf-8") as f:
        meta = json.load(f)
    lang_name = meta["language"]
    if not lang_name:
        raise SystemExit("Training data records must include meta.language.")
    if lang_name not in LANGUAGES:
        raise SystemExit(f"Language '{lang_name}' not found in registry.")
    language = LANGUAGES[lang_name]

    train_ds = JsonlSequenceDataset(str(train_path), tok2id=tok2id, max_seq_len=args.max_seq_len)
    eval_ds  = JsonlSequenceDataset(str(eval_path),  tok2id=tok2id, max_seq_len=args.max_seq_len)

    train_loader = build_dataloader(
        train_ds,
        batch_size=args.batch_size,
        max_steps=args.max_steps + 2,
        replacement=True,  # true random batches
        num_workers=args.num_workers,
        pad_id=eos_id,
        max_len=args.max_seq_len,
    )
    eval_loader = build_dataloader(
        eval_ds,
        batch_size=args.batch_size,
        max_steps=None,
        replacement=True,
        num_workers=args.num_workers,
        pad_id=eos_id,
        max_len=args.max_seq_len,
    )

    # ----- Model -----
    model, hf_cfg = build_causal_lm(
        vocab_size=len(tok2id),
        max_seq_len=args.max_seq_len,
        n_layer=args.n_layer,
        n_head=args.n_head,
        n_embd=args.n_embd,
        dropout=args.dropout,
        eos_token_id=eos_id,
        bos_token_id=bos_id,
        pad_token_id=eos_id,
    )
    
    model.to(device)
    optim = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=tuple(args.betas), eps=args.eps)

    # Save run config
    run_cfg = {
        "data": {
            "data_dir": str(dd),
            "train": str(train_path),
            "eval": str(eval_path),
            "vocab": str(vocab_path),
            "language": lang_name,
            "max_seq_len": args.max_seq_len,
            "batch_size": args.batch_size,
        },
        "model": {
            "n_layer": args.n_layer,
            "n_head": args.n_head,
            "n_embd": args.n_embd,
            "dropout": args.dropout,
            "vocab_size": len(tok2id),
        },
        "optim": {
            "lr": args.lr, "weight_decay": args.weight_decay,
            "betas": list(args.betas), "eps": args.eps, "grad_clip": args.grad_clip,
        },
        "loop": {
            "max_steps": args.max_steps,
            "eval_every": args.eval_every,
            "log_every": args.log_every,
            "eval_batches": args.eval_batches,
            "early_stop_patience": args.early_stop_patience,
            "min_delta": args.min_delta,
        },
        "seed": args.seed,
        "identity": {
            "run_name": args.run_name,
            "save_dir": args.save_dir,
        },
    }
    write_json(os.path.join(args.save_dir, "run_config.json"), run_cfg)

    # ---- Optional W&B init (only if --wandb) ----
    use_wandb = args.wandb
    if use_wandb:
        try:
            import wandb  # local import so script works without wandb installed
        except ImportError:
            raise SystemExit("You passed --wandb but the 'wandb' package is not installed. Try: pip install wandb")
        wandb.init(
            project=args.wandb_project,
            entity=WANDB_ENTITY,
            name=args.run_name,
            config=run_cfg,
        )

    # ----- Training loop with early stopping -----
    best_val = float("inf")
    last_improve_step = 0
    running_loss = 0.0

    model.train()
    for step, batch in enumerate(train_loader, start=1):
        inputs = {k: v.to(device) for k, v in batch.items()}
        out = model(**inputs)  # GPT2LMHeadModel returns loss when labels provided
        loss = out.loss

        optim.zero_grad(set_to_none=True)
        loss.backward()
        if args.grad_clip and args.grad_clip > 0:
            clip_grad_norm_(model.parameters(), args.grad_clip)
        optim.step()

        running_loss += loss.item()

        # Logging
        if step % args.log_every == 0:
            avg_train = running_loss / args.log_every
            print(f"[step {step:>6}] train_loss: {avg_train:.4f}")
            if use_wandb:
                import wandb
                wandb.log({"train/loss": avg_train}, step=step)
            running_loss = 0.0

        # Eval + sampling
        if step % args.eval_every == 0:
            val_loss = evaluate(model, eval_loader, device=device, num_batches=args.eval_batches)
            ppl = math.exp(val_loss) if val_loss < 50 else float("inf")
            print(f"[step {step:>6}] eval_loss: {val_loss:.4f}  (ppl ~ {ppl:.2f})")
            if use_wandb:
                import wandb
                wandb.log({"eval/loss": val_loss}, step=step)

            # Print a few unconditional generations (multinomial)
            model.eval()
            print("  Samples:")

            # Collect samples; only build W&B table if enabled
            samples_rows = []
            for i in range(args.gen_samples):
                concat, toks = sample_one_string(model, bos_id, eos_id, id2tok, device, args.gen_max_steps)
                sigma_only = toks[:-1] if toks and toks[-1] == "[EOS]" else toks
                label = 0
                try:
                    if toks and toks[-1] == "[EOS]" and language.is_positive(sigma_only):
                        label = 1
                except Exception:
                    label = 0
                print(f"   - #{i+1}: {concat}    label={label}")
                samples_rows.append([i + 1, concat, label])

            if use_wandb and samples_rows:
                import wandb
                samples_table = wandb.Table(columns=["sample_idx", "text", "label"], data=samples_rows)
                wandb.log({"eval/samples": samples_table}, step=step)

            model.train()

            # Eval generation accuracy with requested strategy
            class _Shim:
                def __init__(self, model, bos_id, eos_id, tok2id, id2tok, device):
                    self.model = model
                    self.bos_id = bos_id
                    self.eos_id = eos_id
                    self.tok2id = tok2id
                    self.id2tok = id2tok
                    self.device = device
                    self.sigma_ids = [i for i, t in sorted(id2tok.items(), key=lambda kv: kv[0]) if t not in ("[BOS]","[EOS]")]
                    self.sigma_tokens = [id2tok[i] for i in self.sigma_ids]
            lm_loaded = _Shim(model, bos_id, eos_id, tok2id, id2tok, device)
            sampler = LmSampler(lm_loaded, strategy=args.eval_gen_strategy, param=args.eval_gen_param)
            acc, correct, total = generation_accuracy(lm_loaded, sampler, language, args.eval_gen_samples, args.max_seq_len)
            print(f"  Eval gen acc ({args.eval_gen_strategy}, param={args.eval_gen_param}): "
                  f"{correct}/{total} = {acc*100:.2f}%")

            if use_wandb:
                import wandb
                wandb.log({"eval/gen_acc": acc}, step=step)

            # Early stopping check
            improved = (best_val - val_loss) > args.min_delta
            if improved:
                best_val = val_loss
                last_improve_step = step
            elif (step - last_improve_step) >= args.early_stop_patience:
                print(f"Early stopping at step {step} (no improvement for {args.early_stop_patience} steps).")
                break

        if step >= args.max_steps:
            break

    # Final generation accuracy eval at end of training
    acc, correct, total = generation_accuracy(lm_loaded, sampler, language, 1000, args.max_seq_len)
    print(f"Final eval gen acc: {correct}/{total} = {acc*100:.2f}%")
    if use_wandb:
        import wandb
        wandb.log({"eval/final_gen_acc": acc}, step=step)

    # ----- Save -----
    model.save_pretrained(args.save_dir)
    id2tok = {i: t for t, i in tok2id.items()}
    write_json(os.path.join(args.save_dir, "id2tok.json"), id2tok)
    print(f"Saved model to {args.save_dir}")

    if use_wandb:
        import wandb
        wandb.finish()


if __name__ == "__main__":
    main()
