import json
import argparse

import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--sft_model_path", type=str, required=True)
    parser.add_argument("--alpha_model_path", type=str, required=True)
    parser.add_argument("--gradient_path", type=str, required=True)
    args = parser.parse_args()

    with open("./gradient_alpha_history.txt", "r",encoding="utf-8") as f:
        for line in tqdm(f):
            existed = json.loads(line)
            if existed["sft_model_path"] == args.sft_model_path and existed["alpha_model_path"] == args.alpha_model_path and existed["gradient_path"] == args.gradient_path:
                print("Already existed")
                exit(0)
    sft_model = AutoModelForCausalLM.from_pretrained(args.sft_model_path)
    alpha_model = AutoModelForCausalLM.from_pretrained(args.alpha_model_path)

    sft_param_dict = {param_name: param_value for param_name, param_value in sft_model.named_parameters()}
    alpha_param_dict = {param_name: param_value for param_name, param_value in alpha_model.named_parameters()}
    alpha = {param_name: alpha_param_dict[param_name] - sft_param_dict[param_name] for param_name in sft_param_dict}
    gradient = torch.load(args.gradient_path)

    product = {}
    for param_name in sft_param_dict:
        if param_name not in gradient:
            gparam = "module." + param_name
            if gparam not in gradient:
                print(f"param_name: {param_name} not in gradient")
                exit(1)
            else:
                gradient[param_name] = gradient[gparam]
        product[param_name] = torch.sum(alpha[param_name] * gradient[param_name]).item()
    
    with open("./gradient_alpha_history.txt", "a") as f:
        f.write(json.dumps({
            "sft_model_path": args.sft_model_path,
            "alpha_model_path": args.alpha_model_path,
            "gradient_path": args.gradient_path,
            "gradient*alpha": product
        }) + "\n")
