import os
import torch
from transformers import AutoModelForCausalLM


model_name = "models/qwen3-30b-a3b"
# # model_name = "/deepseek-moe-16b-base"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="cpu",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
)


DEEPSEEK_PROJ_ALIASES = {
    "gate_proj": ["gate_proj", "w1", "gate"],
    "up_proj":   ["up_proj",   "w3", "up"],
    "down_proj": ["down_proj", "w2", "down"],
}

def get_attr_with_alias(obj, names):
    for n in names:
        if hasattr(obj, n):
            return getattr(obj, n), n
    raise AttributeError(f"None of aliases {names} exist in {type(obj)}")

def get_expert_hub(layer):
    if hasattr(layer, "block_sparse_moe") and hasattr(layer.block_sparse_moe, "experts"):
        experts = layer.block_sparse_moe.experts
        proj_keys = ["w1", "w2", "w3"]  
        def _getter(expert, logical_key):
           
            if not hasattr(expert, logical_key):
                raise AttributeError(f"Expert has no attr '{logical_key}'")
            mod = getattr(expert, logical_key)
            return mod, logical_key
        return experts, proj_keys, _getter

    if hasattr(layer, "mlp") and hasattr(layer.mlp, "experts"):
        experts = layer.mlp.experts
        proj_keys = ["gate_proj", "up_proj", "down_proj"]
        def _getter(expert, logical_key):
            return get_attr_with_alias(expert, DEEPSEEK_PROJ_ALIASES[logical_key])
        return experts, proj_keys, _getter

    return None, None, None

save_dir = "/merged_experts_qwen3_avg"
os.makedirs(save_dir, exist_ok=True)

with torch.no_grad():
    for layer_idx, layer in enumerate(model.model.layers):
        
        if layer_idx == 0:
            continue

        experts, proj_keys, getter = get_expert_hub(layer)
        if experts is None:
            # print(f"[Skip] layer {layer_idx} has no experts")
            continue

        expert_num = len(experts)
        if expert_num == 0:
            continue

        merged_expert = {}

        for logical_proj in proj_keys:
            
            weights = []
            for expert_idx in range(expert_num):
                proj_module, real_name = getter(experts[expert_idx], logical_proj)
                W = proj_module.weight.detach().cpu().float()
                weights.append(W)

            
            merged_weight = torch.stack(weights, dim=0).mean(dim=0)
            merged_expert[logical_proj] = merged_weight
        
        out_path = os.path.join(save_dir, f"merged_expert_layer_{layer_idx}.pt")
        torch.save(merged_expert, out_path)
        print(f"[AVG] Layer {layer_idx} merged -> {list(merged_expert.keys())}")



# ######################################### merge (average) ##########################################


model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="cpu",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
)

save_dir = "/deepseek/merged_experts_mixtral_avg"
os.makedirs(save_dir, exist_ok=True)

with torch.no_grad():
    for layer_idx, layer in enumerate(model.model.layers):
        
        if layer_idx == 0:
            continue

        if not hasattr(layer.mlp, 'experts'):
            continue

        experts = layer.mlp.experts
        expert_num = len(experts)
        if expert_num == 0:
            continue

        merged_expert = {}  

        for logical_proj in ['gate_proj', 'up_proj', 'down_proj']:
            
            expert_weights = []
            for expert_idx in range(expert_num):
                proj_module, real_name = get_attr_with_alias(experts[expert_idx], PROJ_ALIASES[logical_proj])
                W = proj_module.weight.detach().cpu().float()
                expert_weights.append(W)

          
            merged_weight = torch.stack(expert_weights, dim=0).mean(dim=0)

            merged_expert[logical_proj] = merged_weight

        torch.save(merged_expert, os.path.join(save_dir, f"merged_expert_layer_{layer_idx}.pt"))
        print(f"[AVG] Layer {layer_idx} merged expert saved with keys: {list(merged_expert.keys())}")





####################################### tucker mistral ############################################


import os
import torch
import tensorly as tl
from tensorly.decomposition import tucker

tl.set_backend("pytorch")


merged_dir = "/mixtral/merged_experts_mixtral"  
decomposed_dir = "/mixtral/tucker_rank128"
os.makedirs(decomposed_dir, exist_ok=True)

projections = ['w1', 'w2', 'w3']


rank = 128

layer_ids = sorted([
    int(fn.split("_")[-1].split(".")[0])
    for fn in os.listdir(merged_dir) if fn.startswith("merged_expert_layer_") and fn.endswith(".pt")
])

if not layer_ids:
    raise RuntimeError(f"No merged layer files found in: {merged_dir}")

for proj in projections:
    proj_dir = os.path.join(decomposed_dir, proj)
    os.makedirs(proj_dir, exist_ok=True)


    merged_weights = []
    for layer_id in layer_ids:
        path = os.path.join(merged_dir, f"merged_expert_layer_{layer_id}.pt")
        merged_expert = torch.load(path, map_location="cpu")

        if proj not in merged_expert:
            raise KeyError(f"'{proj}' not found in {path}. Keys: {list(merged_expert.keys())}")

        W = merged_expert[proj].float()     # [out_dim, in_dim]
        merged_weights.append(W)

    out_dim, in_dim = merged_weights[0].shape
    for i, W in enumerate(merged_weights):
        if tuple(W.shape) != (out_dim, in_dim):
            raise ValueError(f"Shape mismatch at layer index {layer_ids[i]} for {proj}: "
                             f"{W.shape} vs {(out_dim, in_dim)}")


    # merged_tensor: [num_layers, out_dim, in_dim]
    merged_tensor = torch.stack(merged_weights, dim=0)

    # rank for three modes: [L, out, in]

    rank_L = len(layer_ids)
    core, factors = tucker(merged_tensor, rank=[rank_L, rank, rank])

    U_layer, U_out, U_in = factors  # U_layer: [L, rank_L]; U_out: [out_dim, rank]; U_in: [in_dim, rank]

    print(f"[{proj}] Tucker done. "
          f"U_layer: {U_layer.shape}, U_out: {U_out.shape}, U_in: {U_in.shape}, core: {core.shape}")


    torch.save(core,   os.path.join(proj_dir, "core.pt"))      
    torch.save(U_in,   os.path.join(proj_dir, "U_in.pt"))      # [in_dim, rank]
    torch.save(U_out,  os.path.join(proj_dir, "U_out.pt"))     # [out_dim, rank]
    torch.save(U_layer,os.path.join(proj_dir, "U_layer.pt"))   
  
    U_out_T = U_out.t()          # [rank, out_dim]
    U_in_mat = U_in              # [in_dim, rank]

    for i, layer_id in enumerate(layer_ids):
        W_i = merged_weights[i]  # [out_dim, in_dim]
        C_i = U_out_T @ W_i @ U_in_mat   # [rank, rank]

        torch.save(C_i, os.path.join(proj_dir, f"layer{layer_id}_coeff.pt"))

    print(f"[{proj}] Decomposition artifacts saved to {proj_dir}")


