#!/usr/bin/env python3
"""
Upload Best Checkpoints to HuggingFace Hub

CLI tool to upload best checkpoint directories from a VERL training run
to HuggingFace Hub. This is meant to be called after training completes.

Two modes:
  1. Default: Scan for existing best_checkpoint_* directories and upload them
  2. --from_wandb: Query WandB for eval metrics, build best_checkpoint_* dirs
     from VERL checkpoints, then upload

Usage:
    # Default mode (scan for existing dirs)
    python upload_best_checkpoints.py \\
        --output_dir /path/to/checkpoints/experiment \\
        --repo_id anon-neurips26/my-model \\
        --token hf_xxx

    # WandB mode (query metrics, build dirs, upload)
    python upload_best_checkpoints.py \\
        --output_dir /path/to/checkpoints/experiment \\
        --repo_id anon-neurips26/my-model \\
        --token hf_xxx \\
        --from_wandb \\
        --wandb_entity ${WANDB_ENTITY:-anonymous} \\
        --wandb_project qwen2.5-puzzle-grpo \\
        --wandb_run_name my_run_name

Environment variables:
    HF_TOKEN or HUGGING_FACE_HUB_TOKEN: HuggingFace API token (alternative to --token)
"""

import os
import sys
import argparse
from typing import Optional, Dict


def build_best_checkpoints_from_wandb(
    output_dir: str,
    wandb_entity: str,
    wandb_project: str,
    wandb_run_name: str,
    dry_run: bool = False
) -> Dict[str, dict]:
    """
    Query WandB for eval metrics and build best_checkpoint_* directories.

    Scans the WandB run history for keys matching 'val-core/*/acc/mean@1',
    finds the step with highest accuracy per eval dataset, then copies the
    LoRA adapter + tokenizer from that step's VERL checkpoint into a flat
    best_checkpoint_{dataset}/ directory.

    Args:
        output_dir: Directory containing global_step_* checkpoint folders
        wandb_entity: WandB entity/username
        wandb_project: WandB project name
        wandb_run_name: WandB run display name
        dry_run: If True, only print what would be done

    Returns:
        Dict mapping dataset_slug -> {"acc": float, "step": int, "dir": str}
    """
    import re
    import shutil
    from datetime import datetime

    try:
        import wandb
    except ImportError:
        print("ERROR: wandb package not installed. Install with: pip install wandb")
        return {}

    api = wandb.Api()

    # Find the run
    run_path = f"{wandb_entity}/{wandb_project}"
    print(f"Searching for WandB run '{wandb_run_name}' in {run_path}...")
    runs = api.runs(run_path, filters={"display_name": wandb_run_name})
    runs_list = list(runs)

    if not runs_list:
        print(f"ERROR: No WandB run found with name '{wandb_run_name}' in {run_path}")
        return {}

    run = runs_list[0]
    print(f"Found WandB run: {run.name} (id: {run.id})")

    # Scan history for eval metrics
    # Pattern: val-core/{dataset_slug}/acc/mean@1
    eval_pattern = re.compile(r'^val-core/(.+)/acc/mean@1$')

    # Track best accuracy per dataset
    best_per_dataset = {}  # dataset_slug -> {"acc": float, "step": int}

    print("Scanning WandB history for eval metrics...")
    row_count = 0
    for row in run.scan_history():
        row_count += 1
        # Get the training step - try trainer/global_step first, then _step
        step = row.get("trainer/global_step", row.get("_step"))
        if step is None:
            continue
        step = int(step)

        for key, value in row.items():
            match = eval_pattern.match(key)
            if match and value is not None:
                try:
                    value = float(value)
                except (TypeError, ValueError):
                    continue
                dataset_slug = match.group(1)
                if dataset_slug not in best_per_dataset or value > best_per_dataset[dataset_slug]["acc"]:
                    best_per_dataset[dataset_slug] = {"acc": value, "step": step}

    print(f"  Scanned {row_count} history rows")

    if not best_per_dataset:
        print("ERROR: No eval metrics found matching 'val-core/*/acc/mean@1'")
        return {}

    # Print best checkpoints
    print(f"\nBest checkpoints per eval dataset:")
    for dataset, info in sorted(best_per_dataset.items()):
        print(f"  {dataset}: acc={info['acc']:.4f} at step {info['step']}")

    # Build best_checkpoint_* dirs
    results = {}
    for dataset, info in sorted(best_per_dataset.items()):
        step = info["step"]
        acc = info["acc"]

        ckpt_dir = os.path.join(output_dir, f"global_step_{step}")
        lora_dir = os.path.join(ckpt_dir, "actor", "lora_adapter")
        hf_dir = os.path.join(ckpt_dir, "actor", "huggingface")

        if not os.path.exists(lora_dir):
            print(f"  WARNING: LoRA adapter not found at {lora_dir}, skipping {dataset}")
            continue

        best_dir_name = f"best_checkpoint_{dataset}"
        best_dir = os.path.join(output_dir, best_dir_name)

        if dry_run:
            print(f"  [Dry run] Would create {best_dir_name} from global_step_{step}")
            results[dataset] = {"acc": acc, "step": step, "dir": best_dir}
            continue

        # Create best checkpoint directory (clean if exists)
        if os.path.exists(best_dir):
            shutil.rmtree(best_dir)
        os.makedirs(best_dir)

        # Copy LoRA adapter files
        for filename in os.listdir(lora_dir):
            src = os.path.join(lora_dir, filename)
            dst = os.path.join(best_dir, filename)
            if os.path.isfile(src):
                shutil.copy2(src, dst)

        # Copy tokenizer/config files (don't overwrite adapter files)
        if os.path.exists(hf_dir):
            for filename in os.listdir(hf_dir):
                src = os.path.join(hf_dir, filename)
                dst = os.path.join(best_dir, filename)
                if os.path.isfile(src) and not os.path.exists(dst):
                    shutil.copy2(src, dst)

        # Write info file
        info_path = os.path.join(best_dir, "best_checkpoint_info.txt")
        with open(info_path, "w") as f:
            f.write(f"dataset: {dataset}\n")
            f.write(f"accuracy: {acc:.4f}\n")
            f.write(f"step: {step}\n")
            f.write(f"wandb_run: {wandb_entity}/{wandb_project}/{run.name}\n")
            f.write(f"wandb_run_id: {run.id}\n")
            f.write(f"created: {datetime.now().isoformat()}\n")

        print(f"  Created {best_dir_name} (from global_step_{step}, acc={acc:.4f})")
        results[dataset] = {"acc": acc, "step": step, "dir": best_dir}

    return results


