# TASK_SPEC: ptb_word_v1
import argparse
from collections import Counter

import numpy as np

from sequence_utils import run_lm_task, try_load_ptb_text


def tokenize(text: str) -> list[str]:
    text = text.replace("\r\n", "\n")
    text = text.replace("\n", " <eos> ")
    tokens = text.strip().split()
    return tokens


def build_vocab(
    tokens: list[str],
    min_freq: int,
    max_vocab: int | None,
) -> tuple[dict[str, int], dict[int, str]]:
    counter = Counter(tokens)
    vocab = [tok for tok, freq in counter.items() if freq >= min_freq]
    vocab.sort(key=lambda t: (-counter[t], t))
    if max_vocab is not None:
        vocab = vocab[: max(1, max_vocab)]
    if "<unk>" in vocab:
        vocab.remove("<unk>")
    vocab.insert(0, "<unk>")
    stoi = {tok: i for i, tok in enumerate(vocab)}
    itos = {i: tok for tok, i in stoi.items()}
    return stoi, itos


def encode_tokens(tokens: list[str], stoi: dict[str, int]) -> np.ndarray:
    unk_id = stoi.get("<unk>", 0)
    return np.fromiter((stoi.get(tok, unk_id) for tok in tokens), dtype=np.int64)


def make_dummy_tokens(
    total_tokens: int,
    vocab_size: int,
    eos_every: int,
    seed: int,
) -> list[str]:
    rng = np.random.default_rng(seed)
    vocab = [f"tok{i}" for i in range(vocab_size)]
    tokens = rng.choice(vocab, size=total_tokens, replace=True).tolist()
    if eos_every > 0:
        for idx in range(eos_every - 1, total_tokens, eos_every):
            tokens[idx] = "<eos>"
    return tokens


def load_tokens(
    ptb_path: str | None,
    use_dummy: bool,
    dummy_tokens: int,
    dummy_vocab: int,
    dummy_eos_every: int,
    seed: int,
) -> tuple[list[str], list[str], list[str]]:
    if use_dummy:
        tokens = make_dummy_tokens(dummy_tokens, dummy_vocab, dummy_eos_every, seed)
        n_train = int(len(tokens) * 0.8)
        n_valid = int(len(tokens) * 0.1)
        train_tokens = tokens[:n_train]
        valid_tokens = tokens[n_train : n_train + n_valid]
        test_tokens = tokens[n_train + n_valid :]
        return train_tokens, valid_tokens, test_tokens

    try:
        train_text, valid_text, test_text = try_load_ptb_text(ptb_path)
    except Exception as exc:
        raise RuntimeError(
            "PTB data unavailable. Provide --ptb-path or pass --dummy for debug data."
        ) from exc

    return tokenize(train_text), tokenize(valid_text), tokenize(test_text)


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="ptb_word")
    parser.add_argument("--epochs", type=int, default=8)
    parser.add_argument("--scan-epochs", type=int, default=3)
    parser.add_argument("--batch-size", type=int, default=48)
    parser.add_argument("--hidden", type=int, default=128)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--ptb-path", type=str, default=None)
    parser.add_argument("--ptb-block-size", type=int, default=35)
    parser.add_argument("--ptb-steps-per-epoch", type=int, default=300)
    parser.add_argument("--ptb-val-steps", type=int, default=60)
    parser.add_argument("--ptb-max-tokens", type=int, default=None)
    parser.add_argument("--min-freq", type=int, default=1)
    parser.add_argument("--max-vocab", type=int, default=10000)
    parser.add_argument("--dummy", action="store_true")
    parser.add_argument("--dummy-tokens", type=int, default=50000)
    parser.add_argument("--dummy-vocab", type=int, default=100)
    parser.add_argument("--dummy-eos-every", type=int, default=20)
    parser.add_argument("--gains", type=str, default=None)
    parser.add_argument("--step-labels", type=str, choices=["final", "fptt"], default="final")
    parser.add_argument("--no-plot", action="store_true")
    parser.add_argument("--no-eprop", action="store_true")
    parser.add_argument("--plot-path", type=str, default=None)
    args = parser.parse_args()
    args.plot = not args.no_plot

    gains_default = np.linspace(0.5, 1.6, 8, endpoint=False)
    if args.gains:
        gains = np.array(
            [float(x) for x in args.gains.split(",") if x.strip()],
            dtype=np.float32,
        )
        if gains.size == 0:
            gains = gains_default
    else:
        gains = gains_default

    train_tokens, valid_tokens, test_tokens = load_tokens(
        args.ptb_path,
        args.dummy,
        args.dummy_tokens,
        args.dummy_vocab,
        args.dummy_eos_every,
        args.seed,
    )

    if args.ptb_max_tokens is not None:
        train_tokens = train_tokens[: args.ptb_max_tokens]
        valid_tokens = valid_tokens[: args.ptb_max_tokens]
        test_tokens = test_tokens[: args.ptb_max_tokens]

    stoi, _ = build_vocab(train_tokens, args.min_freq, args.max_vocab)
    train_ids = encode_tokens(train_tokens, stoi)
    valid_ids = encode_tokens(valid_tokens, stoi)
    test_ids = encode_tokens(test_tokens, stoi)

    vocab_size = len(stoi)
    task_data = {
        "task_type": "lm",
        "task_name": "PTB Word LM",
        "train_data": train_ids,
        "valid_data": valid_ids,
        "test_data": test_ids,
        "input_size": vocab_size,
        "output_size": vocab_size,
    }

    print("STAGE 1: Scanning gains for PTB word LM...")
    print("STAGE 2: Local Rule vs BPTT/E-Prop/FPTT on PTB word LM...")
    run_lm_task(task_data, args, gains)


if __name__ == "__main__":
    main()
