from __future__ import annotations

import json
import os
import shutil
import subprocess
import time
from pathlib import Path

import modal


APP_NAME = "katago-gptoss-finetune"
MODEL_NAME = os.environ.get("MODEL_NAME", "unsloth/gpt-oss-20b")
DATA_VOLUME_NAME = os.environ.get("DATA_VOLUME_NAME", "katago-large-datasets")
OUTPUT_VOLUME_NAME = os.environ.get("OUTPUT_VOLUME_NAME", "katago-gptoss-finetunes")
DATASET_STEM = os.environ.get("DATASET_STEM", "katago_large_dataset_v1")
RUN_NAME = os.environ.get("RUN_NAME", f"gptoss_20b_katago_{int(time.time())}")

REMOTE_CODE_DIR = "/root/gptoss"
REMOTE_DATA_DIR = "/data"
REMOTE_OUTPUT_DIR = "/outputs"

LOCAL_DIR = Path(__file__).resolve().parent

image = (
    modal.Image.debian_slim(python_version="3.11")
    .apt_install("git")
    .pip_install(
        "accelerate>=0.34.0",
        "bitsandbytes>=0.43.3",
        "datasets>=2.21.0",
        "peft>=0.12.0",
        "sentencepiece>=0.2.0",
        "torch>=2.8.0",
        "triton>=3.4.0",
        "torchvision",
        "transformers>=4.51.3,<=5.5.0",
        "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo",
        "unsloth[base] @ git+https://github.com/unslothai/unsloth",
    )
    .add_local_dir(LOCAL_DIR, remote_path=REMOTE_CODE_DIR, copy=True)
)

app = modal.App(APP_NAME, image=image)
data_volume = modal.Volume.from_name(DATA_VOLUME_NAME)
output_volume = modal.Volume.from_name(OUTPUT_VOLUME_NAME, create_if_missing=True)


def dataset_paths() -> tuple[Path, Path]:
    dataset_dir = Path(REMOTE_DATA_DIR) / DATASET_STEM
    return dataset_dir / "train.jsonl", dataset_dir / "eval.jsonl"


def run_command(command: list[str], cwd: Path) -> None:
    print("[CMD] " + " ".join(command), flush=True)
    subprocess.run(command, cwd=cwd, check=True)
    output_volume.commit()


def base_train_command(output_dir: Path) -> list[str]:
    train_path, eval_path = dataset_paths()
    return [
        "python",
        "-u",
        "finetune_gptoss.py",
        "--model-name",
        MODEL_NAME,
        "--train-path",
        str(train_path),
        "--eval-path",
        str(eval_path),
        "--output-dir",
        str(output_dir),
        "--bf16",
    ]


