from __future__ import annotations

import json
import os
import shutil
import subprocess
import time
from pathlib import Path

import modal


APP_NAME = "katago-smollm-finetune"
MODEL_NAME = os.environ.get("MODEL_NAME", "unsloth/SmolLM3-3B-Base")
DATA_VOLUME_NAME = os.environ.get("DATA_VOLUME_NAME", "katago-large-datasets")
OUTPUT_VOLUME_NAME = os.environ.get("OUTPUT_VOLUME_NAME", "katago-smollm-finetunes")
DATASET_STEM = os.environ.get("DATASET_STEM", "katago_large_dataset_v1")
RUN_NAME = os.environ.get("RUN_NAME", f"smollm3_3b_katago_{int(time.time())}")

REMOTE_CODE_DIR = "/root/smollm"
REMOTE_DATA_DIR = "/data"
REMOTE_OUTPUT_DIR = "/outputs"

LOCAL_DIR = Path(__file__).resolve().parent

image = (
    modal.Image.debian_slim(python_version="3.11")
    .apt_install("git")
    .pip_install(
        "accelerate>=0.34.0",
        "bitsandbytes>=0.43.3",
        "datasets>=2.21.0",
        "peft>=0.12.0",
        "sentencepiece>=0.2.0",
        "torch>=2.4.0",
        "transformers==4.53.0",
    )
    .add_local_dir(LOCAL_DIR, remote_path=REMOTE_CODE_DIR, copy=True)
)

app = modal.App(APP_NAME, image=image)
data_volume = modal.Volume.from_name(DATA_VOLUME_NAME)
output_volume = modal.Volume.from_name(OUTPUT_VOLUME_NAME, create_if_missing=True)


def dataset_paths() -> tuple[Path, Path]:
    dataset_dir = Path(REMOTE_DATA_DIR) / DATASET_STEM
    return dataset_dir / "train.jsonl", dataset_dir / "eval.jsonl"


def run_command(command: list[str], cwd: Path) -> None:
    print("[CMD] " + " ".join(command), flush=True)
    subprocess.run(command, cwd=cwd, check=True)
    output_volume.commit()


def base_train_command(output_dir: Path) -> list[str]:
    train_path, eval_path = dataset_paths()
    return [
        "python",
        "-u",
        "finetune_smollm.py",
        "--model-name",
        MODEL_NAME,
        "--train-path",
        str(train_path),
        "--eval-path",
        str(eval_path),
        "--output-dir",
        str(output_dir),
        "--bf16",
    ]


@app.function(
    gpu="A10G",
    cpu=8,
    memory=49152,
    timeout=60 * 60,
    volumes={REMOTE_DATA_DIR: data_volume, REMOTE_OUTPUT_DIR: output_volume},
)
def smoke_test() -> dict:
    train_path, eval_path = dataset_paths()
    if not train_path.exists() or not eval_path.exists():
        raise FileNotFoundError(f"Missing dataset files: {train_path} / {eval_path}")

    output_dir = Path(REMOTE_OUTPUT_DIR) / "smoke" / RUN_NAME
    if output_dir.exists():
        shutil.rmtree(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    command = base_train_command(output_dir) + [
        "--smoke-test",
        "--max-seq-len",
        "512",
        "--per-device-train-batch-size",
        "1",
        "--per-device-eval-batch-size",
        "1",
        "--gradient-accumulation-steps",
        "1",
    ]
    run_command(command, cwd=Path(REMOTE_CODE_DIR))
    metrics_path = output_dir / "metrics.json"
    return {
        "model_name": MODEL_NAME,
        "train_path": str(train_path),
        "eval_path": str(eval_path),
        "output_dir": str(output_dir),
        "metrics": json.loads(metrics_path.read_text()) if metrics_path.exists() else {},
    }


@app.function(
    gpu="A10G",
    cpu=8,
    memory=49152,
    timeout=60 * 60 * 10,
    volumes={REMOTE_DATA_DIR: data_volume, REMOTE_OUTPUT_DIR: output_volume},
)
def finetune(
    run_name: str = RUN_NAME,
    max_seq_len: int = 1024,
    num_train_epochs: float = 2.0,
    learning_rate: float = 2e-4,
    gradient_accumulation_steps: int = 16,
    max_steps: int = -1,
) -> dict:
    output_dir = Path(REMOTE_OUTPUT_DIR) / run_name
    output_dir.mkdir(parents=True, exist_ok=True)
    command = base_train_command(output_dir) + [
        "--max-seq-len",
        str(max_seq_len),
        "--num-train-epochs",
        str(num_train_epochs),
        "--learning-rate",
        str(learning_rate),
        "--gradient-accumulation-steps",
        str(gradient_accumulation_steps),
        "--max-steps",
        str(max_steps),
    ]
    run_command(command, cwd=Path(REMOTE_CODE_DIR))
    metrics_path = output_dir / "metrics.json"
    return {
        "model_name": MODEL_NAME,
        "run_name": run_name,
        "output_dir": str(output_dir),
        "metrics": json.loads(metrics_path.read_text()) if metrics_path.exists() else {},
    }


@app.function(
    gpu="A10G",
    cpu=8,
    memory=49152,
    timeout=60 * 60 * 2,
    volumes={REMOTE_DATA_DIR: data_volume, REMOTE_OUTPUT_DIR: output_volume},
)
def generate_random_explanations(
    run_name: str = RUN_NAME,
    split: str = "eval",
    num_samples: int = 100,
    seed: int = 42,
    adapter_run_name: str = "",
) -> dict:
    train_path, eval_path = dataset_paths()
    data_path = eval_path if split == "eval" else train_path
    output_dir = Path(REMOTE_OUTPUT_DIR) / "generations" / run_name
    output_dir.mkdir(parents=True, exist_ok=True)

    command = [
        "python",
        "-u",
        "generate_explanations.py",
        "--model-name",
        MODEL_NAME,
        "--data-path",
        str(data_path),
        "--output-dir",
        str(output_dir),
        "--num-samples",
        str(num_samples),
        "--seed",
        str(seed),
    ]
    if adapter_run_name:
        command.extend(["--adapter-path", str(Path(REMOTE_OUTPUT_DIR) / adapter_run_name / "adapter")])
    run_command(command, cwd=Path(REMOTE_CODE_DIR))
    manifest_path = output_dir / "manifest.json"
    return json.loads(manifest_path.read_text()) if manifest_path.exists() else {"output_dir": str(output_dir)}


@app.local_entrypoint()
def main(mode: str = "smoke") -> None:
    if mode == "smoke":
        result = smoke_test.remote()
    elif mode == "train":
        result = finetune.remote()
    elif mode == "generate":
        result = generate_random_explanations.remote()
    else:
        raise ValueError("mode must be 'smoke', 'train', or 'generate'")
    print(json.dumps(result, indent=2))
