#!/usr/bin/env python
"""
Instantiate a base Col-model from an existing checkpoint and upload it to the Hub.

Example
-------
python upload_base_colmodel.py some-org/qwen_model-base \
    --model-class ColQwen2_5 \
    --repo-name your-org/colqwen-base \
"""

import argparse

from colpali_engine import models
from transformers import AutoModelForMaskedLM
from peft import PeftModel, AutoPeftModelForFeatureExtraction


def main(base_model_checkpoint: str, model_class: str, repo_name: str = None, merge: bool = False):
    """Instantiate a base Col-model and upload it to the Hugging Face Hub."""
    # append __li to the base model checkpoint to get the local save path
    local_save_path = f"{base_model_checkpoint}__li"

    # Resolve model and processor classes.
    model_class_ = getattr(models, model_class)
    processor_class_ = getattr(models, model_class + "Processor")

    # Build ColVBert.
    model = model_class_.from_pretrained(base_model_checkpoint)
    model.save_pretrained(local_save_path)
    # if repo_name is not None:
    #     model.push_to_hub(
    #         repo_id=repo_name,
    #         commit_message="Add base Col-model with randomly-initialized linear projection head",
    #     )
    #     print(f"✅  Model pushed to https://huggingface.co/{repo_name}")

    # load the processor
    processor = processor_class_.from_pretrained(base_model_checkpoint)
    processor.save_pretrained(local_save_path)
    # if repo_name is not None:
    #     processor.push_to_hub(
    #         repo_id=repo_name,
    #         commit_message="Add base Col-processor with randomly-initialized linear projection head",
    #     )
    #     print(f"✅  Processor pushed to https://huggingface.co/{repo_name}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("checkpoint", help="Existing base model checkpoint on the Hub or local path")
    parser.add_argument("--model-class", required=True, help="Model class to instantiate, e.g. `ColQwen2_5`")
    parser.add_argument("--repo-name", default=None, help="Destination repo name, e.g. `vidore/colvqwen2_5-base`")
    parser.add_argument("--merge", action='store_true')
    args = parser.parse_args()
    main(base_model_checkpoint=args.checkpoint, model_class=args.model_class, repo_name=args.repo_name, merge=args.merge)
