import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GPT2LMHeadModel, Trainer
from src.models import BertForPromptFinetuning, RobertaForPromptFinetuning, resize_token_type_embeddings
from src.tv_utils import *


root = "/data/common/lm-bff"
ROBERTA_PARAM = 163941810

model_fn = RobertaForPromptFinetuning
# modelname = "roberta-base"
modelname = "roberta-large"
cache_dir = root+"/model_files"
config = AutoConfig.from_pretrained(
            modelname,
            cache_dir=cache_dir,
        )
        
def get_model_path(task):
    if modelname == "roberta-base":
        return "/data/common/lm-bff/ckpt_paths/log_noembed_SGD_graft/" + task + f"-prompt-64-0-{modelname}-2-2e-5/"
    elif modelname == "roberta-large":
        return "/data/common/lm-bff/ckpt_paths/large_log_noembed_SGD_graft/" + task + f"-prompt-64-0-{modelname}-2-2e-5/"

def initialize_model(modelname):
    model = model_fn.from_pretrained(
        modelname,
        config=config,
        cache_dir=cache_dir,
    )
    return model

tokenizer = AutoTokenizer.from_pretrained(
    modelname,
    additional_special_tokens=[],
    cache_dir=cache_dir,
)

task_list = ["SST-2", "cr", "mr", "mpqa", "trec", "subj", "QNLI", "SNLI", "MNLI", "RTE", "MRPC", "QQP"]
# task_list = ["MNLI", "QNLI"]
ft_checks = [initialize_model(get_model_path(task)).state_dict() for task in task_list]
ptm_check = initialize_model(modelname).state_dict()

flat_ft = torch.vstack(
    [state_dict_to_vector(check) for check in ft_checks]
)
flat_ptm = state_dict_to_vector(ptm_check)

# Creating Task vectors
tv_flat_checks = flat_ft - flat_ptm

## TASK VECTOR MERGING UTILS

def aggregate(T, agg_type, dim=0):
    if agg_type == "mean":
        result = torch.mean(T, dim=dim)
    elif agg_type == "sum":
        result = torch.sum(T, dim=dim)
    else:
        raise ValueError("Invalid agg_type: %s" % agg_type)

    return result

def tv_merging(tv_flat_checks):
    """Merging by creating and scaling Task Vectors"""
    all_checks = tv_flat_checks.clone()
    tv_merged_check = aggregate(all_checks, "sum")
    return tv_merged_check

# Task Vector Merging example
lamda = 0.3

merged_tv = tv_merging(tv_flat_checks)
merged_check = flat_ptm + lamda * merged_tv
merged_state_dict = vector_to_state_dict(
    merged_check, ptm_check#, remove_keys=remove_keys
)

# path = root + f"/ckpt_paths/merged_models/tv_{lamda}_all-merged-roberta-base-2-2e-5"
if modelname == "roberta-base":
    path = root + f"/ckpt_paths/merged_models/tv_{lamda}_QNLI-MNLI-merged-roberta-base-2-2e-5"
elif modelname == "roberta-large":
    path = root + f"/ckpt_paths/merged_models/large/tv_{lamda}_all-merged-roberta-large-2-2e-5"
model_temp = initialize_model(modelname)
model_temp.load_state_dict(merged_state_dict)
model_temp.save_pretrained(path, safe_serialization=False)
tokenizer.save_pretrained(path)