
import torch

from transformers import AutoConfig, AutoTokenizer, set_seed
from src.models import RobertaForPromptFinetuning
import numpy as np

set_seed(0)

root = "/data/common/lm-bff"
ROBERTA_PARAM = 163941810
model_fn = RobertaForPromptFinetuning
# modelname = "roberta-base"
modelname = "roberta-large"
cache_dir = root+"/model_files"
config = AutoConfig.from_pretrained(
            modelname,
            cache_dir=cache_dir,
        )
tokenizer = AutoTokenizer.from_pretrained(
    modelname,
    additional_special_tokens=[],
    cache_dir=cache_dir,
)


def initialize_model(modelname):
    model = model_fn.from_pretrained(
        modelname,
        config=config,
        cache_dir=cache_dir,
    )
    return model


def select_trainable_parameters(model):
    params = {}
    for n, p in model.named_parameters():
        if 'encoder.layer' in n:
            params[n] = p
                    
    return params
    

def get_average_masks(masks):
    
    def reciprocal_with_zero(tensor):
        mask = tensor == 0
        reciprocal = torch.reciprocal(tensor)
        reciprocal = reciprocal.masked_fill(mask, 0)
        return reciprocal

    output_masks = []
    for i in range(len(masks)):
        n_overlap = 0
        output_mask = masks[i].copy()
        # every other mask
        for j in range(len(masks)):
            if i == j: continue
            # every layer
            for k in range(len(masks[i])):
                intersect = torch.logical_and(masks[i][k], masks[j][k])
                output_mask[k] = output_mask[k] + intersect
                n_overlap += torch.sum(intersect)
        
        for k in range(len(masks[i])):
            output_mask[k] = reciprocal_with_zero(output_mask[k])
        output_masks.append(output_mask)
    
        print("Overlap: ", n_overlap.item())

    return output_masks


def select_trainable_parameters(model):
    params = {}
    for n, p in model.named_parameters():
        if 'encoder.layer' in n:
            params[n] = p
                    
    return params
    

def interpolate_models(model, pretrained_model, finetuned_models, masks, device="cuda"):
    pretrained_model.to(device)
    model.to(device)

    trainable_name = list(select_trainable_parameters(model).keys())

    for finetuned_model, mask in zip(finetuned_models, masks):
        finetuned_model.to(device)
        for counter in range(len(trainable_name)):
            for pre_n, pre_p in pretrained_model.named_parameters():
                if pre_n == trainable_name[counter]: 
                    pretensor = pre_p.to(device)

            for fine_n, fine_p in finetuned_model.named_parameters():
                if fine_n == trainable_name[counter]: 
                    finetensor = fine_p.to(device)

            with torch.no_grad():            
                for n, p in model.named_parameters():  
                    if n == trainable_name[counter]: 
                        mask[counter] = mask[counter].to(device)
                        p += mask[counter] * ( finetensor - pretensor ) 
    
    return model


def create_binary_masks(pretrained_model, finetuned_model, sparsity_level):
        
    trainable_name = []
    trainable_parameters = []
    params = select_trainable_parameters(pretrained_model)
    
    for n in params: 
        trainable_name += [n]
        p = params[n]
        trainable_parameters += [ torch.rand_like( p.data, requires_grad=False) ] 
    
    num_params = sum([p.numel() for p in trainable_parameters])

    grad_directions = []
    for counter in range(len(trainable_name)):
        for pre_n, pre_p in pretrained_model.named_parameters():
            if pre_n == trainable_name[counter]: pretensor = pre_p

        for fine_n, fine_p in finetuned_model.named_parameters():
            if fine_n == trainable_name[counter]: finetensor = fine_p
                
        grad_directions += [ (finetensor - pretensor).detach() ]    
    
    threshold = int(sparsity_level * num_params)  

    abs_tv = []
    for p in grad_directions:
        abs_tv.append(torch.abs(p).view(-1))

    abs_tv = torch.cat(abs_tv)
    k = int(sparsity_level * abs_tv.numel())  # 1% of the total number of elements

    # Get the k largest values; returns values and their indices
    values, indices = torch.topk(abs_tv.view(-1), k)
    threshold = values.min()

    basepatch = [torch.zeros_like(p, requires_grad=False) for p in trainable_parameters]

    for p, q in zip(grad_directions, basepatch):
        q[torch.absolute(p) > threshold] = 1.

    print('Total number of parameters: ', num_params)
    print ('Total parameters in my stitch: ', sum([ torch.sum(p*p) / (1. * num_params) for p in basepatch ]))
    return basepatch


sparsity = 0.05
if modelname == 'roberta-base':
    ckpt_pth = "/data/common/lm-bff/ckpt_paths/log_noembed_SGD_graft/"
elif modelname == 'roberta-large':
    ckpt_pth = "/data/common/lm-bff/ckpt_paths/large_log_noembed_SGD_graft/"

task_list = ["SST-2", "cr", "mr", "mpqa", "trec", "subj", "QNLI", "SNLI", "MNLI", "RTE", "MRPC", "QQP"]
# task_list = ["QNLI", "MNLI"]
model_path_list = [ckpt_pth+f"{task}-prompt-64-0-{modelname}-2-2e-5" for task in task_list]

model = initialize_model(modelname)
pretrained_model = initialize_model(modelname)
finetuned_models = [initialize_model(model_path) for model_path in model_path_list]
masks = [create_binary_masks(pretrained_model, finetuned_model, sparsity) for finetuned_model in finetuned_models]

avg_masks = get_average_masks(masks)
merged_model = interpolate_models(model, pretrained_model, finetuned_models, avg_masks, device="cuda")

total_task = "-".join(task_list)
if modelname == 'roberta-base':
    path = f"/data/common/lm-bff/ckpt_paths/merged_models/dataless_graft_all_{sparsity}-merged-{modelname}-2-2e-5"
elif modelname == 'roberta-large':
    path = f"/data/common/lm-bff/ckpt_paths/merged_models/large/dataless_graft_all_{sparsity}-merged-{modelname}-2-2e-5"

# path = f"/data/common/lm-bff/ckpt_paths/merged_models/case_study/dataless_QNLI_MNLI_sparsity_{sparsity}-merged-roberta-base-2-2e-5"
# path = f"/data/common/lm-bff/ckpt_paths/merged_models/sparsity_case_study/dataless_all_sparsity_{sparsity}-merged-roberta-base-2-2e-5"
merged_model.save_pretrained(path, safe_serialization=False)
tokenizer.save_pretrained(path)
