import torch
import os
import argparse
import shutil
import glob
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

def average_models(model_paths, output_path, alpha=0.5):
    """
    Load models using multiple GPUs and perform weight averaging.
    """

    model1 = AutoModelForCausalLM.from_pretrained(
        model_paths[0], 
        torch_dtype=torch.bfloat16,
        device_map="auto",
        low_cpu_mem_usage=True
    )

    model2 = AutoModelForCausalLM.from_pretrained(
        model_paths[1], 
        torch_dtype=torch.bfloat16,
        device_map="auto",
        low_cpu_mem_usage=True
    )

    state_dict1 = model1.state_dict()
    state_dict2 = model2.state_dict()

    all_keys = set(state_dict1.keys()).union(set(state_dict2.keys()))

    for key in tqdm(all_keys, desc="Averaging model weights"):
        if key in state_dict1 and key in state_dict2:
            weight1 = state_dict1[key]
            weight2 = state_dict2[key]

            compute_device = weight1.device
            weight2 = weight2.to(compute_device)

            state_dict1[key] = alpha * weight1 + (1 - alpha) * weight2
        elif key in state_dict2 and key not in state_dict1:
            state_dict1[key] = state_dict2[key].to(list(state_dict1.values())[0].device)

    del model2
    del state_dict2
    torch.cuda.empty_cache()

    model1.load_state_dict(state_dict1)

    model1.save_pretrained(output_path, safe_serialization=True)

    del model1
    del state_dict1
    torch.cuda.empty_cache()

    tokenizer = AutoTokenizer.from_pretrained(model_paths[1])
    tokenizer.save_pretrained(output_path)

    for filename in glob.glob(os.path.join(model_paths[1], '*')):
        base_name = os.path.basename(filename)
        if (not base_name.startswith('pytorch_model') and 
            not base_name.startswith('model.safetensors') and
            not base_name.endswith('.bin') and
            not base_name.endswith('.safetensors')):
            dest_file = os.path.join(output_path, base_name)
            if not os.path.exists(dest_file):
                shutil.copy(filename, output_path)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Weighted average of two Hugging Face model checkpoints.")
    parser.add_argument(
        'model_paths',
        nargs=2,
        type=str,
        help="Paths to the two models to average: base model and fine-tuned model."
    )
    parser.add_argument(
        '--output_path',
        type=str,
        required=True,
        help="Path to save the averaged model."
    )
    parser.add_argument(
        '--alpha',
        type=float,
        default=0.5,
        help="Weight for the first model (e.g., 0.5 means 50/50 average). The second model will be weighted as (1 - alpha)."
    )

    args = parser.parse_args()

    if not os.path.exists(args.output_path):
        os.makedirs(args.output_path)

    average_models(args.model_paths, args.output_path, args.alpha)