def upload_best_checkpoints(
    output_dir: str,
    repo_id: str,
    token: Optional[str] = None,
    private: bool = True,
    dry_run: bool = False
) -> int:
    """
    Upload all best_checkpoint_* directories from output_dir to HuggingFace.

    Args:
        output_dir: Directory containing best_checkpoint_* folders
        repo_id: HuggingFace repo ID (e.g., "anon-neurips26/my-model")
        token: HuggingFace API token
        private: Whether to create a private repository
        dry_run: If True, only print what would be uploaded

    Returns:
        Number of checkpoints uploaded
    """
    from hf_upload import upload_checkpoint_to_hf

    # Find all best_checkpoint_* directories
    best_dirs = []
    for name in sorted(os.listdir(output_dir)):
        full_path = os.path.join(output_dir, name)
        if name.startswith("best_checkpoint_") and os.path.isdir(full_path):
            best_dirs.append((name, full_path))

    if not best_dirs:
        print(f"No best_checkpoint_* directories found in {output_dir}")
        return 0

    print(f"Found {len(best_dirs)} best checkpoint(s) to upload:")
    for name, path in best_dirs:
        # Check for adapter files
        has_adapter = os.path.exists(os.path.join(path, "adapter_model.safetensors"))
        has_info = os.path.exists(os.path.join(path, "best_checkpoint_info.txt"))

        status = []
        if has_adapter:
            status.append("adapter")
        if has_info:
            status.append("info")

        print(f"  - {name} ({', '.join(status) if status else 'empty'})")

    if dry_run:
        print("\n[Dry run] No uploads performed")
        return 0

    print(f"\nUploading to: {repo_id}")
    print("=" * 80)

    uploaded = 0
    errors = []

    for name, path in best_dirs:
        try:
            upload_checkpoint_to_hf(
                checkpoint_dir=path,
                repo_id=repo_id,
                token=token,
                path_in_repo=name,
                private=private
            )
            uploaded += 1
        except Exception as e:
            print(f"[Error] Failed to upload {name}: {e}")
            errors.append((name, str(e)))

    print("=" * 80)
    print(f"\nUpload summary:")
    print(f"  Successful: {uploaded}/{len(best_dirs)}")
    if errors:
        print(f"  Failed: {len(errors)}")
        for name, error in errors:
            print(f"    - {name}: {error}")

    print(f"\nRepository: https://huggingface.co/{repo_id}")

    return uploaded


