
    
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
from safetensors import safe_open
from safetensors.torch import save_file
import os
import copy
import argparse
import gc
import math
import sys





class HybridMoBE_VariableRank_Blending_SingleP(nn.Module):
    def __init__(
        self,
        initial_B_padded,      # [G, K_max, C]
        initial_C_padded,      # [N, R, K_max]
        allocated_ks,          
        experts_per_group,   
        r_dim,                
        rows,
        cols,
        num_groups,
        activation='silu'
    ):
        """
        Hybrid MoBE:
          - Variable Rank (per group K_i) + Hard Mask
          - Basis Blending (A_params · B_fixed)
          - Single-Head Residual (One P matrix of size 4MN * r)
        """
        super().__init__()

        self.rows = rows
        self.cols = cols
        self.num_groups = num_groups
        self.num_total_experts = initial_C_padded.shape[0]
        self.experts_per_group = experts_per_group

      
        self.B_fixed = nn.Parameter(initial_B_padded, requires_grad=True)
        self.C_params = nn.Parameter(initial_C_padded)

        # Basis Blending: [N, G]
        self.A_params = nn.Parameter(
            torch.ones(self.num_total_experts, num_groups) / num_groups
        )


        K_max = initial_B_padded.shape[1]

        # B Mask: [G, K_max, 1]
        rank_mask_B = torch.zeros(num_groups, K_max, 1)
        for g, k_val in enumerate(allocated_ks):
            rank_mask_B[g, :k_val, 0] = 1.0
        self.register_buffer("rank_mask_B", rank_mask_B)

        # C Mask: [N, 1, K_max]
        C_rank_mask = torch.zeros(self.num_total_experts, 1, K_max)
        for e in range(self.num_total_experts):
            g = e // self.experts_per_group
            k_val = allocated_ks[g]
            C_rank_mask[e, 0, :k_val] = 1.0
        self.register_buffer("C_rank_mask", C_rank_mask)

  
        if activation == 'silu':
            self.act = nn.functional.silu
        elif activation == 'none':
            self.act = nn.Identity()
        elif activation == 'gelu':
            self.act = nn.GELU()
        else:
            self.act = torch.tanh

  
        self.a_params = nn.Parameter(torch.zeros(num_groups, r_dim))

        D = rows * cols
        G_total = D * self.experts_per_group  

   
        p_indices = torch.randint(0, r_dim, (G_total,), dtype=torch.long)
        
 
        unique, counts = torch.unique(p_indices, return_counts=True)
        bucket_size = torch.zeros(r_dim, dtype=torch.float32)
        bucket_size[unique] = counts.float()
        scale = 1.0 / torch.sqrt(bucket_size + 1e-8)
        p_scale = scale[p_indices]

        self.register_buffer("p_indices", p_indices)
        self.register_buffer("p_scale", p_scale)

        self.D = D

    def forward(self, batch_indices):
        group_ids = batch_indices // self.experts_per_group   # [B]
        sub_ids = batch_indices % self.experts_per_group      # [B]

        # ===== Part A: Basis Blending + Rank Mask =====
        batch_A = self.A_params[batch_indices]  # [B, G]
        
  
        B_blended = torch.einsum('bg, gkc -> bkc', batch_A, self.B_fixed)

        B_activated = self.act(B_blended)

    
        batch_mask_B = self.rank_mask_B[group_ids]  # [B, K_max, 1]
        B_final = B_activated * batch_mask_B

    
        batch_C = self.C_params[batch_indices]  # [B, R, K_max]
        
        # Masking C 
        batch_mask_C = self.C_rank_mask[batch_indices]
        C_masked = batch_C * batch_mask_C

       
        W_main = torch.matmul(C_masked, B_final)

        # ===== Part B: Single-Head Residual =====
        batch_a = self.a_params[group_ids]  # [B, r]

 
        offset = sub_ids * self.D
        idx = offset.unsqueeze(1) + torch.arange(self.D, device=offset.device).unsqueeze(0) # [B, D]

      
        curr_p_idx = self.p_indices[idx] # [B, D]
        curr_p_scl = self.p_scale[idx]   # [B, D]

        # Gather: R_flat = P * a
        R_flat = batch_a.gather(1, curr_p_idx) * curr_p_scl
        
        W_res = R_flat.view(len(batch_indices), self.rows, self.cols)

        return W_main + W_res


