import os

import pandas as pd

from utils import sample_checkpoints
from huggingface_hub import list_repo_refs

print("Sampling checkpoints ...")

models = [
    "OLMo-2-0425-1B", 
    "OLMo-2-1124-7B",
    "OLMo-2-1124-13B",
    "OLMo-2-0325-32B",
]

for model_id in models:
    print(f"Processing model: {model_id}")
    model_size = model_id.split("-")[-1].replace("B", "")
    
    out = list_repo_refs(f"allenai/{model_id}")

    all_branches = [b.name for b in out.branches]

    branches = []
    for b in all_branches:
        if b == "main":
            # skip the main checkpoints, these will be added after sampling.
            continue
        parts = b.split("-")

        stage = None
        step = None
        tokens = None
        ingredient = None

        for i, part in enumerate(parts):
            if i == 0: #stage
                stage = int(part.replace("stage", "")[-1])
            if i == 1: #ingredient or step
                if "step" in part:
                    step = int(part.replace("step", ""))
                elif "ingredient" in part:
                    ingredient = int(part.replace("ingredient", ""))
            
            if i == 2: #step or tokens
                if step == None: 
                    step = int(part.replace("step", ""))
                else:
                    tokens = int(part.replace("tokens", "").replace("B", ""))
            if i == 3: #tokens
                tokens = int(part.replace("tokens", "").replace("B", ""))

        
        branches.append({
            "model_id": model_id,
            "stage": b,
            "train_stage": stage,
            "step": step,
            "ingredient": ingredient,
            "tokens": tokens,
            "model_size": model_size,
            "phase": "Pre",
        })

    model_branches_df = pd.DataFrame(branches)
    model_branches_df.sort_values(["stage", "step", "tokens"], inplace=True)

    sample_chkpts = sample_checkpoints(model_branches_df, stage1_samples=50, stage2_samples=6)

    # Add the final "main" checkpoint for each stage
    print("Adding main checkpoints ...")
    main_checkpoints = pd.DataFrame([{"model_id": f"{model_id}", "stage": "main", "model_size": model_size, "phase":"Base"},
                                {"model_id": f"{model_id}-SFT", "stage": "main", "model_size": model_size, "phase":"SFT"},
                                {"model_id": f"{model_id}-DPO", "stage": "main", "model_size": model_size, "phase":"DPO"},
                                {"model_id": f"{model_id}-Instruct", "stage": "main", "model_size": model_size, "phase":"Instruct"},])
    
    chkpts_df = pd.concat([sample_chkpts, main_checkpoints], ignore_index=True)

    print("Saving sampled checkpoints to JSONL ...")
    chkpts_df[["model_id", "stage", "model_size", "phase"]].to_json("olmo2_sampled_branches.jsonl", 
                                                                    orient="records", 
                                                                    lines=True, 
                                                                    mode="a" if os.path.exists("olmo2_sampled_branches.jsonl") else "w")

print("Done.\nTime for coffee!")