from __future__ import annotations

import json
import os
import shutil
import subprocess
import time
from pathlib import Path

import modal


APP_NAME = "katago-smollm-finetune"
MODEL_NAME = os.environ.get("MODEL_NAME", "unsloth/SmolLM3-3B-Base")
DATA_VOLUME_NAME = os.environ.get("DATA_VOLUME_NAME", "katago-large-datasets")
OUTPUT_VOLUME_NAME = os.environ.get("OUTPUT_VOLUME_NAME", "katago-smollm-finetunes")
DATASET_STEM = os.environ.get("DATASET_STEM", "katago_large_dataset_v1")
RUN_NAME = os.environ.get("RUN_NAME", f"smollm3_3b_katago_{int(time.time())}")

REMOTE_CODE_DIR = "/root/smollm"
REMOTE_DATA_DIR = "/data"
REMOTE_OUTPUT_DIR = "/outputs"

LOCAL_DIR = Path(__file__).resolve().parent

image = (
    modal.Image.debian_slim(python_version="3.11")
    .apt_install("git")
    .pip_install(
        "accelerate>=0.34.0",
        "bitsandbytes>=0.43.3",
        "datasets>=2.21.0",
        "peft>=0.12.0",
        "sentencepiece>=0.2.0",
        "torch>=2.4.0",
        "transformers==4.53.0",
        "unsloth",
    )
    .add_local_dir(LOCAL_DIR, remote_path=REMOTE_CODE_DIR, copy=True)
)

app = modal.App(APP_NAME, image=image)
data_volume = modal.Volume.from_name(DATA_VOLUME_NAME)
output_volume = modal.Volume.from_name(OUTPUT_VOLUME_NAME, create_if_missing=True)


def dataset_paths() -> tuple[Path, Path]:
    dataset_dir = Path(REMOTE_DATA_DIR) / DATASET_STEM
    return dataset_dir / "train.jsonl", dataset_dir / "eval.jsonl"


def run_command(command: list[str], cwd: Path) -> None:
    print("[CMD] " + " ".join(command), flush=True)
    subprocess.run(command, cwd=cwd, check=True)
    output_volume.commit()


def base_train_command(output_dir: Path) -> list[str]:
    train_path, eval_path = dataset_paths()
    return [
        "python",
        "-u",
        "finetune_smollm.py",
        "--model-name",
        MODEL_NAME,
        "--train-path",
        str(train_path),
        "--eval-path",
        str(eval_path),
        "--output-dir",
        str(output_dir),
        "--bf16",
    ]