def get_layer_proj_dict(index_path, base_dir, layer_i, matrix_type):

    try:
        with open(index_path, 'r') as f:
            index_data = json.load(f)
    except FileNotFoundError:
        print(f"Error: Index file not found at {index_path}", file=sys.stderr)
        return {}

    layer_dict = {}
    for weight_name, safetensor_file in index_data.get('weight_map', {}).items():
        if f"model.layers.{layer_i}.mlp.experts" in weight_name and weight_name.endswith(f"{matrix_type}.weight"):
            safetensor_path = os.path.join(base_dir, safetensor_file)
            try:
                with safe_open(safetensor_path, framework="pt") as f:
                    if weight_name in f.keys():
                        layer_dict[weight_name] = f.get_tensor(weight_name)
            except Exception as e:
                print(f"Warning: Could not load {weight_name} from {safetensor_path}. Error: {e}", file=sys.stderr)
    return layer_dict


def parse_args():
    parser = argparse.ArgumentParser(description="Hybrid MoBE: Activation-Based Sorting + Variable Rank")

    parser.add_argument("--index_path", type=str, required=True, help="Path to the model index JSON file.")
    parser.add_argument("--base_dir", type=str, required=True, help="Base directory containing the safetensors files.")
    parser.add_argument("--save_path", type=str, required=True, help="Directory to save the compressed models.")


    parser.add_argument("--activation_counts_path", type=str, default='/eagle/lc-mpi/Zhendong/MoBE/DeepSeek-V2-Lite-Chat_expert_counts.pt', 
                        help="Path to expert_counts.pt for activation-based sorting. If not provided, falls back to SVD sorting.")

    parser.add_argument("--num_matrices", type=int, default=128)
    parser.add_argument("--rows_per_matrix", type=int, default=768)
    parser.add_argument("--cols", type=int, default=2048)
    parser.add_argument("--num_hidden_layers", type=int, default=94)
    parser.add_argument("--matrix_type", type=str, choices=["gate_proj", "up_proj", "down_proj"], default="gate_proj")

    parser.add_argument("--num_epochs", type=int, default=5000)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_batches", type=int, default=4, help="Gradient accumulation steps")
    parser.add_argument("--learning_rate", type=float, default=0.07)
    parser.add_argument("--start_layer", type=int, default=0)
    parser.add_argument("--end_layer", type=int, default=48)
    parser.add_argument("--activation", type=str, choices=["silu", "tanh", "none",'gelu'], default="silu")

    parser.add_argument("--num_B", type=int, default=32, help="Number of groups (G)")
    
    parser.add_argument("--target_compression_ratio", type=float, default=0.59,
                        help="Target backbone parameter ratio (excluding residual)")
    parser.add_argument("--residual_ratio", type=float, default=0.03,
                        help="Target ratio for residual component r_dim size.")

    return parser.parse_args()


