from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel, GPT2Tokenizer
from datasets import load_dataset
import argparse
import torch
import time
from llama_graft_trainer import Localizer, Stitcher
from transformers.data.data_collator import DataCollatorWithPadding
from alpaca_ds import get_alpaca_trainer
from math_ds import get_math_trainer
from qa_ds import get_qa_trainer

# datasets = [load_dataset("tatsu-lab/alpaca", split='train'), load_dataset("hendrycks/competition_math", split='train')]

base_model = "openai-community/gpt2-xl"
tokenizer = AutoTokenizer.from_pretrained(base_model)
pretrained_model = GPT2LMHeadModel.from_pretrained(base_model, torch_dtype=torch.bfloat16)
pretrained_model.resize_token_embeddings(len(tokenizer))

finetuned_names = ["Locutusque/gpt2-xl-conversational", "Onlydrinkwater/gpt2xl_language_math_520_10base", "Rachneet/gpt2-xl-alpaca"] #Harshkmr/gemma-2b-math
# mmlu, truthfulqa, arc
# alpaca, hotpotqa, math

train_mask = False
parser = argparse.ArgumentParser()
# parser.add_argument("--id", type=int, default=0)
graft_args = parser.parse_args()
graft_args.learning_rate = 1e6
graft_args.sigmoid_bias = 2
graft_args.num_train_epochs = 40
graft_args.sparsity = 1e-4
graft_args.gradient_accumulation_steps = 4
graft_args.l1_strength = 0

# [3.5e6, 8e5, 8e5]

# print(lr_list)
print(graft_args.learning_rate)


def select_trainable_parameters(model):
    params = {}
    frozen = ['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.ln_f.weight', 'transformer.ln_f.bias']
    for n, p in model.named_parameters():
        if n not in frozen:
            # layer_id = int(n.split('.')[2])
            # if layer_id > 32:
            params[n] = p
                    
    return params

start_time = time.time()

finetuned_models = []
masks, proportions, tests = [], [], []
for i in range(len(finetuned_names)):
    finetuned_model = GPT2LMHeadModel.from_pretrained(finetuned_names[i], torch_dtype=torch.bfloat16) 
    finetuned_model.resize_token_embeddings(len(tokenizer))
    finetuned_models.append(finetuned_model)
    print(finetuned_names[i].split('/')[1])

    if train_mask:
        if i == 0:
            trainer = get_alpaca_trainer(pretrained_model, tokenizer)
        elif i == 1:
            trainer = get_qa_trainer(pretrained_model, tokenizer)
        elif i == 2:
            trainer = get_math_trainer(pretrained_model, tokenizer)

        dataloader = trainer.get_train_dataloader()
        final_model = GPT2LMHeadModel.from_pretrained(base_model)
        final_model.resize_token_embeddings(len(tokenizer))
        trainable_params = select_trainable_parameters(final_model)
        localizer = Localizer(trainable_params, final_model, pretrained_model, finetuned_model, graft_args, finetuned_names[i].split('/')[1])
        mask, proportion = localizer.train_graft(dataloader)
        torch.save(mask, f"/data/common/mergekit/masks/{finetuned_names[i].split('/')[1]}_mask.pt")
        # masks.append(mask)
        proportions.append(proportion.cpu().item())
    
    else:
        mask_dir = "/data/common/mergekit/masks/"
        mask_path = ["mask_gpt2-xl-conversational_epoch_39_0.0001_4000000.0.pt", "mask_gpt2xl_language_math_520_10base_epoch_39_0.0001_850000.0.pt", "mask_gpt2-xl-alpaca_epoch_39_0.0001_800000.0.pt"]
        masks = [torch.load(mask_dir + path) for path in mask_path]

localize_time = time.time() - start_time
final_model = pretrained_model
trainable_params = select_trainable_parameters(final_model)
stitcher = Stitcher(trainable_params, final_model, pretrained_model, finetuned_models, masks)
merged_model = stitcher.interpolate_models()
stitch_time = time.time() - start_time - localize_time

path = f"/data/common/mergekit/gpt2-xl-conversational_language_math_520_10base_alpaca_localize_stitch"
merged_model.save_pretrained(path, safe_serialization=False)
tokenizer.save_pretrained(path)

# with open('log.txt', 'a') as f:
#     for i in range(3):
#         f.write(f"{finetuned_names[i]}: {localize_time}, {stitch_time}\n")
#     f.write(f"Proportions: {proportions}\n")
#     f.write(lr_list)
