import os
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

TAMPERING_HOME = os.getenv("TAMPERING_HOME")

model_names = [
    # model names
]

models = [AutoModelForSequenceClassification.from_pretrained(name) for name in model_names]

state_dicts = [m.state_dict() for m in models]

avg_state_dict = {}
for key in state_dicts[0]:
    avg_state_dict[key] = (
        state_dicts[0][key]
        + state_dicts[1][key]
    ) / 2.0

avg_model = AutoModelForSequenceClassification.from_pretrained(model_names[0], torch_dtype=torch.bfloat16)
avg_model.load_state_dict(avg_state_dict)

save_path = f"{TAMPERING_HOME}/models/AT-qwen2.5-7b-hhrlhf-5120-warm-ai"
avg_model.save_pretrained(save_path, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_names[0])
tokenizer.save_pretrained(save_path)
