import argparse
import io
import os

import numpy as np

from sequence_utils import build_char_vocab, encode_text, run_lm_task


def try_load_wikitext2_raw(data_path: str | None) -> tuple[str, str, str]:
    def read_file(path: str) -> str:
        with io.open(path, "r", encoding="utf-8") as f:
            return f.read()

    def try_local(dirpath: str | None) -> tuple[str, str, str] | None:
        if dirpath is None:
            return None
        train_path = os.path.join(dirpath, "wiki.train.raw")
        valid_path = os.path.join(dirpath, "wiki.valid.raw")
        test_path = os.path.join(dirpath, "wiki.test.raw")
        if os.path.exists(train_path) and os.path.exists(valid_path) and os.path.exists(test_path):
            print(f"[WikiText-2] Loaded raw files from '{dirpath}/'.")
            return read_file(train_path), read_file(valid_path), read_file(test_path)
        return None

    for cand in [
        data_path,
        os.path.join("data", "wikitext2_raw"),
        "wikitext-2-raw",
        "wikitext-2",
        "./data",
        "./",
    ]:
        got = try_local(cand)
        if got is not None:
            return got

    from datasets import load_dataset  # type: ignore

    print("[WikiText-2] Downloading wikitext-2-raw-v1 from HuggingFace datasets ...")
    ds = load_dataset("wikitext", "wikitext-2-raw-v1")
    train_text = "\n".join(ds["train"]["text"])
    valid_text = "\n".join(ds["validation"]["text"])
    test_text = "\n".join(ds["test"]["text"])

    cache_dir = data_path or os.path.join("data", "wikitext2_raw")
    os.makedirs(cache_dir, exist_ok=True)
    with io.open(os.path.join(cache_dir, "wiki.train.raw"), "w", encoding="utf-8") as f:
        f.write(train_text)
    with io.open(os.path.join(cache_dir, "wiki.valid.raw"), "w", encoding="utf-8") as f:
        f.write(valid_text)
    with io.open(os.path.join(cache_dir, "wiki.test.raw"), "w", encoding="utf-8") as f:
        f.write(test_text)
    print(f"[WikiText-2] Cached raw files under '{cache_dir}/'.")
    return train_text, valid_text, test_text


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="wikitext2_char")
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--scan-epochs", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--hidden", type=int, default=128)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--wikitext-path", type=str, default=None)
    parser.add_argument("--ptb-block-size", "--block-size", dest="ptb_block_size", type=int, default=80)
    parser.add_argument(
        "--ptb-steps-per-epoch",
        "--steps-per-epoch",
        dest="ptb_steps_per_epoch",
        type=int,
        default=200,
    )
    parser.add_argument("--ptb-val-steps", "--val-steps", dest="ptb_val_steps", type=int, default=60)
    parser.add_argument("--ptb-max-chars", "--max-chars", dest="ptb_max_chars", type=int, default=None)
    parser.add_argument("--eval-every", type=int, default=5)
    parser.add_argument("--gains", type=str, default=None)
    parser.add_argument("--step-labels", type=str, choices=["final", "fptt"], default="final")
    parser.add_argument("--no-plot", action="store_true")
    parser.add_argument("--no-eprop", action="store_true")
    parser.add_argument("--plot-path", type=str, default=None)
    args = parser.parse_args()
    args.plot = not args.no_plot

    gains_default = np.linspace(0.1, 2.2, 15, endpoint=False)
    if args.gains:
        gains = np.array([float(x) for x in args.gains.split(",") if x.strip()], dtype=np.float32)
        if gains.size == 0:
            gains = gains_default
    else:
        gains = gains_default

    train_text, valid_text, test_text = try_load_wikitext2_raw(args.wikitext_path)
    stoi, _ = build_char_vocab(train_text)
    train_ids = encode_text(train_text, stoi)
    valid_ids = encode_text(valid_text, stoi)
    test_ids = encode_text(test_text, stoi)

    if args.ptb_max_chars is not None:
        train_ids = train_ids[: args.ptb_max_chars]
        valid_ids = valid_ids[: args.ptb_max_chars]
        test_ids = test_ids[: args.ptb_max_chars]

    vocab_size = len(stoi)
    task_data = {
        "task_type": "lm",
        "task_name": "WikiText-2 Char LM",
        "train_data": train_ids,
        "valid_data": valid_ids,
        "test_data": test_ids,
        "input_size": vocab_size,
        "output_size": vocab_size,
    }

    print("STAGE 1: Scanning gains for WikiText-2 char LM...")
    print("STAGE 2: Local Rule vs BPTT/E-Prop/FPTT on WikiText-2 char LM...")
    run_lm_task(task_data, args, gains)


if __name__ == "__main__":
    main()
