import argparse
import os
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import sys
import gc


def get_activation(name):
    if name == "silu":
        return F.silu
    elif name == "none":
        return lambda x: x
    elif name == "gelu":
        return F.gelu
    else:
        return torch.tanh


def calc_diff(tensor_a, tensor_b):

  
 
    a = tensor_a.float()
    b = tensor_b.to(a.device).float()
    
    if a.shape != b.shape:
        print(f"[Warning] Shape mismatch: {a.shape} vs {b.shape}")
        return 0.0, 0.0

    diff = a - b
    mse = (diff ** 2).mean().item()
    
    norm_a = torch.norm(a)
    norm_diff = torch.norm(diff)
    rel_error = (norm_diff / (norm_a + 1e-8)).item()
    
    return mse, rel_error


def reconstruct_layer_weights(
    layer_idx,
    pth_path,
    num_total_experts,
    original_rows,
    original_cols,
    device,
    target_dtype=torch.float32
):

    loaded_state = torch.load(pth_path, map_location=device)
    
    config = loaded_state["_config"]
    allocated_ks = loaded_state["_allocated_ks"]
    
   
    K_max = config["K_max"]
    experts_per_group = config["experts_per_group"]
    act_name = config.get("activation", "silu")
    act = get_activation(act_name)
    
    # Global Std
    global_std = loaded_state["_global_std"]
    if not torch.is_tensor(global_std):
        global_std = torch.tensor(global_std, dtype=torch.float32, device=device)
    else:
        global_std = global_std.to(device).float()


    B_fixed = loaded_state["B_fixed"].to(device).float()
    C_params = loaded_state["C_params"].to(device).float()
    A_params = loaded_state["A_params"].to(device).float()
    a_params = loaded_state["a_params"].to(device).float()
    
    p_indices = loaded_state["p_indices"].to(device).long()
    p_scale = loaded_state["p_scale"].to(device).float()

    D = original_rows * original_cols
    reconstructed_sorted = []


    for e in range(num_total_experts):
        g_idx = e // experts_per_group
        sub_id = e % experts_per_group
        K_i = allocated_ks[g_idx]

        # Part A: Backbone
        A_e = A_params[e].unsqueeze(0)
        B_blend = torch.einsum("bg, gkc -> bkc", A_e, B_fixed)
        B_act = act(B_blend)[0]
        B_final = B_act[:K_i, :].contiguous()
        
        C_e = C_params[e].view(original_rows, K_max)
        C_final = C_e[:, :K_i]

        W_main = (C_final @ B_final) * global_std

        # Part B: Residual
        a_e = a_params[g_idx].unsqueeze(0)
        offset = sub_id * D
        curr_idx = p_indices[offset : offset + D]
        curr_scl = p_scale[offset : offset + D]
        
        R_flat = a_e.gather(1, curr_idx.unsqueeze(0)) * curr_scl.unsqueeze(0)
        W_res = R_flat.view(original_rows, original_cols) * global_std

        W_final = W_main + W_res
        reconstructed_sorted.append(W_final.to(dtype=target_dtype))

    permutation = loaded_state["_expert_permutation"].to(device)
    reconstructed_raw_order = [None] * num_total_experts

    for sorted_idx, original_idx in enumerate(permutation.tolist()):
        reconstructed_raw_order[original_idx] = reconstructed_sorted[sorted_idx]

    return reconstructed_raw_order


def parse_args():
    parser = argparse.ArgumentParser(description="Merge HybridMoBE .pth weights into a Base MoE model (Memory Optimized).")
    parser.add_argument("--base_model", required=True, help="Path to the base model.")
    parser.add_argument("--comp_dir", required=True, help="Directory containing .pth files.")
    parser.add_argument("--save_dir", required=True, help="Directory to save the merged model.")
    
    parser.add_argument("--start_layer", type=int, default=0)
    parser.add_argument("--end_layer", type=int, default=94)
    parser.add_argument("--num_experts", type=int, default=128)
    
    parser.add_argument("--rows", type=int, default=768, help="Rows per matrix (output dim).")
    parser.add_argument("--cols", type=int, default=2048, help="Cols per matrix (input dim).")
    parser.add_argument("--dtype", default="bfloat16", choices=["float16", "bfloat16", "float32"])
    
    return parser.parse_args()

def main():
    args = parse_args()
    
 
    dtype_map = {
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
        "float32": torch.float32
    }
    torch_dtype = dtype_map[args.dtype]
  
    compute_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using compute device for reconstruction: {compute_device}")
    
    print(f"Loading Base Model from {args.base_model} ...")

    model = AutoModelForCausalLM.from_pretrained(
        args.base_model,
        torch_dtype=torch_dtype,
        device_map="cpu", 
        trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
    
    target_matrices = ["gate_proj", "up_proj"]
    
    print(f"Starting Reconstruction (Layers {args.start_layer}-{args.end_layer})...")
    print("Optimization: Using In-Place updates to prevent RAM explosion.")

    global_mse_sum = 0
    global_rel_sum = 0
    total_matrices_count = 0

    for i in tqdm(range(args.start_layer, args.end_layer), desc="Processing Layers"):
        
    
        try:
            current_layer_module = model.model.layers[i]
        except AttributeError:
   
            current_layer_module = model.layers[i]

        layer_processed = False

        for m_type in target_matrices:

            pth_file = f"model_layers_{i}_mlp_{m_type}_Hybrid.pth"
            pth_path = os.path.join(args.comp_dir, pth_file)
            
       

            if not os.path.exists(pth_path):
            
                continue
            
            layer_processed = True


            restored_experts = reconstruct_layer_weights(
                layer_idx=i,
                pth_path=pth_path,
                num_total_experts=args.num_experts,
                original_rows=args.rows,
                original_cols=args.cols,
                device=compute_device,
                target_dtype=torch.float32
            )
            
            layer_mse = 0.0
            layer_rel = 0.0
            
    
            for expert_idx, rec_weight in enumerate(restored_experts):

                expert_module = current_layer_module.mlp.experts[expert_idx]
                target_linear = getattr(expert_module, m_type)
                
     
                mse, rel = calc_diff(target_linear.weight.data, rec_weight)
                layer_mse += mse
                layer_rel += rel
    
                new_cpu_weight = rec_weight.to(dtype=torch_dtype, device="cpu")
                
        
                target_linear.weight.data = new_cpu_weight
            

            avg_mse = layer_mse / args.num_experts
            avg_rel = layer_rel / args.num_experts
            
            global_mse_sum += avg_mse
            global_rel_sum += avg_rel
            total_matrices_count += 1
            
            tqdm.write(f"  [L{i} {m_type}] MSE: {avg_mse:.6f} | RelErr: {avg_rel:.4%}")

      
            del restored_experts
            torch.cuda.empty_cache()

        if layer_processed:
            gc.collect()


    print("-" * 40)
    if total_matrices_count > 0:
        final_mse = global_mse_sum / total_matrices_count
        final_rel = global_rel_sum / total_matrices_count
        print(f"Final Average MSE: {final_mse:.6f}")
        print(f"Final Average Relative Error: {final_rel:.4%}")
    print("-" * 40)
    

    print(f"Saving Final Merged Model to {args.save_dir} ...")
    
 
    model.save_pretrained(args.save_dir)
    tokenizer.save_pretrained(args.save_dir)
    print("All Done.")

if __name__ == "__main__":
    main()