from model import PythiaLightningModule
import argparse
import glob
import os
import torch
from pathlib import Path

def load_local_checkpoint(checkpoint_dir, config_path):
    """Load a model from a local checkpoint"""
    print("Initializing model from config...")
    print(config_path)
    model = PythiaLightningModule(config_path=config_path)
    try:

        model_state_file = os.path.join(checkpoint_dir, "checkpoint", "mp_rank_00_model_states.pt")
        checkpoint = torch.load(model_state_file, map_location='cpu')
        state_dict = checkpoint.get('module', checkpoint)

        print("Loading state dict into model...")
        try:
            model.model.load_state_dict(state_dict)
        except:
            model.load_state_dict(state_dict, strict=False)
        print("Successfully loaded model from checkpoint")
        inner = model.model
        inner.to("cuda")
        inner.half()  
        inner.eval()  
        return inner  
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        raise


def main():
    cwd_str = str(Path.cwd())
    if "disk" in cwd_str:
        file_system = "disk"
    elif "share" in cwd_str:
        file_system = "share"

    parser = argparse.ArgumentParser()
    parser.add_argument("--model_type", type=str, required=True)
    parser.add_argument("--first_step", type=int, default=100)
    parser.add_argument("--last_step", type=int, default=20000)
    args = parser.parse_args()
    model_type = args.model_type
    last_step = args.last_step
    first_step = args.first_step
    for step in range(first_step, last_step, 100):
        checkpoint_dir = (
            f"~/pythia_replicate/trained_models/pythia_output_{model_type}/step={step}.ckpt"
        )
        config_path = f"~/pythia_replicate/trained_models/tensorboard_logs/pythia/{model_type}/hparams.yaml"
        model = load_local_checkpoint(checkpoint_dir, config_path)
        model.save_pretrained(
            f"~/pythia_replicate/hf_output/{model_type}/step={step}"
        )


if __name__ == "__main__":
    main()