@app.function(
    gpu="A10G",
    cpu=8,
    memory=49152,
    timeout=60 * 60,
    volumes={REMOTE_DATA_DIR: data_volume, REMOTE_OUTPUT_DIR: output_volume},
)
def smoke_test() -> dict:
    train_path, eval_path = dataset_paths()
    if not train_path.exists() or not eval_path.exists():
        raise FileNotFoundError(f"Missing dataset files: {train_path} / {eval_path}")

    output_dir = Path(REMOTE_OUTPUT_DIR) / "smoke" / RUN_NAME
    if output_dir.exists():
        shutil.rmtree(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    command = base_train_command(output_dir) + [
        "--smoke-test",
        "--max-seq-len",
        "2048",
        "--per-device-train-batch-size",
        "1",
        "--per-device-eval-batch-size",
        "1",
        "--gradient-accumulation-steps",
        "1",
    ]
    run_command(command, cwd=Path(REMOTE_CODE_DIR))
    metrics_path = output_dir / "metrics.json"
    return {
        "model_name": MODEL_NAME,
        "train_path": str(train_path),
        "eval_path": str(eval_path),
        "output_dir": str(output_dir),
        "metrics": json.loads(metrics_path.read_text()) if metrics_path.exists() else {},
    }


@app.function(
    gpu="A10G",
    cpu=8,
    memory=49152,
    timeout=60 * 60 * 10,
    volumes={REMOTE_DATA_DIR: data_volume, REMOTE_OUTPUT_DIR: output_volume},
)
def finetune(
    run_name: str = RUN_NAME,
    max_seq_len: int = 2048,
    num_train_epochs: float = 2.0,
    learning_rate: float = 2e-4,
    gradient_accumulation_steps: int = 16,
    max_steps: int = -1,
) -> dict:
    output_dir = Path(REMOTE_OUTPUT_DIR) / run_name
    output_dir.mkdir(parents=True, exist_ok=True)
    command = base_train_command(output_dir) + [
        "--max-seq-len",
        str(max_seq_len),
        "--num-train-epochs",
        str(num_train_epochs),
        "--learning-rate",
        str(learning_rate),
        "--gradient-accumulation-steps",
        str(gradient_accumulation_steps),
        "--max-steps",
        str(max_steps),
    ]
    run_command(command, cwd=Path(REMOTE_CODE_DIR))
    metrics_path = output_dir / "metrics.json"
    return {
        "model_name": MODEL_NAME,
        "run_name": run_name,
        "output_dir": str(output_dir),
        "metrics": json.loads(metrics_path.read_text()) if metrics_path.exists() else {},
    }


@app.function(
    gpu="A10G",
    cpu=8,
    memory=49152,
    timeout=60 * 60 * 2,
    volumes={REMOTE_DATA_DIR: data_volume, REMOTE_OUTPUT_DIR: output_volume},
)
def generate_random_explanations(
    run_name: str = RUN_NAME,
    split: str = "eval",
    num_samples: int = 100,
    seed: int = 42,
    adapter_run_name: str = "",
) -> dict:
    train_path, eval_path = dataset_paths()
    data_path = eval_path if split == "eval" else train_path
    output_dir = Path(REMOTE_OUTPUT_DIR) / "generations" / run_name
    output_dir.mkdir(parents=True, exist_ok=True)

    command = [
        "python",
        "-u",
        "generate_explanations.py",
        "--model-name",
        MODEL_NAME,
        "--data-path",
        str(data_path),
        "--output-dir",
        str(output_dir),
        "--num-samples",
        str(num_samples),
        "--seed",
        str(seed),
    ]
    if adapter_run_name:
        command.extend(["--adapter-path", str(Path(REMOTE_OUTPUT_DIR) / adapter_run_name / "adapter")])
    run_command(command, cwd=Path(REMOTE_CODE_DIR))
    manifest_path = output_dir / "manifest.json"
    return json.loads(manifest_path.read_text()) if manifest_path.exists() else {"output_dir": str(output_dir)}


@app.function(
    gpu="A10G",
    cpu=8,
    memory=49152,
    timeout=60 * 60 * 2,
    volumes={REMOTE_DATA_DIR: data_volume, REMOTE_OUTPUT_DIR: output_volume},
)
def generate_matrix_explanations(
    run_name: str = RUN_NAME,
    split: str = "eval",
    sample_index: int = 0,
    num_samples: int = 1,
    max_new_tokens: int = 180,
    fast_inference: bool = False,
) -> dict:
    train_path, eval_path = dataset_paths()
    data_path = eval_path if split == "eval" else train_path
    if not data_path.exists():
        raise FileNotFoundError(f"Missing dataset file: {data_path}")

    output_root = Path(REMOTE_OUTPUT_DIR) / "matrix_generations" / run_name
    output_root.mkdir(parents=True, exist_ok=True)
    env = os.environ.copy()
    if not fast_inference:
        env["UNSLOTH_COMPILE_DISABLE"] = "1"
        env["UNSLOTH_DISABLE_FAST_GENERATION"] = "1"

    command = [
        "python",
        "-u",
        "generate_matrix_explanations.py",
        "--model-name",
        MODEL_NAME,
        "--data-path",
        str(data_path),
        "--output-dir",
        str(output_root),
        "--sample-index",
        str(sample_index),
        "--num-samples",
        str(num_samples),
        "--max-new-tokens",
        str(max_new_tokens),
    ]
    if fast_inference:
        command.append("--fast-inference")
    print("[CMD] " + " ".join(command), flush=True)
    subprocess.run(command, cwd=Path(REMOTE_CODE_DIR), check=True, env=env)
    results_path = output_root / "matrix_inference_results.json"
    results_payload = json.loads(results_path.read_text()) if results_path.exists() else {}
    manifest = {
        "output_root": str(output_root),
        "split": split,
        "sample_index": sample_index,
        "num_samples": num_samples,
        "max_new_tokens": max_new_tokens,
        **results_payload,
    }
    (output_root / "manifest.json").write_text(json.dumps(manifest, indent=2))
    output_volume.commit()
    return manifest


@app.function(
    gpu="A10G",
    cpu=8,
    memory=49152,
    timeout=60 * 60 * 2,
    volumes={REMOTE_DATA_DIR: data_volume, REMOTE_OUTPUT_DIR: output_volume},
)
def consistency_smoke(run_name: str = RUN_NAME) -> dict:
    data_path = Path(REMOTE_DATA_DIR) / DATASET_STEM / "full.jsonl"
    output_dir = Path(REMOTE_OUTPUT_DIR) / "causal_consistency_smoke" / run_name
    if output_dir.exists():
        shutil.rmtree(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    command = [
        "python",
        "-u",
        "train_go_consistency.py",
        "--data-path",
        str(data_path),
        "--output-dir",
        str(output_dir),
        "--model-name",
        MODEL_NAME,
        "--train-games",
        "2",
        "--eval-games",
        "1",
        "--num-train-epochs",
        "1",
        "--per-device-train-batch-size",
        "1",
        "--per-device-eval-batch-size",
        "1",
        "--gradient-accumulation-steps",
        "1",
        "--max-seq-length",
        "768",
        "--lambda-lm",
        "0.0",
        "--lambda-claim",
        "1.0",
        "--generate-samples",
        "0",
        "--max-new-tokens",
        "32",
        "--bf16",
        "--log-every",
        "1",
    ]
    run_command(command, cwd=Path(REMOTE_CODE_DIR))
    metrics_path = output_dir / "metrics.json"
    return {
        "model_name": MODEL_NAME,
        "data_path": str(data_path),
        "output_dir": str(output_dir),
        "metrics": json.loads(metrics_path.read_text()) if metrics_path.exists() else {},
    }


@app.function(
    gpu="A10G",
    cpu=8,
    memory=49152,
    timeout=60 * 60 * 12,
    volumes={REMOTE_DATA_DIR: data_volume, REMOTE_OUTPUT_DIR: output_volume},
)
def consistency_train(
    run_name: str = RUN_NAME,
    max_seq_length: int = 1024,
    num_train_epochs: int = 3,
    learning_rate: float = 2e-4,
    per_device_train_batch_size: int = 1,
    per_device_eval_batch_size: int = 1,
    gradient_accumulation_steps: int = 16,
    lambda_claim: float = 1.0,
    generate_samples: int = 0,
    max_new_tokens: int = 160,
) -> dict:
    train_path, eval_path = dataset_paths()
    if not train_path.exists() or not eval_path.exists():
        raise FileNotFoundError(f"Missing dataset files: {train_path} / {eval_path}")

    output_dir = Path(REMOTE_OUTPUT_DIR) / "causal_consistency_train" / run_name
    output_dir.mkdir(parents=True, exist_ok=True)
    command = [
        "python",
        "-u",
        "train_go_consistency.py",
        "--train-data-path",
        str(train_path),
        "--eval-data-path",
        str(eval_path),
        "--output-dir",
        str(output_dir),
        "--model-name",
        MODEL_NAME,
        "--num-train-epochs",
        str(num_train_epochs),
        "--per-device-train-batch-size",
        str(per_device_train_batch_size),
        "--per-device-eval-batch-size",
        str(per_device_eval_batch_size),
        "--gradient-accumulation-steps",
        str(gradient_accumulation_steps),
        "--max-seq-length",
        str(max_seq_length),
        "--learning-rate",
        str(learning_rate),
        "--lambda-lm",
        "0.0",
        "--lambda-claim",
        str(lambda_claim),
        "--generate-samples",
        str(generate_samples),
        "--max-new-tokens",
        str(max_new_tokens),
        "--bf16",
        "--log-every",
        "10",
    ]
    run_command(command, cwd=Path(REMOTE_CODE_DIR))
    metrics_path = output_dir / "metrics.json"
    return {
        "model_name": MODEL_NAME,
        "train_path": str(train_path),
        "eval_path": str(eval_path),
        "run_name": run_name,
        "output_dir": str(output_dir),
        "metrics": json.loads(metrics_path.read_text()) if metrics_path.exists() else {},
    }


@app.function(
    gpu="A10G",
    cpu=8,
    memory=49152,
    timeout=60 * 60 * 4,
    volumes={REMOTE_DATA_DIR: data_volume, REMOTE_OUTPUT_DIR: output_volume},
)
def consistency_claim_10x10(run_name: str = RUN_NAME) -> dict:
    train_path, eval_path = dataset_paths()
    if not train_path.exists() or not eval_path.exists():
        raise FileNotFoundError(f"Missing dataset files: {train_path} / {eval_path}")

    output_dir = Path(REMOTE_OUTPUT_DIR) / "causal_consistency_claim_10x10" / run_name
    output_dir.mkdir(parents=True, exist_ok=True)
    command = [
        "python",
        "-u",
        "train_go_consistency.py",
        "--train-data-path",
        str(train_path),
        "--eval-data-path",
        str(eval_path),
        "--output-dir",
        str(output_dir),
        "--model-name",
        MODEL_NAME,
        "--max-train-positions",
        "10",
        "--max-eval-positions",
        "10",
        "--num-train-epochs",
        "10",
        "--per-device-train-batch-size",
        "1",
        "--per-device-eval-batch-size",
        "1",
        "--gradient-accumulation-steps",
        "1",
        "--max-seq-length",
        "1024",
        "--learning-rate",
        "2e-4",
        "--lambda-lm",
        "0.0",
        "--lambda-claim",
        "1.0",
        "--generate-samples",
        "0",
        "--bf16",
        "--log-every",
        "1",
    ]
    run_command(command, cwd=Path(REMOTE_CODE_DIR))
    metrics_path = output_dir / "metrics.json"
    return {
        "model_name": MODEL_NAME,
        "train_path": str(train_path),
        "eval_path": str(eval_path),
        "run_name": run_name,
        "output_dir": str(output_dir),
        "metrics": json.loads(metrics_path.read_text()) if metrics_path.exists() else {},
    }


@app.function(
    gpu="A10G",
    cpu=8,
    memory=49152,
    timeout=60 * 60 * 14,
    volumes={REMOTE_DATA_DIR: data_volume, REMOTE_OUTPUT_DIR: output_volume},
)
def consistency_claim_1k_200(run_name: str = RUN_NAME) -> dict:
    train_path, eval_path = dataset_paths()
    if not train_path.exists() or not eval_path.exists():
        raise FileNotFoundError(f"Missing dataset files: {train_path} / {eval_path}")

    output_dir = Path(REMOTE_OUTPUT_DIR) / "causal_consistency_claim_1k_200" / run_name
    output_dir.mkdir(parents=True, exist_ok=True)
    command = [
        "python",
        "-u",
        "train_go_consistency.py",
        "--train-data-path",
        str(train_path),
        "--eval-data-path",
        str(eval_path),
        "--output-dir",
        str(output_dir),
        "--model-name",
        MODEL_NAME,
        "--max-train-positions",
        "1000",
        "--max-eval-positions",
        "200",
        "--num-train-epochs",
        "5",
        "--per-device-train-batch-size",
        "2",
        "--per-device-eval-batch-size",
        "2",
        "--gradient-accumulation-steps",
        "4",
        "--max-seq-length",
        "1024",
        "--learning-rate",
        "2e-4",
        "--lambda-lm",
        "0.0",
        "--lambda-claim",
        "1.0",
        "--generate-samples",
        "0",
        "--bf16",
        "--log-every",
        "25",
    ]
    run_command(command, cwd=Path(REMOTE_CODE_DIR))
    metrics_path = output_dir / "metrics.json"
    return {
        "model_name": MODEL_NAME,
        "train_path": str(train_path),
        "eval_path": str(eval_path),
        "run_name": run_name,
        "output_dir": str(output_dir),
        "metrics": json.loads(metrics_path.read_text()) if metrics_path.exists() else {},
    }


@app.function(
    gpu="H100",
    cpu=8,
    memory=49152,
    timeout=60 * 60 * 24,
    volumes={REMOTE_DATA_DIR: data_volume, REMOTE_OUTPUT_DIR: output_volume},
)
def consistency_claim_full_10ep(run_name: str = RUN_NAME) -> dict:
    train_path, eval_path = dataset_paths()
    if not train_path.exists() or not eval_path.exists():
        raise FileNotFoundError(f"Missing dataset files: {train_path} / {eval_path}")

    output_dir = Path(REMOTE_OUTPUT_DIR) / "causal_consistency_claim_full_10ep" / run_name
    output_dir.mkdir(parents=True, exist_ok=True)
    command = [
        "python",
        "-u",
        "train_go_consistency.py",
        "--train-data-path",
        str(train_path),
        "--eval-data-path",
        str(eval_path),
        "--output-dir",
        str(output_dir),
        "--model-name",
        MODEL_NAME,
        "--num-train-epochs",
        "10",
        "--per-device-train-batch-size",
        "2",
        "--per-device-eval-batch-size",
        "2",
        "--gradient-accumulation-steps",
        "4",
        "--max-seq-length",
        "1024",
        "--learning-rate",
        "2e-4",
        "--lambda-lm",
        "0.0",
        "--lambda-claim",
        "1.0",
        "--generate-samples",
        "0",
        "--bf16",
        "--log-every",
        "25",
    ]
    run_command(command, cwd=Path(REMOTE_CODE_DIR))
    metrics_path = output_dir / "metrics.json"
    return {
        "model_name": MODEL_NAME,
        "train_path": str(train_path),
        "eval_path": str(eval_path),
        "run_name": run_name,
        "output_dir": str(output_dir),
        "metrics": json.loads(metrics_path.read_text()) if metrics_path.exists() else {},
    }


@app.function(
    cpu=2,
    memory=8192,
    timeout=60 * 10,
    volumes={REMOTE_DATA_DIR: data_volume},
)
def claim_value_stats() -> dict:
    data_path = Path(REMOTE_DATA_DIR) / DATASET_STEM / "full.jsonl"
    fields = [
        "win_prob_bin",
        "score_lead_bin",
        "phase_estimate",
        "main_control_region",
        "main_contested_region",
        "global_contestedness",
        "best_move_region",
        "move_urgency",
        "search_surprise",
    ]
    counts: dict[str, dict[str, int]] = {field: {} for field in fields}
    rows = 0
    rows_19 = 0
    with data_path.open() as f:
        for line in f:
            if not line.strip():
                continue
            row = json.loads(line)
            rows += 1
            if int(row.get("board_size") or 19) != 19:
                continue
            rows_19 += 1
            for field in fields:
                value = str(row.get(field) if row.get(field) is not None else "<MISSING>")
                counts[field][value] = counts[field].get(value, 0) + 1
    return {
        "data_path": str(data_path),
        "rows": rows,
        "rows_19x19": rows_19,
        "values": {field: sorted(counts[field]) for field in fields},
        "counts": {field: dict(sorted(counts[field].items())) for field in fields},
    }


@app.function(cpu=2, memory=4096, timeout=60 * 10)
def dependency_versions() -> dict:
    import importlib.metadata
    import importlib.util

    packages = ["unsloth", "unsloth_zoo", "transformers", "torch", "triton", "accelerate"]
    versions = {}
    for package in packages:
        try:
            versions[package] = importlib.metadata.version(package)
        except importlib.metadata.PackageNotFoundError:
            versions[package] = "not installed"

    module_paths = {}
    for module_name in ["unsloth", "unsloth_zoo"]:
        spec = importlib.util.find_spec(module_name)
        module_paths[module_name] = spec.origin if spec and spec.origin else "not found"

    return {"versions": versions, "module_paths": module_paths}


@app.local_entrypoint()
def main(
    mode: str = "smoke",
    split: str = "eval",
    sample_index: int = 0,
    num_samples: int = 1,
    max_new_tokens: int = 180,
    fast_inference: bool = False,
) -> None:
    if mode == "smoke":
        result = smoke_test.remote()
    elif mode == "train":
        result = finetune.remote()
    elif mode == "generate":
        result = generate_random_explanations.remote()
    elif mode == "matrix-generate":
        result = generate_matrix_explanations.remote(
            split=split,
            sample_index=sample_index,
            num_samples=num_samples,
            max_new_tokens=max_new_tokens,
            fast_inference=fast_inference,
        )
    elif mode == "consistency-smoke":
        result = consistency_smoke.remote()
    elif mode == "consistency-train":
        result = consistency_train.remote()
    elif mode == "consistency-claim-10x10":
        result = consistency_claim_10x10.remote()
    elif mode == "consistency-claim-1k-200":
        result = consistency_claim_1k_200.remote()
    elif mode == "consistency-claim-full-10ep":
        result = consistency_claim_full_10ep.remote()
    elif mode == "claim-stats":
        result = claim_value_stats.remote()
    elif mode == "deps":
        result = dependency_versions.remote()
    else:
        raise ValueError(
            "mode must be 'smoke', 'train', 'generate', 'matrix-generate', 'consistency-smoke', "
            "'consistency-train', 'consistency-claim-10x10', 'consistency-claim-1k-200', "
            "'consistency-claim-full-10ep', 'claim-stats', or 'deps'"
        )
    print(json.dumps(result, indent=2))