def main():
    args = parse_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    os.makedirs(args.save_path, exist_ok=True)

    num_group = args.num_B
    num_matrices_per_group = args.num_matrices // num_group
    experts_per_group = num_matrices_per_group

    if args.num_matrices % args.num_B != 0:
        raise ValueError("num_matrices must be divisible by num_B (experts must be evenly distributed across groups).")

    print(f"Config: Activation-Based Sorting. Target Backbone Ratio {args.target_compression_ratio*100:.1f}%")


    all_activation_counts = None
    if args.activation_counts_path:
        if os.path.exists(args.activation_counts_path):
            print(f"Loading activation counts from {args.activation_counts_path}...")
       
            all_activation_counts = torch.load(args.activation_counts_path, map_location='cpu')
        else:
            print(f"Warning: Activation file {args.activation_counts_path} not found. Will fall back to SVD sorting.")

    for n in range(args.start_layer, args.end_layer):
        best_real_loss = float('inf')
        best_model_state = None

        print(f"\n========== Processing Layer: {n} ==========")

        state_dict = get_layer_proj_dict(args.index_path, args.base_dir, layer_i=n, matrix_type=args.matrix_type)
        target_list = []
        for i in range(args.num_matrices):
            key = f"model.layers.{n}.mlp.experts.{i}.{args.matrix_type}.weight"
            target_list.append(state_dict[key].to(torch.float16).to(device))

        raw_target = torch.stack(target_list)
        global_target_std = raw_target.std()

        expert_scores = []
        
        if all_activation_counts is not None:
        
            if n < all_activation_counts.shape[0]:
                layer_counts = all_activation_counts[n]
                for i in range(args.num_matrices):
             
                    score = layer_counts[i].item()
                    expert_scores.append((i, score))
            else:
                print(f"  [Warning] Layer {n} out of bounds in counts file. Using default order.")
                for i in range(args.num_matrices): expert_scores.append((i, 0.0))
        
        else:
     
            print("  [Sorting Strategy] Effective Rank (Spectral Entropy)")
            for i in range(args.num_matrices):
                w = raw_target[i].float()
                _, S, _ = torch.linalg.svd(w, full_matrices=False)
                singular_vals_squared = S.pow(2)
                p = singular_vals_squared / singular_vals_squared.sum()
                p_nz = p[p > 1e-10]
                if p_nz.numel() > 0:
                     entropy = -torch.sum(p_nz * torch.log(p_nz)).item()
                else:
                     entropy = 0.0
                score = math.exp(entropy)
                expert_scores.append((i, score))
        
      
        expert_scores.sort(key=lambda x: x[1], reverse=True)
        
       
        print(f"  Top-1 Expert: ID {expert_scores[0][0]} (Score: {expert_scores[0][1]:.2f})")
        print(f"  Bottom-1 Expert: ID {expert_scores[-1][0]} (Score: {expert_scores[-1][1]:.2f})")
        
        sorted_indices = torch.tensor([x[0] for x in expert_scores], dtype=torch.long, device=device)
        
  
        target = raw_target[sorted_indices]
        
   
        print("Step 1: Analyzing SVD complexity per group...")
        group_scores = []
        group_svd_cache = []
        
        for g_idx in range(num_group):
            g_target = target[g_idx*experts_per_group:(g_idx+1)*experts_per_group]
    
            g_flat = torch.cat([t for t in g_target], dim=0).float()
            
            U, S, Vt = torch.linalg.svd(g_flat, full_matrices=False)
  
            
            del g_flat
            torch.cuda.empty_cache()

            
       
            singular_vals_squared = S.pow(2)
            p = singular_vals_squared / singular_vals_squared.sum()
            p_nz = p[p > 1e-10]
            entropy = -torch.sum(p_nz * torch.log(p_nz)).item()
            effective_rank_score = math.exp(entropy)
            
            group_scores.append(effective_rank_score)
            group_svd_cache.append((U, S, Vt))
        
        
        
