import argparse
import yaml
import os
import torch
import json
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel


def load_and_merge_peft_model(peft_model_id, organization_name, new_model_name, upload_adapters_only=False):
    # Step 1: Locate the PEFT model config file
    config_file = os.path.join(peft_model_id, "adapter_config.json")
    
    if not os.path.exists(config_file):
        raise FileNotFoundError(f"Config file not found in the directory: {peft_model_id}")
    
    # Step 2: Load the config file to get the base model ID
    with open(config_file, 'r') as f:
        config = json.load(f)
    
    base_model_name = config.get('base_model_name_or_path')
    
    if not base_model_name:
        raise ValueError("Base model name not found in the adapter config.")
    
    if not upload_adapters_only:
        # Step 3: Load the base model and PEFT model
        base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
        model = PeftModel.from_pretrained(base_model, peft_model_id)

        # Step 4: Merge the PEFT model with the base model
        model = model.merge_and_unload()
        
        # Step 5: Load the tokenizer
        tokenizer = AutoTokenizer.from_pretrained(peft_model_id, truncation=True, padding=True)
    else:
        model = AutoModelForCausalLM.from_pretrained(
            peft_model_id,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        tokenizer = AutoTokenizer.from_pretrained(peft_model_id, truncation=True, padding=True)
    
    # Step 6: Save the merged model and tokenizer
    tmp_save_dir = os.path.join("save/tmp/", new_model_name)
    
    model.save_pretrained(tmp_save_dir, push_to_hub=True, repo_id=f"{organization_name}/{new_model_name}")
    tokenizer.push_to_hub(repo_id=f"{organization_name}/{new_model_name}")
    
    print(f"Successfully uploaded {new_model_name} to Hugging Face Hub under organization {organization_name}")


# Step 1: Parse arguments
parser = argparse.ArgumentParser(description='Upload model')
parser.add_argument('--config', type=str, required=True, help='Path to training config file')
parser.add_argument('--iter', type=int, help='Iteration ID to replace {iter} placeholders')
parser.add_argument('--upload_adapters_only', action='store_true', help='Only upload the PEFT adapters without merging the model')
args = parser.parse_args()

# Step 2: Load configuration
with open(args.config, 'r') as f:
    config = yaml.safe_load(f)
config = config['upload_models_to_hf']

# Step 3: Extract values from the config
model_id_path = config.get('model_id_path', None)
new_model_name = config.get('new_model_name', None)
organization_name = config['organization_name']
iter_value = args.iter

# Step 4: Check for placeholders and ensure --iter is provided if needed
def check_and_replace_placeholder(value, placeholder, replacement):
    if placeholder in value:
        if replacement is None:
            raise ValueError(f"{placeholder} placeholder found in the config, but no --iter argument was passed.")
        return value.replace(placeholder, str(replacement))
    return value

# Step 5: Check and replace {iter} in model_id_path and new_model_name if present
if model_id_path:
    model_id_path = check_and_replace_placeholder(model_id_path, "{iter}", iter_value)

if new_model_name:
    new_model_name = check_and_replace_placeholder(new_model_name, "{iter}", iter_value)

# Step 6: Function to identify the latest checkpoint from the directory
def get_latest_checkpoint(directory):
    # List all checkpoint directories in the given model_id_path
    checkpoints = [d for d in os.listdir(directory) if d.startswith("checkpoint")]
    
    if not checkpoints:
        raise ValueError(f"No checkpoints found in directory: {directory}")
    
    # Sort checkpoints by the checkpoint number (assuming format "checkpoint-<number>")
    checkpoints = sorted(checkpoints, key=lambda x: int(x.split('-')[-1]))
    
    # Get the latest checkpoint (the last one in the sorted list)
    latest_checkpoint = checkpoints[-1]
    return os.path.join(directory, latest_checkpoint)

# Step 7: Handle model_id_path logic, search for latest checkpoint, or use model_id directly
peft_model_id = get_latest_checkpoint(model_id_path)

# Ask user for confirmation
confirm = input(f"Do you want to upload {peft_model_id} to {organization_name}/{new_model_name}? [y/n] ").strip().lower()
if confirm != 'y':
    print("Upload cancelled by user.")
    exit()

# Step 8: Load the model and tokenizer
load_and_merge_peft_model(peft_model_id, organization_name, new_model_name, upload_adapters_only=args.upload_adapters_only)