"""
Small-scale training script to compare:
- SimpleMambaModel (selective Mamba block)
- SSDMambaModelExp (diagonal SSD variant)


"""

from __future__ import annotations

import argparse
import time
from pathlib import Path
from typing import Tuple

import json
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from .data import load_wikitext2
from .mamba_simple_wrapper import SimpleMambaModel
from .mamba_SSD_diag_exp import SSDMambaModelExp
try:
    from mamba_experiments.mamba_SSD import ModelArgs  # type: ignore
except ImportError:
    ModelArgs = None  # type: ignore


MODEL_DISPLAY_NAMES = {
    "mamba_simple": "Original Mamba",
    "mamba_SSD_diag_exp": "Diagonal SSD Mamba",
}

DATASET_DISPLAY_NAMES = {
    "wikitext2": "WikiText-2",
}


def _display_name(mapping: dict[str, str], key: str, fallback: str) -> str:
    if not key:
        return fallback
    return mapping.get(key, key.replace("_", " ").title())


def _write_plot_data(
    args,
    plot_dir: Path,
    epochs: list[int],
    train_losses: list[float],
    val_losses: list[float],
    train_accs: list[float],
    val_accs: list[float],
    model_name: str,
    dataset_name: str,
) -> Path:
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    dataset = getattr(args, "dataset", "unknown")
    model = getattr(args, "model", "unknown")
    filename = f"{model}_{dataset}_{timestamp}_plot_data.json"
    payload = {
        "metadata": {
            "model": model,
            "model_display_name": model_name,
            "dataset": dataset,
            "dataset_display_name": dataset_name,
            "seed": getattr(args, "seed", None),
            "seq_len": getattr(args, "seq_len", None),
            "n_layers": getattr(args, "n_layers", None),
            "n_state": getattr(args, "n_state", None),
            "d_model": getattr(args, "d_model", None),
            "batch_size": getattr(args, "batch_size", None),
            "epochs": len(epochs),
        },
        "epochs": epochs,
        "train_loss": train_losses,
        "val_loss": val_losses,
        "train_acc": train_accs,
        "val_acc": val_accs,
    }
    plot_dir.mkdir(parents=True, exist_ok=True)
    output_path = plot_dir / filename
    with output_path.open("w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2)
    return output_path


def make_model(kind: str, vocab_size: int, d_model: int, n_state: int, n_layers: int) -> nn.Module:
    if kind == "mamba_simple":
        return SimpleMambaModel(vocab_size=vocab_size, d_model=d_model, n_layers=n_layers, d_state=n_state)
    if kind == "mamba_SSD_diag_exp":
        if ModelArgs is None:
            raise ImportError("mamba_SSD.py not found; place it at repo root.")
        args = ModelArgs(d_model=d_model, n_layer=n_layers, vocab_size=vocab_size, d_state=n_state)
        return SSDMambaModelExp(args)
    raise ValueError(f"Unknown model kind: {kind}")


def train_one_epoch(
    model: nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer, device: torch.device
) -> Tuple[float, float]:
    model.train()
    total_loss = 0.0
    total_correct = 0
    total = 0
    criterion = nn.CrossEntropyLoss()
    for x, target in loader:
        x = x.to(device)
        target = target.to(device)
        logits = model(x)
        loss = criterion(logits, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=-1)
        total_correct += (preds == target).sum().item()
        total += x.size(0)
    return total_loss / total, total_correct / total


def evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> Tuple[float, float]:
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total = 0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for x, target in loader:
            x = x.to(device)
            target = target.to(device)
            logits = model(x)
            loss = criterion(logits, target)
            total_loss += loss.item() * x.size(0)
            preds = logits.argmax(dim=-1)
            total_correct += (preds == target).sum().item()
            total += x.size(0)
    return total_loss / total, total_correct / total


def save_plots(args, train_losses, val_losses, train_accs, val_accs) -> None:
    try:
        import matplotlib.pyplot as plt  # type: ignore
    except ImportError:  # pragma: no cover
        print("matplotlib not available; skipping plot generation.")
        return

    plot_dir = Path(args.plot_dir)
    plot_dir.mkdir(parents=True, exist_ok=True)
    epochs = list(range(1, len(train_losses) + 1))
    model_name = _display_name(MODEL_DISPLAY_NAMES, getattr(args, "model", ""), "Model")
    dataset_name = _display_name(DATASET_DISPLAY_NAMES, getattr(args, "dataset", ""), "Dataset")
    title_context = f"{model_name} on {dataset_name}" if dataset_name else model_name
    loss_title = f"{title_context}: Loss per Epoch"
    acc_title = f"{title_context}: Accuracy per Epoch"

    def _style_axes(title: str, ylabel: str, legend_title: str) -> None:
        plt.xlabel("Epoch", fontsize=24)
        plt.ylabel(ylabel, fontsize=24)
        plt.title(title, fontsize=20, fontweight="bold")
        plt.tick_params(axis="both", which="major", labelsize=20)
        plt.grid(True, alpha=0.3)
        plt.legend(title=legend_title, fontsize=16, title_fontsize=16)
        plt.tight_layout()

    plt.figure(figsize=(8, 6))
    plt.plot(
        epochs,
        train_losses,
        marker="o",
        linewidth=4,
        alpha=0.9,
        color="darkgrey",
        label="Training Loss",
    )
    plt.plot(
        epochs,
        val_losses,
        marker="s",
        linewidth=4,
        alpha=0.9,
        color="darkred",
        label="Validation Loss",
    )
    _style_axes(loss_title, "Loss", "Data split")
    plt.savefig(plot_dir / f"{args.model}_loss.png", dpi=225, bbox_inches="tight")
    plt.close()

    data_path = _write_plot_data(
        args,
        plot_dir,
        epochs,
        train_losses,
        val_losses,
        train_accs,
        val_accs,
        model_name,
        dataset_name,
    )
    print(f"Saved plot data to {data_path}")

    plt.figure(figsize=(8, 6))
    plt.plot(
        epochs,
        train_accs,
        marker="o",
        linewidth=4,
        alpha=0.9,
        color="darkgrey",
        label="Training Accuracy",
    )
    plt.plot(
        epochs,
        val_accs,
        marker="s",
        linewidth=4,
        alpha=0.9,
        color="darkred",
        label="Validation Accuracy",
    )
    _style_axes(acc_title, "Accuracy", "Data split")
    plt.savefig(plot_dir / f"{args.model}_acc.png", dpi=225, bbox_inches="tight")
    plt.close()


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Small-scale Mamba vs baseline comparison.")
    p.add_argument(
        "--model",
        choices=["mamba_simple", "mamba_SSD_diag_exp"],
        default="mamba_simple",
    )
    p.add_argument("--dataset", choices=["wikitext2"], default="wikitext2")
    p.add_argument("--seq-len", type=int, default=64)
    p.add_argument("--max-vocab", type=int, default=20000, help="Max vocab for WikiText-2.")
    p.add_argument("--train-max-samples", type=int, default=5000, help="Cap train samples for WikiText-2 (None for all).")
    p.add_argument("--val-max-samples", type=int, default=1000, help="Cap val samples for WikiText-2 (None for all).")
    p.add_argument("--batch-size", type=int, default=64)
    p.add_argument("--epochs", type=int, default=5)
    p.add_argument("--d-model", type=int, default=64)
    p.add_argument("--n-state", type=int, default=8, help="State dim for Mamba (ignored by others).")
    p.add_argument("--n-layers", type=int, default=2)
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--save-plots", action="store_true", help="Save training curves (loss/acc) to PNG.")
    p.add_argument("--plot-dir", type=str, default="experiments/logs", help="Directory to save plots.")
    return p.parse_args()


def main() -> None:
    args = parse_args()
    torch.manual_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if args.dataset == "wikitext2":
        train_ds, val_ds, vocab_size = load_wikitext2(
            seq_len=args.seq_len,
            max_vocab=args.max_vocab,
            train_max_samples=None if args.train_max_samples <= 0 else args.train_max_samples,
            val_max_samples=None if args.val_max_samples <= 0 else args.val_max_samples,
        )
    else:
        raise ValueError(f"Unknown dataset: {args.dataset}")

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False)

    model = make_model(args.model, vocab_size, args.d_model, args.n_state, args.n_layers).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)

    print(f"Device: {device}, Model: {args.model}")
    start_time = time.time()
    log_history = []
    train_losses, val_losses, train_accs, val_accs = [], [], [], []
    for epoch in range(1, args.epochs + 1):
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, device)
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        log_history.append(
            {
                "epoch": epoch,
                "train_loss": train_loss,
                "train_acc": train_acc,
                "val_loss": val_loss,
                "val_acc": val_acc,
            }
        )
        print(
            f"Epoch {epoch:02d} | train_loss={train_loss:.4f} acc={train_acc:.3f} | "
            f"val_loss={val_loss:.4f} acc={val_acc:.3f}"
        )
    total_time = time.time() - start_time
    print(f"Total training time: {total_time:.2f}s")

    if args.save_plots:
        save_plots(args, train_losses, val_losses, train_accs, val_accs)

    # Log metrics to JSONL
    log_dir = Path(args.plot_dir)
    log_dir.mkdir(parents=True, exist_ok=True)
    record = {
        "model": args.model,
        "dataset": args.dataset,
        "seq_len": args.seq_len,
        "vocab_size": args.vocab_size,
        "train_size": args.train_size,
        "val_size": args.val_size,
        "max_vocab": args.max_vocab,
        "train_max_samples": args.train_max_samples,
        "val_max_samples": args.val_max_samples,
        "batch_size": args.batch_size,
        "epochs": args.epochs,
        "d_model": args.d_model,
        "n_state": args.n_state,
        "n_layers": args.n_layers,
        "lr": args.lr,
        "seed": args.seed,
        "total_time": total_time,
        "history": log_history,
    }
    with (log_dir / "train_runs.jsonl").open("a", encoding="utf-8") as f:
        f.write(json.dumps(record) + "\n")


if __name__ == "__main__":
    main()
