from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel, GPT2Tokenizer
from datasets import load_dataset
import argparse
import torch
import time
from llama_graft_trainer import Localizer, Stitcher
from transformers.data.data_collator import DataCollatorWithPadding
from alpaca_ds import get_alpaca_trainer
from math_ds import get_math_trainer
from qa_ds import get_qa_trainer

# datasets = [load_dataset("tatsu-lab/alpaca", split='train'), load_dataset("hendrycks/competition_math", split='train')]

base_model = "openai-community/gpt2-xl"
tokenizer = AutoTokenizer.from_pretrained(base_model)
pretrained_model = GPT2LMHeadModel.from_pretrained(base_model, torch_dtype=torch.bfloat16)
pretrained_model.resize_token_embeddings(len(tokenizer))

finetuned_names = ["Locutusque/gpt2-xl-conversational", "Onlydrinkwater/gpt2xl_language_math_520_10base", "Rachneet/gpt2-xl-alpaca"] #Harshkmr/gemma-2b-math
# mmlu, truthfulqa, arc
# alpaca, hotpotqa, math

i = 1
finetuned_model = GPT2LMHeadModel.from_pretrained(finetuned_names[i], torch_dtype=torch.bfloat16) 
finetuned_model.resize_token_embeddings(len(tokenizer))

model_name = finetuned_names[i].split('/')[1]

if i == 0:
    m_path = f"{model_name}_epoch_9_0.0001_4000000.0.pt"
elif i == 1:
    m_path = f"{model_name}_epoch_9_0.0001_1000000.0.pt"
elif i == 2:
    m_path = f"{model_name}_epoch_29_0.0001_5000000.0.pt"
mask = torch.load(f'/data/common/mergekit/masks/mask_'+m_path)
# mask = torch.load(f'/data/common/mergekit/masks/gpt2xl_language_math_520_10base_mask.pt')

def select_trainable_parameters(model):
    params = {}
    frozen = ['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.ln_f.weight', 'transformer.ln_f.bias']
    for n, p in model.named_parameters():
        if n not in frozen:
            # layer_id = int(n.split('.')[2])
            # if layer_id > 32:
            params[n] = p
                    
    return params


final_model = pretrained_model
trainable_params = select_trainable_parameters(final_model)
stitcher = Stitcher(trainable_params, final_model, pretrained_model, [finetuned_model], [mask])
merged_model = stitcher.interpolate_models()

path = f"/data/common/mergekit/grafted/"+m_path
merged_model.save_pretrained(path, safe_serialization=False)
tokenizer.save_pretrained(path)