@app.function(
    gpu="A10G",
    cpu=8,
    memory=49152,
    timeout=60 * 60,
    volumes={REMOTE_DATA_DIR: data_volume, REMOTE_OUTPUT_DIR: output_volume},
)
def smoke_test() -> dict:
    train_path, eval_path = dataset_paths()
    if not train_path.exists() or not eval_path.exists():
        raise FileNotFoundError(f"Missing dataset files: {train_path} / {eval_path}")

    output_dir = Path(REMOTE_OUTPUT_DIR) / "smoke" / RUN_NAME
    if output_dir.exists():
        shutil.rmtree(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    command = base_train_command(output_dir) + [
        "--smoke-test",
        "--max-seq-len",
        "512",
        "--per-device-train-batch-size",
        "1",
        "--per-device-eval-batch-size",
        "1",
        "--gradient-accumulation-steps",
        "1",
    ]
    run_command(command, cwd=Path(REMOTE_CODE_DIR))
    metrics_path = output_dir / "metrics.json"
    return {
        "model_name": MODEL_NAME,
        "train_path": str(train_path),
        "eval_path": str(eval_path),
        "output_dir": str(output_dir),
        "metrics": json.loads(metrics_path.read_text()) if metrics_path.exists() else {},
    }


@app.function(
    gpu="A10G",
    cpu=8,
    memory=49152,
    timeout=60 * 60 * 10,
    volumes={REMOTE_DATA_DIR: data_volume, REMOTE_OUTPUT_DIR: output_volume},
)
def finetune(
    run_name: str = RUN_NAME,
    max_seq_len: int = 1024,
    num_train_epochs: float = 2.0,
    learning_rate: float = 2e-4,
    gradient_accumulation_steps: int = 16,
    max_steps: int = -1,
) -> dict:
    output_dir = Path(REMOTE_OUTPUT_DIR) / run_name
    output_dir.mkdir(parents=True, exist_ok=True)
    command = base_train_command(output_dir) + [
        "--max-seq-len",
        str(max_seq_len),
        "--num-train-epochs",
        str(num_train_epochs),
        "--learning-rate",
        str(learning_rate),
        "--gradient-accumulation-steps",
        str(gradient_accumulation_steps),
        "--max-steps",
        str(max_steps),
    ]
    run_command(command, cwd=Path(REMOTE_CODE_DIR))
    metrics_path = output_dir / "metrics.json"
    return {
        "model_name": MODEL_NAME,
        "run_name": run_name,
        "output_dir": str(output_dir),
        "metrics": json.loads(metrics_path.read_text()) if metrics_path.exists() else {},
    }


@app.function(
    gpu="A10G",
    cpu=8,
    memory=49152,
    timeout=60 * 60 * 2,
    volumes={REMOTE_DATA_DIR: data_volume, REMOTE_OUTPUT_DIR: output_volume},
)
def generate_random_explanations(
    run_name: str = RUN_NAME,
    split: str = "eval",
    num_samples: int = 100,
    seed: int = 42,
    adapter_run_name: str = "",
) -> dict:
    train_path, eval_path = dataset_paths()
    data_path = eval_path if split == "eval" else train_path
    output_dir = Path(REMOTE_OUTPUT_DIR) / "generations" / run_name
    output_dir.mkdir(parents=True, exist_ok=True)

    command = [
        "python",
        "-u",
        "generate_explanations.py",
        "--model-name",
        MODEL_NAME,
        "--data-path",
        str(data_path),
        "--output-dir",
        str(output_dir),
        "--num-samples",
        str(num_samples),
        "--seed",
        str(seed),
    ]
    if adapter_run_name:
        command.extend(["--adapter-path", str(Path(REMOTE_OUTPUT_DIR) / adapter_run_name / "adapter")])
    run_command(command, cwd=Path(REMOTE_CODE_DIR))
    manifest_path = output_dir / "manifest.json"
    return json.loads(manifest_path.read_text()) if manifest_path.exists() else {"output_dir": str(output_dir)}


@app.function(
    gpu="A10G",
    cpu=8,
    memory=49152,
    timeout=60 * 60 * 2,
    volumes={REMOTE_DATA_DIR: data_volume, REMOTE_OUTPUT_DIR: output_volume},
)
def consistency_smoke(run_name: str = RUN_NAME) -> dict:
    data_path = Path(REMOTE_DATA_DIR) / DATASET_STEM / "full.jsonl"
    output_dir = Path(REMOTE_OUTPUT_DIR) / "causal_consistency_smoke" / run_name
    if output_dir.exists():
        shutil.rmtree(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    command = [
        "python",
        "-u",
        "train_go_consistency.py",
        "--data-path",
        str(data_path),
        "--output-dir",
        str(output_dir),
        "--model-name",
        MODEL_NAME,
        "--train-games",
        "2",
        "--eval-games",
        "1",
        "--num-train-epochs",
        "1",
        "--per-device-train-batch-size",
        "1",
        "--per-device-eval-batch-size",
        "1",
        "--gradient-accumulation-steps",
        "1",
        "--max-seq-length",
        "768",
        "--generate-samples",
        "1",
        "--max-new-tokens",
        "32",
        "--bf16",
        "--log-every",
        "1",
    ]
    run_command(command, cwd=Path(REMOTE_CODE_DIR))
    metrics_path = output_dir / "metrics.json"
    return {
        "model_name": MODEL_NAME,
        "data_path": str(data_path),
        "output_dir": str(output_dir),
        "metrics": json.loads(metrics_path.read_text()) if metrics_path.exists() else {},
    }


@app.function(
    gpu="A10G",
    cpu=8,
    memory=49152,
    timeout=60 * 60 * 12,
    volumes={REMOTE_DATA_DIR: data_volume, REMOTE_OUTPUT_DIR: output_volume},
)
def consistency_train(
    run_name: str = RUN_NAME,
    max_seq_length: int = 1024,
    num_train_epochs: int = 3,
    learning_rate: float = 2e-4,
    per_device_train_batch_size: int = 2,
    per_device_eval_batch_size: int = 2,
    gradient_accumulation_steps: int = 8,
    lambda_claim: float = 1.0,
    generate_samples: int = 5,
    max_new_tokens: int = 160,
) -> dict:
    train_path, eval_path = dataset_paths()
    if not train_path.exists() or not eval_path.exists():
        raise FileNotFoundError(f"Missing dataset files: {train_path} / {eval_path}")

    output_dir = Path(REMOTE_OUTPUT_DIR) / "causal_consistency_train" / run_name
    output_dir.mkdir(parents=True, exist_ok=True)
    command = [
        "python",
        "-u",
        "train_go_consistency.py",
        "--train-data-path",
        str(train_path),
        "--eval-data-path",
        str(eval_path),
        "--output-dir",
        str(output_dir),
        "--model-name",
        MODEL_NAME,
        "--num-train-epochs",
        str(num_train_epochs),
        "--per-device-train-batch-size",
        str(per_device_train_batch_size),
        "--per-device-eval-batch-size",
        str(per_device_eval_batch_size),
        "--gradient-accumulation-steps",
        str(gradient_accumulation_steps),
        "--max-seq-length",
        str(max_seq_length),
        "--learning-rate",
        str(learning_rate),
        "--lambda-claim",
        str(lambda_claim),
        "--generate-samples",
        str(generate_samples),
        "--max-new-tokens",
        str(max_new_tokens),
        "--bf16",
        "--log-every",
        "10",
    ]
    run_command(command, cwd=Path(REMOTE_CODE_DIR))
    metrics_path = output_dir / "metrics.json"
    return {
        "model_name": MODEL_NAME,
        "train_path": str(train_path),
        "eval_path": str(eval_path),
        "run_name": run_name,
        "output_dir": str(output_dir),
        "metrics": json.loads(metrics_path.read_text()) if metrics_path.exists() else {},
    }


@app.function(
    cpu=2,
    memory=8192,
    timeout=60 * 10,
    volumes={REMOTE_DATA_DIR: data_volume},
)
def claim_value_stats() -> dict:
    data_path = Path(REMOTE_DATA_DIR) / DATASET_STEM / "full.jsonl"
    fields = [
        "win_prob_bin",
        "score_lead_bin",
        "phase_estimate",
        "main_control_region",
        "main_contested_region",
        "global_contestedness",
        "best_move_region",
        "move_urgency",
        "search_surprise",
    ]
    counts: dict[str, dict[str, int]] = {field: {} for field in fields}
    rows = 0
    rows_19 = 0
    with data_path.open() as f:
        for line in f:
            if not line.strip():
                continue
            row = json.loads(line)
            rows += 1
            if int(row.get("board_size") or 19) != 19:
                continue
            rows_19 += 1
            for field in fields:
                value = str(row.get(field) if row.get(field) is not None else "<MISSING>")
                counts[field][value] = counts[field].get(value, 0) + 1
    return {
        "data_path": str(data_path),
        "rows": rows,
        "rows_19x19": rows_19,
        "values": {field: sorted(counts[field]) for field in fields},
        "counts": {field: dict(sorted(counts[field].items())) for field in fields},
    }


@app.local_entrypoint()
def main(mode: str = "smoke") -> None:
    if mode == "smoke":
        result = smoke_test.remote()
    elif mode == "train":
        result = finetune.remote()
    elif mode == "generate":
        result = generate_random_explanations.remote()
    elif mode == "consistency-smoke":
        result = consistency_smoke.remote()
    elif mode == "consistency-train":
        result = consistency_train.remote()
    elif mode == "claim-stats":
        result = claim_value_stats.remote()
    else:
        raise ValueError(
            "mode must be 'smoke', 'train', 'generate', 'consistency-smoke', "
            "'consistency-train', or 'claim-stats'"
        )
    print(json.dumps(result, indent=2))
