import argparse
import os
import shutil
import torch
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM

def copy_extra_files(source_dir, dest_dir, files):
    """Copy extra files from source to destination if they exist."""
    for file in files:
        src_file = os.path.join(source_dir, file)
        if os.path.exists(src_file):
            if os.path.isdir(src_file):
                # If destination directory already exists, remove it first
                dest_subdir = os.path.join(dest_dir, file)
                if os.path.exists(dest_subdir):
                    shutil.rmtree(dest_subdir)
                shutil.copytree(src_file, dest_subdir)
            else:
                shutil.copy(src_file, dest_dir)

def main():
    parser = argparse.ArgumentParser(
        description="Blend model weights using the formula: TARGET_MODEL + alpha * (FORGET_MODEL - BASE_MODEL) + beta * (RETAIN_MODEL - BASE_MODEL) "
    )
    parser.add_argument("--base_model", type=str, required=True,
                        help="Path or model id of the BASE_MODEL")
    
    parser.add_argument("--forget_model", type=str, required=True,
                        help="Path or model id of the FORGET_MODEL")
    
    parser.add_argument("--retain_model", type=str, required=True,
                        help="Path or model id of the RETAIN_MODEL")
    
    parser.add_argument("--target_model", type=str, required=True,
                        help="Path or model id of the TARGET_MODEL")
    
    parser.add_argument("--alpha", type=float, required=True,
                        help="Alpha scaling parameter (must be negative)")
    
    parser.add_argument("--beta", type=float, required=True,
                        help="Alpha scaling parameter (must be positive)")
    
    parser.add_argument("--save_path", type=str, required=True,
                        help="Path where the final model will be saved")
    
    parser.add_argument("--base_revision", type=str, required=False, default=None)
    args = parser.parse_args()


    model_dtype=torch.bfloat16

    print(args.base_model, args.base_revision)
    # Load models with low CPU memory usage.
    base_model = AutoModelForCausalLM.from_pretrained(args.base_model, low_cpu_mem_usage=True, torch_dtype=model_dtype, revision=args.base_revision)
    forget_model = AutoModelForCausalLM.from_pretrained(args.forget_model, low_cpu_mem_usage=True, torch_dtype=model_dtype)
    retain_model = AutoModelForCausalLM.from_pretrained(args.retain_model, low_cpu_mem_usage=True, torch_dtype=model_dtype)
    target_model = AutoModelForCausalLM.from_pretrained(args.target_model, low_cpu_mem_usage=True, torch_dtype=model_dtype)

    # Extract state dictionaries for arithmetic.
    base_state = base_model.state_dict()
    forget_state = forget_model.state_dict()
    retain_state = retain_model.state_dict()

    if args.alpha > 0 or args.beta < 0:
        raise Exception("alpha and beta values not suitable", args.alpha, args.beta)

    
    avg_norm = 0.

    # Update the parameters of target_model in-place.
    with torch.no_grad():
        for name, param in target_model.named_parameters():
            if name in forget_state and name in retain_state and name in base_state:
                param.add_(args.alpha * (forget_state[name] - base_state[name]) + 
                            args.beta * (retain_state[name] - base_state[name]))
                
                avg_norm += (args.alpha * (forget_state[name] - base_state[name]) + 
                            args.beta * (retain_state[name] - base_state[name])).norm().item()
            else:
                raise ValueError(f"Parameter '{name}' not found in all models.")
            
    avg_norm /= len(target_model.state_dict())

    print(f"Average norm of the MSA: {avg_norm}")

    os.makedirs(args.save_path, exist_ok=True)
    target_model.save_pretrained(args.save_path)
    print(f"Model saved to {args.save_path}")

    # tokenizer = AutoTokenizer.from_pretrained(args.target_model)
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    tokenizer.save_pretrained(args.save_path)
    print("Tokenizer saved.")

    
    if os.path.isdir(args.target_model):
        extra_files = [
            "generation_config.json",
            "trainer_state.json",
            "FinetuneTrainer.log",
            "training_args.bin",
            "logs"  # This can be a directory.
        ]
        copy_extra_files(args.target_model, args.save_path, extra_files)
        print("Extra files copied from the target model directory, if they existed.")


if __name__ == "__main__":
    main()
