import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GPT2LMHeadModel, Trainer
from src.models import BertForPromptFinetuning, RobertaForPromptFinetuning, resize_token_type_embeddings
from src.tv_utils import *

root = "/data/common/lm-bff"
ROBERTA_PARAM = 163941810

model_fn = RobertaForPromptFinetuning
# modelname = "roberta-base"
modelname = "roberta-large"
cache_dir = root+"/model_files"
config = AutoConfig.from_pretrained(
            modelname,
            cache_dir=cache_dir,
        )
        
def get_model_path(task):
    if modelname == 'roberta-base':
        return "/data/common/lm-bff/ckpt_paths/log_noembed_SGD_graft/" + task + f"-prompt-64-0-{modelname}-2-2e-5/"
    elif modelname == 'roberta-large':
        return "/data/common/lm-bff/ckpt_paths/large_log_noembed_SGD_graft/" + task + f"-prompt-64-0-{modelname}-2-2e-5/"

def initialize_model(modelname):
    model = model_fn.from_pretrained(
        modelname,
        config=config,
        cache_dir=cache_dir,
    )
    return model

tokenizer = AutoTokenizer.from_pretrained(
    modelname,
    additional_special_tokens=[],
    cache_dir=cache_dir,
)

task_list = ["SST-2", "cr", "mr", "mpqa", "trec", "subj", "QNLI", "SNLI", "MNLI", "RTE", "MRPC", "QQP"]
ft_checks = [initialize_model(get_model_path(task)).state_dict() for task in task_list]
# ft_checks = [initialize_model(get_model_path("SST-2")).state_dict(), initialize_model(get_model_path("QNLI")).state_dict()]
ptm_check = initialize_model(modelname).state_dict()

# check if all checkpoints have the same paramters.
check_parameterNamesMatch(ft_checks + [ptm_check])

# Removing the two keys from state dict when creating the task vector.
# Basically these keys are not involved in the global operations like the computation of topk.
# remove_keys = [
#     "transformer.encoder.embed_tokens.weight",
#     "transformer.decoder.embed_tokens.weight",
# ]
remove_keys = ['roberta.embeddings.word_embeddings.weight', 'roberta.embeddings.position_embeddings.weight', 'roberta.embeddings.token_type_embeddings.weight', 'roberta.embeddings.LayerNorm.weight', 'roberta.embeddings.LayerNorm.bias', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.decoder.bias']

## TIES MERGING UTILS

def topk_values_mask(M, K=0.7, return_mask=False):
    if K > 1:
        K /= 100

    original_shape = M.shape
    if M.dim() == 1:
        M = M.unsqueeze(0)

    n, d = M.shape
    k = int(d * K)
    k = d - k  # Keep top k elements instead of bottom k elements

    # Find the k-th smallest element by magnitude for each row
    kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
    # Create a mask tensor with True for the top k elements in each row
    mask = M.abs() >= kth_values
    final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask

    if return_mask:
        return M * final_mask, final_mask.float().mean(dim=1), final_mask
    return M * final_mask, final_mask.float().mean(dim=1)


def resolve_zero_signs(sign_to_mult, method="majority"):
    majority_sign = torch.sign(sign_to_mult.sum())

    if method == "majority":
        sign_to_mult[sign_to_mult == 0] = majority_sign
    elif method == "minority":
        sign_to_mult[sign_to_mult == 0] = -1 * majority_sign
    return sign_to_mult


def resolve_sign(Tensor):
    sign_to_mult = torch.sign(Tensor.sum(dim=0))
    sign_to_mult = resolve_zero_signs(sign_to_mult, "majority")
    return sign_to_mult


def disjoint_merge(Tensor, merge_func, sign_to_mult):

    merge_func = merge_func.split("-")[-1]

    # If sign is provided then we select the corresponding entries and aggregate.
    if sign_to_mult is not None:
        rows_to_keep = torch.where(
            sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0
        )
        selected_entries = Tensor * rows_to_keep
    # Else we select all non-zero entries and aggregate.
    else:
        rows_to_keep = Tensor != 0
        selected_entries = Tensor * rows_to_keep

    if merge_func == "mean":
        non_zero_counts = (selected_entries != 0).sum(dim=0).float()
        disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp(
            non_zero_counts, min=1
        )
    elif merge_func == "sum":
        disjoint_aggs = torch.sum(selected_entries, dim=0)
    elif merge_func == "max":
        disjoint_aggs = selected_entries.abs().max(dim=0)[0]
        disjoint_aggs *= sign_to_mult
    else:
        raise ValueError(f"Merge method {merge_func} is not defined.")

    return disjoint_aggs


def ties_merging(
    flat_task_checks,
    reset_thresh=None,
    merge_func="",
):
    all_checks = flat_task_checks.clone()
    updated_checks, *_ = topk_values_mask(
        all_checks, K=reset_thresh, return_mask=False
    )
    print(f"RESOLVING SIGN")
    final_signs = resolve_sign(updated_checks)
    assert final_signs is not None
    
    print(f"Disjoint AGGREGATION: {merge_func}")
    merged_tv = disjoint_merge(updated_checks, merge_func, final_signs)
    
    return merged_tv


print(f"Flattening out Checkpoints")
flat_ft = torch.vstack(
    [state_dict_to_vector(check, remove_keys) for check in ft_checks]
)
flat_ptm = state_dict_to_vector(ptm_check, remove_keys)

# Creating Task vectors
tv_flat_checks = flat_ft - flat_ptm

# check if the vectorized state dicts can be converted back to the original state dicts
# covnert back the flat task vectors to state dict and see if the original and converted sd's are equal
assert check_state_dicts_equal(
        vector_to_state_dict(flat_ptm, ptm_check, remove_keys), ptm_check
    )
assert all(
    [
        check_state_dicts_equal(
            vector_to_state_dict(flat_ft[i], ft_checks[i], remove_keys), ft_checks[i]
        )
        for i in range(len(ft_checks))
    ]
)

# TIES Merging example
K = 0.8
merge_func = "dis-sum"
lamda = 0.3

# return merged flat task vector
merged_tv = ties_merging(
    tv_flat_checks,
    reset_thresh=K,
    merge_func=merge_func,
)

# add back the PTM to the flat merged task vector
merged_check = flat_ptm + lamda * merged_tv

# convert the flat merged checkpoint to a state dict
merged_state_dict = vector_to_state_dict(
    merged_check, ptm_check, remove_keys=remove_keys
)

if modelname == 'roberta-base':
    path = root+f"/ckpt_paths/merged_models/TIES_all_{lamda}_{K}-merged-{modelname}-2-2e-5"
elif modelname == 'roberta-large':
    path = root+f"/ckpt_paths/merged_models/large/TIES_all_{lamda}_{K}-merged-{modelname}-2-2e-5"
model_temp = initialize_model(modelname)
model_temp.load_state_dict(merged_state_dict)
model_temp.save_pretrained(path, safe_serialization=False)
tokenizer.save_pretrained(path)