def main():
    parser = argparse.ArgumentParser(
        description="Upload best checkpoints to HuggingFace Hub",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    # Upload all best checkpoints from a training run
    python upload_best_checkpoints.py \\
        --output_dir checkpoints/qwen2.5-puzzle-grpo/5x5dm_intformat_exact_t1.5 \\
        --repo_id anon-neurips26/qwen2.5-puzzle-grpo-5x5dm

    # Dry run to see what would be uploaded
    python upload_best_checkpoints.py \\
        --output_dir checkpoints/experiment \\
        --repo_id anon-neurips26/my-model \\
        --dry_run

    # WandB mode: query metrics, build best dirs, upload
    python upload_best_checkpoints.py \\
        --output_dir checkpoints/qwen2.5-puzzle-grpo/my_run \\
        --repo_id anon-neurips26/my-model \\
        --from_wandb \\
        --wandb_entity ${WANDB_ENTITY:-anonymous} \\
        --wandb_project qwen2.5-puzzle-grpo \\
        --wandb_run_name my_run_name

    # WandB dry run
    python upload_best_checkpoints.py \\
        --output_dir checkpoints/experiment \\
        --repo_id anon-neurips26/my-model \\
        --from_wandb \\
        --wandb_entity ${WANDB_ENTITY:-anonymous} \\
        --wandb_project qwen2.5-puzzle-grpo \\
        --wandb_run_name my_run \\
        --dry_run
"""
    )

    parser.add_argument("--output_dir", type=str, required=True,
                       help="Directory containing checkpoints (global_step_* or best_checkpoint_*)")
    parser.add_argument("--repo_id", type=str, required=True,
                       help="HuggingFace repo ID (e.g., anon-neurips26/my-model)")
    parser.add_argument("--token", type=str,
                       help="HuggingFace API token (uses HF_TOKEN env var if not set)")
    parser.add_argument("--public", action="store_true",
                       help="Create public repository (default: private)")
    parser.add_argument("--dry_run", action="store_true",
                       help="Only print what would be uploaded, don't actually upload")

    # WandB mode arguments
    parser.add_argument("--from_wandb", action="store_true",
                       help="Query WandB for eval metrics to find best checkpoints")
    parser.add_argument("--wandb_entity", type=str,
                       help="WandB entity/username (required with --from_wandb)")
    parser.add_argument("--wandb_project", type=str,
                       help="WandB project name (required with --from_wandb)")
    parser.add_argument("--wandb_run_name", type=str,
                       help="WandB run display name (required with --from_wandb)")

    args = parser.parse_args()

    # Validate output directory exists
    if not os.path.isdir(args.output_dir):
        print(f"Error: Output directory does not exist: {args.output_dir}")
        sys.exit(1)

    # Validate WandB args
    if args.from_wandb:
        missing = []
        if not args.wandb_entity:
            missing.append("--wandb_entity")
        if not args.wandb_project:
            missing.append("--wandb_project")
        if not args.wandb_run_name:
            missing.append("--wandb_run_name")
        if missing:
            parser.error(f"--from_wandb requires: {', '.join(missing)}")

    # Get token from environment if not provided
    token = args.token or os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")

    if not token and not args.dry_run:
        print("Warning: No HuggingFace token provided. Upload may fail.")
        print("Set HF_TOKEN environment variable or use --token argument.")

    # If --from_wandb, build best_checkpoint_* dirs from WandB metrics
    if args.from_wandb:
        results = build_best_checkpoints_from_wandb(
            output_dir=args.output_dir,
            wandb_entity=args.wandb_entity,
            wandb_project=args.wandb_project,
            wandb_run_name=args.wandb_run_name,
            dry_run=args.dry_run
        )

        if not results:
            print("No best checkpoints built from WandB metrics")
            sys.exit(1)

        if args.dry_run:
            sys.exit(0)

    # Upload best_checkpoint_* directories (both modes converge here)
    uploaded = upload_best_checkpoints(
        output_dir=args.output_dir,
        repo_id=args.repo_id,
        token=token,
        private=not args.public,
        dry_run=args.dry_run
    )

    sys.exit(0 if uploaded > 0 or args.dry_run else 1)


if __name__ == "__main__":
    main()
