# models.py
"""
Model catalog + snapshot downloader for Hugging Face repos.

Usage (from main.py):
    from models import get_model_catalog, DownloadOptions, download_models

    opts = DownloadOptions(local_dir="./hf_models", hf_token=os.getenv("HUGGINGFACE_HUB_TOKEN"))
    paths = download_models(models=["llama2-7b-chat", "phi3-medium-4k"], options=opts)
    print(paths)

CLI (optional):
    python models.py --all --local-dir ./hf_models
"""

from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence

from huggingface_hub import snapshot_download, HfApi, HfHubHTTPError


# ---------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------

def get_model_catalog() -> Dict[str, str]:
    """
    Return alias -> HF repo_id mapping for the main models used in the paper.
    """
    return {
        # Gated model (accept license on HF + provide token)
        "llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf",
        # Open-weight
        "phi3-medium-4k": "microsoft/Phi-3-medium-4k-instruct",
        # Gated model (70B)
        "llama-3.1-70b": "meta-llama/Llama-3.1-70B",
    }


@dataclass
class DownloadOptions:
    """
    Options controlling snapshot downloads.

    Args:
        local_dir: Root directory where models are stored (each alias becomes a subfolder).
        hf_token:  HF access token (required for gated repos; can be None if already logged in).
        revision:  Optional branch/tag/commit (e.g., "main").
        allow_patterns: Optional glob patterns to include (e.g., ["*.safetensors", "*.json"]).
        ignore_patterns: Optional glob patterns to exclude.
        max_workers: Max concurrent workers for download.
        use_symlinks: If True, use symlinks inside local_dir to reduce disk usage.
        local_files_only: If True, use only local cache (no network calls).
    """
    local_dir: str = "./hf_models"
    hf_token: Optional[str] = None
    revision: Optional[str] = None
    allow_patterns: Optional[Sequence[str]] = None
    ignore_patterns: Optional[Sequence[str]] = None
    max_workers: int = 8
    use_symlinks: bool = True
    local_files_only: bool = False


def download_models(
    models: Optional[Sequence[str]] = None,
    *,
    all_models: bool = False,
    options: Optional[DownloadOptions] = None,
) -> Dict[str, str]:
    """
    Snapshot-download one or more HF model repos.

    Args:
        models:     Iterable of model aliases to download (keys from get_model_catalog()).
        all_models: If True, download all models in the catalog (overrides 'models').
        options:    DownloadOptions; if None, defaults are used.

    Returns:
        Dict alias -> local_path for successfully downloaded models.

    Raises:
        ValueError: If neither 'models' nor 'all_models' is provided.
        RuntimeError: On download failure with a helpful message.
    """
    catalog = get_model_catalog()
    if not all_models and not models:
        raise ValueError("Provide 'models' or set all_models=True.")

    targets: List[str] = sorted(catalog.keys()) if all_models else list(models or [])
    opts = options or DownloadOptions()

    # Best-effort token check (useful for gated repos)
    if opts.hf_token:
        api = HfApi()
        try:
            who = api.whoami(token=opts.hf_token)
            print(f"[models] Authenticated as: {who.get('name') or who.get('email')}")
        except Exception as e:
            print(f"[models] Warning: could not verify token ({e}). Continuing…")

    os.makedirs(opts.local_dir, exist_ok=True)
    results: Dict[str, str] = {}

    for alias in targets:
        if alias not in catalog:
            raise RuntimeError(f"Unknown model alias '{alias}'. Valid: {list(catalog.keys())}")

        repo_id = catalog[alias]
        out_dir = os.path.join(opts.local_dir, alias)
        os.makedirs(out_dir, exist_ok=True)

        print(f"[models] Downloading '{alias}' -> {repo_id}")
        try:
            local_path = snapshot_download(
                repo_id=repo_id,
                repo_type="model",
                local_dir=out_dir,
                local_dir_use_symlinks=opts.use_symlinks,
                revision=opts.revision,
                allow_patterns=opts.allow_patterns,
                ignore_patterns=opts.ignore_patterns,
                token=opts.hf_token,
                max_workers=opts.max_workers,
                local_files_only=opts.local_files_only,
            )
            results[alias] = local_path
            print(f"[models] OK: {alias} -> {local_path}")
        except HfHubHTTPError as e:
            msg = (
                f"HTTP error while downloading '{repo_id}': {e}. "
                f"If this is a gated repo, ensure you've accepted the license and provided a valid token."
            )
            raise RuntimeError(msg) from e
        except Exception as e:
            raise RuntimeError(f"Failed to download '{repo_id}': {e}") from e

    return results


# ---------------------------------------------------------------------
# Optional CLI entry point (no side effects on import)
# ---------------------------------------------------------------------

def _cli():
    import argparse
    parser = argparse.ArgumentParser(description="Snapshot-download HF model repos.")
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("--all", action="store_true", help="Download all models in the catalog.")
    group.add_argument(
        "--models",
        nargs="+",
        choices=sorted(get_model_catalog().keys()),
        help="One or more model aliases to download.",
    )
    parser.add_argument("--local-dir", default="./hf_models")
    parser.add_argument("--hf-token", default=os.environ.get("HUGGINGFACE_HUB_TOKEN"))
    parser.add_argument("--revision", default=None)
    parser.add_argument("--allow-patterns", nargs="*", default=None)
    parser.add_argument("--ignore-patterns", nargs="*", default=None)
    parser.add_argument("--max-workers", type=int, default=8)
    parser.add_argument("--no-symlinks", action="store_true")
    parser.add_argument("--local-files-only", action="store_true")
    args = parser.parse_args()

    opts = DownloadOptions(
        local_dir=args.local_dir,
        hf_token=args.hf_token,
        revision=args.revision,
        allow_patterns=args.allow_patterns,
        ignore_patterns=args.ignore_patterns,
        max_workers=args.max_workers,
        use_symlinks=not args.no_symlinks,
        local_files_only=args.local_files_only,
    )
    results = download_models(models=args.models, all_models=args.all, options=opts)
    for alias, path in results.items():
        print(f"{alias}: {path}")


if __name__ == "__main__":
    _cli()