# ===== Step 1.5: group activation score (normalize) =====
        if all_activation_counts is not None and n < all_activation_counts.shape[0]:
            layer_counts = all_activation_counts[n].to(device)  # [128]
     
            layer_counts_sorted = layer_counts[sorted_indices] 
            
            group_activation = []
            for g_idx in range(num_group):
                start = g_idx * experts_per_group
                end = (g_idx + 1) * experts_per_group

                g_freq = layer_counts_sorted[start:end].sum().item()
                group_activation.append(g_freq)

            group_activation = torch.tensor(group_activation, dtype=torch.float32)


            freq_norm = group_activation / (group_activation.sum() + 1e-8)
        else:
      
            freq_norm = torch.ones(num_group, dtype=torch.float32) / num_group

        svd_scores = torch.tensor(group_scores, dtype=torch.float32)
        #print(svd_scores)
        svd_norm = svd_scores / (svd_scores.sum() + 1e-8)

  
        alpha = 0.7
        group_fused_score = alpha * svd_norm + (1 - alpha) * freq_norm
        

        print(f"  [Group 0] SVD: {svd_norm[0]:.4f}, Freq: {freq_norm[0]:.4f} -> Fused: {group_fused_score[0]:.4f}")
        print(f"  [Group -1] SVD: {svd_norm[-1]:.4f}, Freq: {freq_norm[-1]:.4f} -> Fused: {group_fused_score[-1]:.4f}")
        
        print(f"  SVD norm:   {svd_norm.tolist()}")
        print(f"  Freq norm:  {freq_norm.tolist()}")
        print(f"  Fused norm: {group_fused_score.tolist()}")



        

        total_original_params = args.num_matrices * args.rows_per_matrix * args.cols
        target_backbone_params = total_original_params * args.target_compression_ratio
        
        params_per_rank_unit = args.cols + experts_per_group * args.rows_per_matrix
        total_K_budget = int(target_backbone_params / params_per_rank_unit)
        

        k_upper_limit = min(args.rows_per_matrix * experts_per_group, args.cols)
        
        total_score = group_fused_score.sum().item()
        
   
        allocated_Ks = []
        for g_idx in range(num_group):
            score = group_fused_score[g_idx].item()
            k = int(total_K_budget * (score / total_score))
 
            allocated_Ks.append(max(1, k))

    
        
        iteration = 0
        while True:
            iteration += 1
            current_excess = 0
            not_full_indices = []
 
            for i in range(num_group):
                if allocated_Ks[i] > k_upper_limit:
              
                    overflow = allocated_Ks[i] - k_upper_limit
                    current_excess += overflow
           
                    allocated_Ks[i] = k_upper_limit
                elif allocated_Ks[i] < k_upper_limit:
         
                    not_full_indices.append(i)
            
    
            if current_excess == 0:
       
                break
            
            if len(not_full_indices) == 0:
       
                print(f"Warning: All groups reached limit. Discarding {current_excess} remaining ranks.")
                break


            add_per_group = current_excess // len(not_full_indices)

            remainder = current_excess % len(not_full_indices)
            

            for idx_in_list, group_idx in enumerate(not_full_indices):
                allocated_Ks[group_idx] += add_per_group

                if idx_in_list < remainder:
                    allocated_Ks[group_idx] += 1
            

            if iteration > 100:
                print("Warning: Rank allocation loop exceeded 100 iterations. Breaking.")
                break

        K_max = max(allocated_Ks)
        print(f"Step 2: Allocated Ks (Average Redistribute Mode): {allocated_Ks}")
        print(f"        Max K: {K_max}")
        
        

        init_B_padded = torch.zeros(num_group, K_max, args.cols, device=device)
        init_C_padded_list = []

        for g_idx in range(num_group):
            K_i = allocated_Ks[g_idx]
            U, S, Vt = group_svd_cache[g_idx]
            
            init_B_padded[g_idx, :K_i, :] = Vt[:K_i, :]
            
            U_k = U[:, :K_i]
            S_k = S[:K_i]
            C_raw = U_k @ torch.diag(S_k)
            C_norm = C_raw 
            
            C_padded_group = torch.zeros(C_norm.shape[0], K_max, device=device)
            C_padded_group[:, :K_i] = C_norm
            
            C_split = torch.chunk(C_padded_group, experts_per_group, dim=0)
            C_stacked = torch.stack(C_split, dim=0)
            init_C_padded_list.append(C_stacked)

        final_C_padded = torch.cat(init_C_padded_list, dim=0)

  
        r = int(args.rows_per_matrix * args.cols * experts_per_group * args.residual_ratio)
        r = max(r, 1)
        
        model = HybridMoBE_VariableRank_Blending_SingleP(
            initial_B_padded=init_B_padded,
            initial_C_padded=final_C_padded,
            allocated_ks=allocated_Ks,
            experts_per_group=experts_per_group,
            r_dim=r,
            rows=args.rows_per_matrix,
            cols=args.cols,
            num_groups=num_group,
            activation=args.activation
        ).to(device)
  
        optimizer = torch.optim.Adam([
            {'params': model.C_params, 'lr': args.learning_rate},
            {'params': model.B_fixed, 'lr': args.learning_rate},
            {'params': model.A_params, 'lr': args.learning_rate }, 
            {'params': model.a_params, 'lr': args.learning_rate *2},
        ])

        grad_accum = args.num_batches

        for epoch in range(args.num_epochs):
            epoch_loss = 0.0
            epoch_real_loss = 0.0
            epoch_real_mae = 0.0

            optimizer.zero_grad()
            for batch_idx in range(args.num_batches):
                start_idx = batch_idx * args.batch_size
                end_idx = min((batch_idx + 1) * args.batch_size, args.num_matrices)
                indices = torch.arange(start_idx, end_idx, device=device)

                outputs = model(indices)
                batch_target = target[indices] 

                Z_scaled = batch_target / global_target_std
                loss = F.mse_loss(outputs, Z_scaled.float())

                
                with torch.no_grad():
                    Z_hat = outputs * global_target_std
                    real_mse = F.mse_loss(Z_hat, batch_target.float())
                    real_mae = F.l1_loss(Z_hat, batch_target.float())

                (loss / grad_accum).backward()

                bs = len(indices)
                epoch_loss += loss.item() * bs
                epoch_real_loss += real_mse.item() * bs
                epoch_real_mae += real_mae.item() * bs

            # Mask
            with torch.no_grad():
                if model.B_fixed.grad is not None:
                    model.B_fixed.grad *= model.rank_mask_B
                if model.C_params.grad is not None:
                    model.C_params.grad *= model.C_rank_mask

            optimizer.step()

            epoch_loss /= args.num_matrices
            epoch_real_loss /= args.num_matrices
            epoch_real_mae /= args.num_matrices

            if epoch_loss < best_real_loss:
                best_real_loss = epoch_loss
                best_model_state = {
    k: v.detach().cpu()
    for k, v in model.state_dict().items()
}


            if (epoch + 1) % 200 == 0:
                print(f"Epoch {epoch+1}, Scaled: {epoch_loss:.9f}, MSE: {epoch_real_loss:.9f}, MAE: {epoch_real_mae:.9f}")

 
        print(f"Finished Layer {n}. Best Scaled Loss: {best_real_loss:.8f}")

        state_to_save = best_model_state
        state_to_save['_global_std'] = global_target_std
        state_to_save['_allocated_ks'] = allocated_Ks
        state_to_save['_expert_permutation'] = sorted_indices.cpu() 
        
        state_to_save['_config'] = {
            'K_max': K_max,
            'r': r,
            'activation': args.activation,
            'num_groups': num_group,
            'experts_per_group': experts_per_group,
            'type': 'activation_sorted_variable_rank'
        }

        out_file = f'{args.save_path}/model_layers_{n}_mlp_{args.matrix_type}_Hybrid.pth'
        torch.save(state_to_save, out_file)
        print(f"Saved to {out_file}")


        original_params = args.num_matrices * args.rows_per_matrix * args.cols
        params_per_k = args.cols + (experts_per_group * args.rows_per_matrix)
        backbone_params = sum(allocated_Ks) * params_per_k
        residual_params = num_group * r 
        total_compressed = backbone_params + residual_params
        actual_ratio = total_compressed / original_params

        print("-" * 40)
        print(f"Final Statistics for Layer {n}:")
        print(f"  Actual Compression Ratio: {actual_ratio:.2%} (Target: {args.target_compression_ratio:.0%})")
        print("-" * 40)

        del target, model, optimizer, group_svd_cache, raw_target
        gc.collect()
        torch.cuda.empty_cache()


if __name__ == "__main__":
    main()