import sys
import os, copy
import torch
import matplotlib.pyplot as plt
import numpy as np
import re
from collections import OrderedDict
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM


TASK2CHECKPOINT = {
    "bert-base-uncased-old": {
        "cola": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "mrpc": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "qqp": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "rte": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "wnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "sst2": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "mnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "qnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
    },
    "bert-base-uncased": {
        "mrpc": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "qqp": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "rte": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "wnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "sst2": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "mnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "qnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
    },
    "bert-large-uncased": {
        "mrpc": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "qqp": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "rte": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "wnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "sst2": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "mnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "qnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
    },
    "roberta-base": {
        "mrpc": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "qqp": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "rte": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "wnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "sst2": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "mnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "qnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
    },
    "roberta-large": {
        "mrpc": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "qqp": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "rte": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "wnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "sst2": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "mnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "qnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
    },
    "t5-base": {
        "mrpc": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "qqp": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "rte": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "wnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "sst2": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "mnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "qnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
    },
    "t5-large": {
        "mrpc": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "qqp": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "rte": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "wnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "sst2": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "mnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "qnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
    },
    "google/t5-v1_1-base": {
        "mrpc": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "qqp": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "rte": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "wnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "sst2": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "mnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "qnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
    },
    "google/t5-v1_1-large": {
        "mrpc": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "qqp": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "rte": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "wnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "sst2": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "mnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
        "qnli": {
            "none": "",
            "ia3": "",
            "lora": "",
        },
    },
    "bigscience/T0_3B": {
        "mrpc": {
            "ia3": "",
            "lora": "",
        },
        "qqp": {
            "ia3": "",
            "lora": "",
        },
        "rte": {
            "ia3": "",
            "lora": "",
        },
        "wnli": {
            "ia3": "",
            "lora": "",
        },
        "sst2": {
            "ia3": "",
            "lora": "",
        },
        "mnli": {
            "ia3": "",
            "lora": "",
        },
        "qnli": {
            "ia3": "",
            "lora": "",
        },
    },
    "llama-7b": {
        "guanaco": "timdettmers/guanaco-7b",
        "alpaca": "timdettmers/qlora-alpaca-7b",
        "chip2": "timdettmers/qlora-chip2-7b",
        "flan": "timdettmers/qlora-flan-7b",
        "hh-rlhf": "timdettmers/qlora-hh-rlhf-7b",
        "longform": "timdettmers/qlora-longform-7b",
        "self-instruct": "timdettmers/qlora-self-instruct-7b",
        "unnatural-instructions": "timdettmers/qlora-unnatural-instructions-7b",
    },
    "llama-13b": {
        "guanaco": "timdettmers/guanaco-13b",
        "alpaca": "timdettmers/qlora-alpaca-13b",
        "chip2": "timdettmers/qlora-chip2-13b",
        "flan": "timdettmers/qlora-flan-13b",
        "hh-rlhf": "timdettmers/qlora-hh-rlhf-13b",
        "longform": "timdettmers/qlora-longform-13b",
        "self-instruct": "timdettmers/qlora-self-instruct-13b",
        "unnatural-instructions": "timdettmers/qlora-unnatural-instructions-13b",
    },
    "llama-30b": {
        "guanaco": "timdettmers/guanaco-33b",
        "alpaca": "timdettmers/qlora-alpaca-33b",
        "chip2": "timdettmers/qlora-chip2-33b",
        "flan": "timdettmers/qlora-flan-33b",
        "hh-rlhf": "timdettmers/qlora-hh-rlhf-33b",
        "longform": "timdettmers/qlora-longform-33b",
        "self-instruct": "timdettmers/qlora-self-instruct-33b",
        "unnatural-instructions": "timdettmers/qlora-unnatural-instructions-33b",
    },
    "llama-65b": {
        "guanaco": "timdettmers/guanaco-65b",
        "alpaca": "timdettmers/qlora-alpaca-65b",
        "chip2": "timdettmers/qlora-chip2-65b",
        "flan": "timdettmers/qlora-flan-65b",
        "hh-rlhf": "timdettmers/qlora-hh-rlhf-65b",
        "longform": "timdettmers/qlora-longform-65b",
        "self-instruct": "timdettmers/qlora-self-instruct-65b",
        "unnatural-instructions": "timdettmers/qlora-unnatural-instructions-65b",
    },
}




def state_dict_to_vector(state_dict, remove_keys=[]):
    shared_state_dict = copy.deepcopy(state_dict)
    for key in remove_keys:
        if key in shared_state_dict:
            del shared_state_dict[key]
    sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
    return torch.nn.utils.parameters_to_vector(
        [value.reshape(-1).cpu() for key, value in sorted_shared_state_dict.items()]
    )


