from __future__ import annotations

"""CPU-only smoke checks.

This verifies:
- submission package imports
- Stage-4 training runs on the tiny cached dataset and writes a checkpoint

Run from the submission root directory:
    python scripts/verify_cpu.py
"""

import sys
import tempfile
from pathlib import Path

# Add parent directory to path for relative imports when run as script
_SCRIPT_DIR = Path(__file__).resolve().parent
_ROOT_DIR = _SCRIPT_DIR.parent
if str(_ROOT_DIR) not in sys.path:
    sys.path.insert(0, str(_ROOT_DIR))

import torch


def main() -> None:
    # Import checks (now using package-relative imports)
    from training.train_estimator import main as train_main  # noqa: F401
    from inference.theta import uniform_theta  # noqa: F401

    # Run training on tiny cached features
    from training.dataset import iter_cached_batches
    from training.loss import pearson_corr_loss
    from training.model import LinearHeadEstimator

    data_dir = _ROOT_DIR / "assets" / "tiny_cached_features"
    batches = list(iter_cached_batches(data_dir))
    num_features = int(batches[0].features.shape[1])

    model = LinearHeadEstimator(num_features=num_features)
    opt = torch.optim.AdamW(model.parameters(), lr=1e-2)
    for _ in range(20):
        b = batches[0]
        feats = torch.from_numpy(b.features[None, :, :])
        outs = torch.from_numpy(b.outputs[None, :])
        loss = pearson_corr_loss(model(feats), outs)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
    model.finalize()

    with tempfile.TemporaryDirectory() as td:
        out = Path(td) / "theta.pt"
        torch.save(
            {"state_dict": {"linear.weight": model.linear.weight.detach().cpu()}}, out
        )
        ckpt = torch.load(out, map_location="cpu", weights_only=True)
        w = ckpt["state_dict"]["linear.weight"]
        assert tuple(w.shape) == (1, num_features)

    print("OK")


if __name__ == "__main__":
    main()
