task_list = ["SST-2", "cr", "mr", "mpqa", "trec", "subj", "QNLI", "SNLI", "MNLI", "RTE", "MRPC", "QQP"]
for sparsity in [1e-1,1e-2,1e-3]:
    for task in task_list:

        import torch

        from transformers import AutoConfig, AutoTokenizer, set_seed
        from src.models import RobertaForPromptFinetuning
        import numpy as np

        set_seed(0)

        root = "/data/common/lm-bff"
        ROBERTA_PARAM = 163941810
        model_fn = RobertaForPromptFinetuning
        modelname = "roberta-base"
        cache_dir = root+"/model_files"
        config = AutoConfig.from_pretrained(
                    modelname,
                    cache_dir=cache_dir,
                )
        tokenizer = AutoTokenizer.from_pretrained(
            modelname,
            additional_special_tokens=[],
            cache_dir=cache_dir,
        )


        def initialize_model(modelname):
            model = model_fn.from_pretrained(
                modelname,
                config=config,
                cache_dir=cache_dir,
            )
            return model


        def select_trainable_parameters(model):
            params = {}
            for n, p in model.named_parameters():
                if 'encoder.layer' in n:
                    params[n] = p
                            
            return params
            

        def get_average_masks(masks):
            
            def reciprocal_with_zero(tensor):
                mask = tensor == 0
                reciprocal = torch.reciprocal(tensor)
                reciprocal = reciprocal.masked_fill(mask, 0)
                return reciprocal

            output_masks = []
            for i in range(len(masks)):
                n_overlap = 0
                output_mask = masks[i].copy()
                # every other mask
                for j in range(len(masks)):
                    if i == j: continue
                    # every layer
                    for k in range(len(masks[i])):
                        intersect = torch.logical_and(masks[i][k], masks[j][k])
                        output_mask[k] = output_mask[k] + intersect
                        n_overlap += torch.sum(intersect)
                
                for k in range(len(masks[i])):
                    output_mask[k] = reciprocal_with_zero(output_mask[k])
                output_masks.append(output_mask)
            
                print("Overlap: ", n_overlap.item())

            return output_masks


        def select_trainable_parameters(model):
            params = {}
            for n, p in model.named_parameters():
                if 'encoder.layer' in n:
                    params[n] = p
                            
            return params
            

        def interpolate_models(model, pretrained_model, finetuned_models, masks, device="cpu"):
            pretrained_model.to(device)
            model.to(device)

            trainable_name = list(select_trainable_parameters(model).keys())

            for finetuned_model, mask in zip(finetuned_models, masks):
                finetuned_model.to(device)
                for counter in range(len(trainable_name)):
                    for pre_n, pre_p in pretrained_model.named_parameters():
                        if pre_n == trainable_name[counter]: 
                            pretensor = pre_p.to(device)

                    for fine_n, fine_p in finetuned_model.named_parameters():
                        if fine_n == trainable_name[counter]: 
                            finetensor = fine_p.to(device)

                    with torch.no_grad():            
                        for n, p in model.named_parameters():  
                            if n == trainable_name[counter]: 
                                mask[counter] = mask[counter].to(device)
                                p += mask[counter] * ( finetensor - pretensor ) 
            
            return model


        def create_binary_masks(pretrained_model, finetuned_model, sparsity_level):
                
            trainable_name = []
            trainable_parameters = []
            params = select_trainable_parameters(pretrained_model)
            
            for n in params: 
                trainable_name += [n]
                p = params[n]
                trainable_parameters += [ torch.rand_like( p.data, requires_grad=False) ] 
            
            num_params = sum([p.numel() for p in trainable_parameters])

            grad_directions = []
            for counter in range(len(trainable_name)):
                for pre_n, pre_p in pretrained_model.named_parameters():
                    if pre_n == trainable_name[counter]: pretensor = pre_p

                for fine_n, fine_p in finetuned_model.named_parameters():
                    if fine_n == trainable_name[counter]: finetensor = fine_p
                        
                grad_directions += [ (finetensor - pretensor).detach() ]    
            
            threshold = int(sparsity_level * num_params)  

            mask_flat = torch.cat([torch.ones(threshold),torch.zeros(num_params - threshold)]) 
            mask_flat = mask_flat[torch.randperm(num_params)]

            binary_masks = []
            current_index = 0
            for tensor in grad_directions:
                num_elements = tensor.numel()
                binary_mask = mask_flat[current_index:current_index + num_elements].reshape(tensor.shape)
                binary_masks.append(binary_mask)
                current_index += num_elements

            print('Total number of parameters: ', num_params)
            print ('Total parameters in my stitch: ', sum([ torch.sum(p*p) / (1. * num_params) for p in binary_masks ]))
            return binary_masks


        # sparsity = 0.0028
        ckpt_pth = "/data/common/lm-bff/ckpt_paths/log_noembed_SGD_graft/"

        task_list = ["SST-2", "cr", "mr", "mpqa", "trec", "subj", "QNLI", "SNLI", "MNLI", "RTE", "MRPC", "QQP"]
        # # task_list = ["QNLI", "MNLI"]
        # model_path_list = [ckpt_pth+f"{task}-prompt-64-0-roberta-base-2-2e-5" for task in task_list]

        # model = initialize_model(modelname)
        # pretrained_model = initialize_model(modelname)
        # finetuned_models = [initialize_model(model_path) for model_path in model_path_list]
        # masks = [create_binary_masks(pretrained_model, finetuned_model, sparsity) for finetuned_model in finetuned_models]

        # avg_masks = get_average_masks(masks)
        # merged_model = interpolate_models(model, pretrained_model, finetuned_models, avg_masks, device="cpu")

        # total_task = "-".join(task_list)
        # # path = f"/data/common/lm-bff/ckpt_paths/merged_models/dataless_graft_all_{sparsity}-merged-roberta-base-2-2e-5"
        # path = f"/data/common/lm-bff/ckpt_paths/merged_models/case_study/random_QNLI_MNLI_sparsity_{sparsity}-merged-roberta-base-2-2e-5"
        # # path = f"/data/common/lm-bff/ckpt_paths/merged_models/sparsity_case_study/random_all_{sparsity}-merged-roberta-base-2-2e-5"
        # merged_model.save_pretrained(path, safe_serialization=False)
        # tokenizer.save_pretrained(path)

        model_path = ckpt_pth+f"{task}-prompt-64-0-roberta-base-2-2e-5"

        model = initialize_model(modelname)
        pretrained_model = initialize_model(modelname)
        finetuned_model = initialize_model(model_path)
        mask = create_binary_masks(pretrained_model, finetuned_model, sparsity)
        merged_model = interpolate_models(model, pretrained_model, [finetuned_model], [mask], device="cpu")

        path = f"/data/common/lm-bff/ckpt_paths/merged_models/temp/random_{task}-{sparsity}-merged-roberta-base-2-2e-5"

        merged_model.save_pretrained(path, safe_serialization=False)
        tokenizer.save_pretrained(path)