def vector_to_state_dict(vector, state_dict, remove_keys=[]):
    # create a reference dict to define the order of the vector
    reference_dict = copy.deepcopy(state_dict)
    for key in remove_keys:
        if key in reference_dict:
            del reference_dict[key]
    sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))

    # create a shared state dict using the refence dict
    torch.nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())

    # add back the encoder and decoder embedding weights.
    if "shared.weight" in sorted_reference_dict:
        for key in ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]:
            sorted_reference_dict[key] = sorted_reference_dict[ "shared.weight" ]
    return sorted_reference_dict


def add_ptm_to_tv(tv_dict, ptm_dict):
    assert set(tv_dict.keys()) == set(
        ptm_dict.keys()
    ), "Differing parameter names in models."
    final_dict = copy.deepcopy(tv_dict)
    for k, v in ptm_dict.items():
        final_dict[k] = tv_dict[k] + v
    return final_dict


def check_parameterNamesMatch(checkpoints):
    parameter_names = set(checkpoints[0].keys())

    if len(checkpoints) >= 2:
        # raise ValueError("Number of models is less than 2.")
        for checkpoint in checkpoints[1:]:
            current_parameterNames = set(checkpoint.keys())
            if current_parameterNames != parameter_names:
                raise ValueError(
                    "Differing parameter names in models. "
                    f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
                )

def check_state_dicts_equal(state_dict1, state_dict2):
    if set(state_dict1.keys()) != set(state_dict2.keys()):
        return False

    for key in state_dict1.keys():
        if not torch.equal(state_dict1[key], state_dict2[key]):
            return False

    return True


def topk_values_mask(M, K=0.7, return_mask=False):
    if K > 1:
        K /= 100

    original_shape = M.shape
    if M.dim() == 1:
        M = M.unsqueeze(0)

    n, d = M.shape
    k = int(d * K)
    k = d - k  # Keep top k elements instead of bottom k elements

    # Find the k-th smallest element by magnitude for each row
    kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
    # Create a mask tensor with True for the top k elements in each row
    mask = M.abs() >= kth_values
    final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask

    if return_mask:
        return M * final_mask, final_mask.float().mean(dim=1), final_mask
    return M * final_mask, final_mask.float().mean(dim=1)


def resolve_zero_signs(sign_to_mult, method="majority"):
    majority_sign = torch.sign(sign_to_mult.sum())

    if method == "majority":
        sign_to_mult[sign_to_mult == 0] = majority_sign
    elif method == "minority":
        sign_to_mult[sign_to_mult == 0] = -1 * majority_sign
    return sign_to_mult


def resolve_sign(Tensor):
    sign_to_mult = torch.sign(Tensor.sum(dim=0))
    sign_to_mult = resolve_zero_signs(sign_to_mult, "majority")
    return sign_to_mult


def disjoint_merge(Tensor, merge_func, sign_to_mult):

    merge_func = merge_func.split("-")[-1]

    # If sign is provided then we select the corresponding entries and aggregate.
    if sign_to_mult is not None:
        rows_to_keep = torch.where(
            sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0
        )
        selected_entries = Tensor * rows_to_keep
    # Else we select all non-zero entries and aggregate.
    else:
        rows_to_keep = Tensor != 0
        selected_entries = Tensor * rows_to_keep

    if merge_func == "mean":
        non_zero_counts = (selected_entries != 0).sum(dim=0).float()
        disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp(
            non_zero_counts, min=1
        )
    elif merge_func == "sum":
        disjoint_aggs = torch.sum(selected_entries, dim=0)
    elif merge_func == "max":
        disjoint_aggs = selected_entries.abs().max(dim=0)[0]
        disjoint_aggs *= sign_to_mult
    else:
        raise ValueError(f"Merge method {merge_func} is not defined.")

    return disjoint_aggs

def ties_merging(
    flat_task_checks,
    reset_thresh=None,
    merge_func="",
):
    if reset_thresh is not None:
        updated_checks, *_ = topk_values_mask(
            flat_task_checks, K=reset_thresh, return_mask=False
        )
    else:
        updated_checks = flat_task_checks
    print(f"RESOLVING SIGN")
    final_signs = resolve_sign(updated_checks)
    assert final_signs is not None

    print(f"Disjoint AGGREGATION: {merge_func}")
    merged_tv = disjoint_merge(updated_checks, merge_func, final_signs)

    return merged_tv

def tv_merging(tv_flat_checks):
    """Merging by creating and scaling Task Vectors"""
    tv_merged_check = torch.sum(tv_flat_checks, dim=0)
    return tv_merged_check

def basic_merging(flat_checks, sd_check, remove_keys):
    """ "Basic aggregation of the delta checks"""
    merged_check = torch.mean(flat_checks, dim=0)
    final_sd = vector_to_state_dict(merged_check, sd_check, remove_keys=remove_keys)
    return final_sd