import torch
import torch.nn as nn
import scipy.io as scio
import pickle

from transformers import AutoModelForCausalLM, AutoTokenizer
import pdb
from tqdm import tqdm
import pdb
from safetensors.torch import save_file
import json
import os
import shutil
import math

class Config:
    def __init__(self, params):
        self.params = params
    
    def get(self, key, default_value=None):
        if key in self.params.keys():
            return self.params[key]
        else:
            if default_value is not None:
                return default_value
            else:
                raise Exception(f"config does not have key {key}")

def copy_params_to_model(params: dict, model: nn.Module):
        """
        copy parameters in "params" to the model
        :param params: dict, dictionary of parameters
        :param model: nn.Module, model that needs to copy parameters
        :return:
        """
        for param_name, param_value in model.named_parameters():
            if param_name in params:
                param_value.data.copy_(params[param_name])


def do_merge(adapter_path, save_to_path):
    #with open(f"{save_to_path}/adapter_config.json", 'w', newline='\n', encoding="utf-8") as f:
    #    f.write(json.dumps(adapter_config,indent=1,ensure_ascii=False))
    #    f.close()

    with open(os.path.join(adapter_path, "adapter_config.json"), "r") as f:
        adapter_config = json.load(f)
        f.close()

    cfg = Config(adapter_config)

    Ins_model_path = cfg.get("base_model_name_or_path")
    rslora = cfg.get("use_rslora", False)
    rank = cfg.get("r")
    alpha = cfg.get("lora_alpha")
    target_modules = cfg.get("target_modules")
    num_experts = cfg.get("num_experts")
    num_shared_experts = 1

    if rslora:
        scaling = alpha / math.sqrt(rank)
    else:
        scaling = alpha / rank

    
    print(f"****We are merging the {Ins_model_path}****")
    print(f"****The adapter model is {adapter_path}****")
    print(f"****The model is saved to {save_to_path}****")
    
    Ins_model = AutoModelForCausalLM.from_pretrained(
        Ins_model_path, device_map="cpu"
    )
    finetuned_tokenizer = AutoTokenizer.from_pretrained(Ins_model_path)
    Ins_weights = {param_name:param_value for param_name, param_value in Ins_model.named_parameters()}
    adapter_weight = torch.load(os.path.join(adapter_path, "adapter_model.bin"))

    for param_name, param_value in tqdm(Ins_model.named_parameters()):
        #pdb.set_trace()
        flag = False
        for tm in target_modules:
            if tm in param_name:
                flag = True

        if not flag:
            continue

        # get the frozen weight
        fweight = param_value

        # get the adapter weight for each MoE
        layer_name = ".".join(param_name.split(".")[1:-1])

        # 1. get the shared magnitude
        pname = f"loramoe.{layer_name}.moe_gate.weight"
        shared_magnititude = adapter_weight[pname].view(-1)

        # 2. get the adapter weights
        loraA = [adapter_weight[f"loramoe.{layer_name}.experts.{i}.lora_A.weight"] for i in range(num_experts)]
        loraB = [adapter_weight[f"loramoe.{layer_name}.experts.{i}.lora_B.weight"] for i in range(num_experts)]

        # calculate the normalization
        for i, (lora_a_, lora_b_) in enumerate(zip(loraA, loraB)):
            if i < num_shared_experts:
                fweight = (lora_b_ @ lora_a_).detach().cpu() * scaling + fweight
            else:
                device = lora_a_.device
                input_dim = lora_a_.shape[-1]
                output_dim = lora_b_.shape[0]
                D = lora_a_.abs() @ torch.ones(input_dim, device=device, dtype=lora_a_.dtype)
                D = lora_b_.abs() @ D
                #import pdb
                #pdb.set_trace()
                lora_b_ = lora_b_ * scaling * shared_magnititude[:output_dim, None] * (1 / (D[:, None]+1e-8))
                fweight = (lora_b_ @ lora_a_).detach().cpu() + fweight
        
        param_value.data.copy_(fweight)

    print("starting saving the model")
    Ins_model.save_pretrained(save_directory=save_to_path)
    print("saving completed")
    json_files = next(os.walk(Ins_model_path))[2]
    json_files = [jf for jf in json_files if (".json" in jf or ".text" in jf)]
    ready_files = next(os.walk(save_to_path))[2]
    for jf in json_files:
        if jf != "model.safetensors.index.json" and jf not in ready_files:
            shutil.copy(os.path.join(Ins_model_path, jf), os.path.join(save_to_path, jf))

do_merge("/path/to/the/adapter", "/path/to/the/save")