import argparse
import os
import torch
from safetensors.torch import load_file as torch_load_file
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import re


def calc_diff(tensor_a, tensor_b):


    a = tensor_a.float()
    b = tensor_b.to(a.device).float()
    
    diff = a - b
    mse = (diff ** 2).mean().item()
    
  
    norm_a = torch.norm(a)
    norm_diff = torch.norm(diff)
    rel_error = (norm_diff / (norm_a + 1e-8)).item()
    
    return mse, rel_error

def parse_args():
    parser = argparse.ArgumentParser(description="Merge MoBE experts into a MoE model with Error Check.")
    parser.add_argument("--base_model", required=True, help="Path to base model.")
    parser.add_argument("--mobe_dir", required=True, help="Directory containing MoBE weights.")
    parser.add_argument("--save_dir", required=True, help="Where to save the merged model.")
    parser.add_argument("--start_layer", type=int, default=0)
    parser.add_argument("--end_layer", type=int, default=94)
    parser.add_argument("--num_experts", type=int, default=128)
    parser.add_argument("--dtype", default="bfloat16", choices=["float16", "bfloat16", "float32"])
    parser.add_argument("--grouped_experts", action="store_true", help="Load grouped safetensors.")
    return parser.parse_args()

def load_safetensors(file_path: str) -> dict:
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File not found: {file_path}")
    return torch_load_file(file_path)

def load_grouped_tensors(mobe_dir: str, layer_idx: int, proj_name: str, num_experts: int) -> dict:
    pattern = re.compile(rf"model_layers_{layer_idx}_mlp_{proj_name}_group(\d+)\.safetensors$")
    group_files = []
    for f in os.listdir(mobe_dir):
        match = pattern.match(f)
        if match:
            group_files.append((int(match.group(1)), os.path.join(mobe_dir, f)))
    group_files.sort()

    if not group_files:
        raise FileNotFoundError(f"No grouped files found for layer {layer_idx}, proj {proj_name}")

    merged_dict = {}
    expert_idx = 0
    for _, file_path in group_files:
        tensors = load_safetensors(file_path)
        sorted_items = sorted(tensors.items(), key=lambda x: int(x[0].split('_')[1]))
        for k, v in sorted_items:
            new_key = f"experts_{expert_idx}_{proj_name}_weight"
            merged_dict[new_key] = v
            expert_idx += 1

    if expert_idx != num_experts:
        raise ValueError(f"Expected {num_experts} experts but loaded {expert_idx}")

    return merged_dict

def main():
    args = parse_args()

    dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
    torch_dtype = dtype_map[args.dtype]
    
    print(f"Loading base model from {args.base_model} (CPU)...")
    model = AutoModelForCausalLM.from_pretrained(
        args.base_model,
        torch_dtype=torch_dtype,
        device_map="cpu",
        trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
    state_dict = model.state_dict()

 
    global_gate_rel_sum = 0
    global_up_rel_sum = 0
    total_layers_count = 0

    print(f"Starting Merge & Error Check ({args.start_layer} -> {args.end_layer})...")

    for i in tqdm(range(args.start_layer, args.end_layer), desc="Processing Layers"):
        
        try:
            if args.grouped_experts:
                gate_dict = load_grouped_tensors(args.mobe_dir, i, "gate_proj", args.num_experts)
                up_dict = load_grouped_tensors(args.mobe_dir, i, "up_proj", args.num_experts)
            else:
                gate_path = os.path.join(args.mobe_dir, f"model_layers_{i}_mlp_gate_proj_weight.safetensors")
                up_path = os.path.join(args.mobe_dir, f"model_layers_{i}_mlp_up_proj_weight.safetensors")
                
       
                if not os.path.exists(gate_path) or not os.path.exists(up_path):
                    continue
                    
                gate_dict = load_safetensors(gate_path)
                up_dict = load_safetensors(up_path)
        except Exception as e:
            print(f"[Warning] Layer {i} skipped: {e}")
            continue

        gate_mse_acc = 0.0
        gate_rel_acc = 0.0
        gate_count = 0
        
  
        up_mse_acc = 0.0
        up_rel_acc = 0.0
        up_count = 0

        for j in range(args.num_experts):
       
            gate_key_hf = f'model.layers.{i}.mlp.experts.{j}.gate_proj.weight'
            gate_key_file = f'experts_{j}_gate_proj_weight'
            
            if gate_key_hf in state_dict and gate_key_file in gate_dict:
                new_w = gate_dict[gate_key_file]
                orig_w = state_dict[gate_key_hf]
                
                mse, rel = calc_diff(orig_w, new_w)
                gate_mse_acc += mse
                gate_rel_acc += rel
                gate_count += 1
                
     
                state_dict[gate_key_hf] = new_w.to(dtype=torch_dtype)

         
            up_key_hf = f'model.layers.{i}.mlp.experts.{j}.up_proj.weight'
            up_key_file = f'experts_{j}_up_proj_weight'
            
            if up_key_hf in state_dict and up_key_file in up_dict:
                new_w = up_dict[up_key_file]
                orig_w = state_dict[up_key_hf]
                
                mse, rel = calc_diff(orig_w, new_w)
                up_mse_acc += mse
                up_rel_acc += rel
                up_count += 1
                
          
                state_dict[up_key_hf] = new_w.to(dtype=torch_dtype)

       
        msg_parts = []
        if gate_count > 0:
            avg_gate_mse = gate_mse_acc / gate_count
            avg_gate_rel = gate_rel_acc / gate_count
            msg_parts.append(f"Gate [MSE: {avg_gate_mse:.6f} | Rel: {avg_gate_rel:.4%}]")
            
            global_gate_rel_sum += avg_gate_rel

        if up_count > 0:
            avg_up_mse = up_mse_acc / up_count
            avg_up_rel = up_rel_acc / up_count
            msg_parts.append(f"Up   [MSE: {avg_up_mse:.6f} | Rel: {avg_up_rel:.4%}]")
            
            global_up_rel_sum += avg_up_rel
            
        if gate_count > 0 or up_count > 0:
            total_layers_count += 1
    
            tqdm.write(f"  [Layer {i}] " + "  ".join(msg_parts))


    print("-" * 40)
    if total_layers_count > 0:
        print(f"Summary over {total_layers_count} layers:")
        print(f"  Avg Gate Relative Error: {global_gate_rel_sum/total_layers_count:.4%}")
        print(f"  Avg Up   Relative Error: {global_up_rel_sum/total_layers_count:.4%}")
    print("-" * 40)

    print("Loading merged state dict ...")
    model.load_state_dict(state_dict)
    print(f"Saving merged model to {args.save_dir}")
    model.save_pretrained(args.save_dir)
    tokenizer.save_pretrained(args.save_dir)
    print("Done.")

if __name__ == "__main__":
    main()
