import torch

from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GPT2LMHeadModel
from transformers import set_seed
from src.models import BertForPromptFinetuning, RobertaForPromptFinetuning, resize_token_type_embeddings

set_seed(0)

root = "/data/common/lm-bff"
ROBERTA_PARAM = 163941810
model_fn = RobertaForPromptFinetuning
modelname = "roberta-base"
# modelname = "roberta-large"
cache_dir = root+"/model_files"
config = AutoConfig.from_pretrained(
            modelname,
            cache_dir=cache_dir,
        )
tokenizer = AutoTokenizer.from_pretrained(
    modelname,
    additional_special_tokens=[],
    cache_dir=cache_dir,
)


def initialize_model(modelname):
    model = model_fn.from_pretrained(
        modelname,
        config=config,
        cache_dir=cache_dir,
    )
    return model


def select_trainable_parameters(model):
    params = {}
    for n, p in model.named_parameters():
        if 'encoder.layer' in n:
            params[n] = p
                    
    return params
       

def get_average_masks(masks):
    
    def reciprocal_with_zero(tensor):
        mask = tensor == 0
        reciprocal = torch.reciprocal(tensor)
        reciprocal = reciprocal.masked_fill(mask, 0)
        return reciprocal

    output_masks = []
    for i in range(len(masks)):
        output_mask = masks[i].copy()
        # every other mask
        for j in range(len(masks)):
            if i == j: continue
            # every layer
            for k in range(len(masks[i])):
                intersect = torch.logical_and(masks[i][k], masks[j][k])
                output_mask[k] = output_mask[k] + intersect
        
        for k in range(len(masks[i])):
            output_mask[k] = reciprocal_with_zero(output_mask[k])
        output_masks.append(output_mask)

    return output_masks
    

def interpolate_models(model, pretrained_model, finetuned_models, masks, device="cpu"):
    pretrained_model.to(device)
    model.to(device)

    trainable_name = list(select_trainable_parameters(model).keys())

    for finetuned_model, mask in zip(finetuned_models, masks):
        finetuned_model.to(device)
        for counter in range(len(trainable_name)):
            for pre_n, pre_p in pretrained_model.named_parameters():
                if pre_n == trainable_name[counter]: 
                    pretensor = pre_p.to(device)

            for fine_n, fine_p in finetuned_model.named_parameters():
                if fine_n == trainable_name[counter]: 
                    finetensor = fine_p.to(device)

            with torch.no_grad():            
                for n, p in model.named_parameters():  
                    if n == trainable_name[counter]: 
                        mask[counter] = mask[counter].to(device)
                        p += mask[counter] * ( finetensor - pretensor ) 
    
    return model


sparsity = "1e-2"
shots = 128

if modelname == "roberta-base":
    ckpt_pth = "/data/common/lm-bff/ckpt_paths/log_noembed_SGD_graft/"
    mask_pth = f"/data/common/lm-bff/mask_path/{sparsity}/"
elif modelname == "roberta-large":
    ckpt_pth = "/data/common/lm-bff/ckpt_paths/large_log_noembed_SGD_graft/"
    mask_pth = f"/data/common/lm-bff/mask_path/large/{sparsity}/"

task_list = ["SST-2", "cr", "mr", "mpqa", "trec", "subj", "QNLI", "SNLI", "MNLI", "RTE", "MRPC", "QQP"]
# task_list = ["QNLI", "MNLI"]
model_path_list = [ckpt_pth+f"{task}-prompt-64-0-{modelname}-2-2e-5" for task in task_list]
# mask_path_list = [mask_pth+f"mask_{task}-prompt-64-0-roberta-base-2-2e-5" for task in task_list]
mask_pth = f"/data/common/lm-bff/mask_path/{shots}-shot/"
mask_path_list = [mask_pth+f"mask_{task.lower()}_sparsity_1e-2_-prompt-64-0-{modelname}-2-2e-5" for task in task_list]
# mask_path_list = ["/data/common/lm-bff/mask_path/case_study/mask_qnli_l1_0.1-prompt-64-0-roberta-base-2-2e-5", "/data/common/lm-bff/mask_path/case_study/mask_mnli_l1_1-prompt-64-0-roberta-base-2-2e-5"]

model = initialize_model(modelname)
pretrained_model = initialize_model(modelname)
finetuned_models = [initialize_model(model_path) for model_path in model_path_list]
masks = [torch.load(mask_path) for mask_path in mask_path_list]

# check whether the loaded masks are indeed binary
all_binary = all(torch.all((tensor == 0) | (tensor == 1)) for mask_list in masks for tensor in mask_list)
assert all_binary, "Not all masks are binary."

avg_masks = get_average_masks(masks)
merged_model = interpolate_models(model, pretrained_model, finetuned_models, avg_masks, device="cpu")

total_task = "-".join(task_list)
# path = f"/data/common/lm-bff/ckpt_paths/merged_models/large/avg_graft_all_{sparsity}-merged-{modelname}-2-2e-5"
# path = f"/data/common/lm-bff/ckpt_paths/merged_models/disjoint/reverse_avg_graft_all_{sparsity}-merged-{modelname}-2-2e-5"
path = f"/data/common/lm-bff/ckpt_paths/merged_models/shots/avg_graft_all_{sparsity}_{shots}_shot-merged-{modelname}-2-2e-5"
# path = f"/data/common/lm-bff/ckpt_paths/merged_models/case_study/avg_graft_all_sparsity_0.001-merged-roberta-base-2-2e-5"
merged_model.save_pretrained(path, safe_serialization=False)
tokenizer.save_pretrained(path)
