import torch
import torch.nn as nn
import logging
from tqdm import tqdm
import numpy as np
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM

logger = logging.getLogger(__name__)

def frequency_pruning(model, calib_loader, args, layer_experts=None):
    logger.info("Starting frequency-based expert pruning")
    
    
    model.eval()
    
    
    if layer_experts is None:
        if args.r is None:
            raise ValueError("Neither layer_experts nor args.r is specified")
        
        num_layers = model.config.num_hidden_layers
        layer_experts = [args.r] * num_layers
    
    
    num_layers = model.config.num_hidden_layers
    num_experts = model.config.num_local_experts
    expert_counts = [torch.zeros(num_experts, dtype=torch.long, device='cuda') for _ in range(num_layers)]
    
    
    router_logits = [[] for _ in range(num_layers)]
    
    
    top_k = model.config.num_experts_per_tok  
    
    def collect_router_logits(layer_idx):
        def hook(module, input, output):
            
            
            router_probs = torch.nn.functional.softmax(output, dim=-1)
            _, expert_indices = torch.topk(router_probs, top_k, dim=-1)
            router_logits[layer_idx].append(expert_indices.cpu())
        return hook
    
    
    hooks = []
    for i in range(num_layers):
        hook = model.model.layers[i].block_sparse_moe.gate.register_forward_hook(
            collect_router_logits(i)
        )
        hooks.append(hook)
    
    
    logger.info("Collecting expert activation frequencies on calibration data")
    with torch.no_grad():
        for batch in tqdm(calib_loader):
            
            inputs = {k: v.cuda() for k, v in batch.items()}
            
            model(**inputs)
    
    
    for hook in hooks:
        hook.remove()
    
    
    for i in range(num_layers):
        if not router_logits[i]:  
            continue
            
        
        all_expert_indices = torch.cat(router_logits[i], dim=0)  
        
        
        flat_expert_indices = all_expert_indices.reshape(-1)  
        
        
        for expert_idx in range(num_experts):
            expert_counts[i][expert_idx] = (flat_expert_indices == expert_idx).sum().item()
        
        
        r = layer_experts[i]
        
        
        _, indices = torch.sort(expert_counts[i], descending=True)
        experts_to_keep = indices[:r].tolist()
        experts_to_prune = indices[r:].tolist()
        
        
        logger.info(f"Layer {i} expert activation counts: {expert_counts[i].tolist()}")
        logger.info(f"Layer {i} keeping experts {experts_to_keep} with counts {[expert_counts[i][j].item() for j in experts_to_keep]}")
        logger.info(f"Layer {i} pruning experts {experts_to_prune} with counts {[expert_counts[i][j].item() for j in experts_to_prune]}")
        
        
        moe_layer = model.model.layers[i].block_sparse_moe
        
        
        new_experts = nn.ModuleList([moe_layer.experts[j] for j in experts_to_keep])
        
        
        moe_layer.experts = new_experts
        moe_layer.num_experts = len(new_experts)
        
        
        old_router_weights = moe_layer.gate.weight.data
        new_router_weights = old_router_weights[experts_to_keep, :]
        moe_layer.gate = nn.Linear(old_router_weights.size(1), len(experts_to_keep), bias=False)
        moe_layer.gate.weight.data = new_router_weights
    
    
    info = {
        "pruning_method": "frequency",
        "layer_experts": layer_experts,
        "expert_counts": [counts.tolist() for counts in expert_counts]
    }
    
    logger.info("Frequency-based expert pruning completed successfully")
    return model, info


def router_logits_pruning(model, calib_loader, args, layer_experts=None):
    logger.info("Starting router_logits-based expert pruning")
    
    
    model.eval()
    
    
    if layer_experts is None:
        if args.r is None:
            raise ValueError("Neither layer_experts nor args.r is specified")
        num_layers = model.config.num_hidden_layers
        layer_experts = [args.r] * num_layers
    
    
    num_layers = model.config.num_hidden_layers
    num_experts = model.config.num_local_experts
    
    
    router_scores = [torch.zeros(num_experts, dtype=torch.float32) for _ in range(num_layers)]
    expert_output_norms = [torch.zeros(num_experts, dtype=torch.float32) for _ in range(num_layers)]
    
    
    top_k = model.config.num_experts_per_tok
    
    
    def collect_router_scores(layer_idx):
        def hook(module, input, output):
            
            probs = torch.nn.functional.softmax(output, dim=-1).to(torch.float32)
            topk_vals, topk_indices = torch.topk(probs, top_k, dim=-1)
            flat_indices = topk_indices.reshape(-1).detach().cpu()
            flat_vals = topk_vals.reshape(-1).detach().cpu()
            router_scores[layer_idx].index_add_(0, flat_indices.cpu(), flat_vals.cpu())
        return hook
    
    
    def collect_expert_output_norm(layer_idx, expert_idx):
        def hook(module, input, output):
            out = output[0] if isinstance(output, tuple) else output
            
            norm_sum = torch.linalg.vector_norm(out.to(torch.float32), dim=-1).sum().detach().cpu()
            expert_output_norms[layer_idx][expert_idx] += norm_sum
        return hook
    
    
    hooks = []
    for i in range(num_layers):
        moe_layer = model.model.layers[i].block_sparse_moe
        gate_hook = moe_layer.gate.register_forward_hook(collect_router_scores(i))
        hooks.append(gate_hook)
        for j in range(num_experts):
            expert_module = moe_layer.experts[j]
            h = expert_module.register_forward_hook(collect_expert_output_norm(i, j))
            hooks.append(h)
    
    
    logger.info("Collecting router logits and expert output norms on calibration data")
    with torch.no_grad():
        for batch in tqdm(calib_loader):
            inputs = {k: v.cuda() for k, v in batch.items()}
            model(**inputs)
    
    
    for h in hooks:
        h.remove()
    
    
    alpha = getattr(args, "router_weight", 0.5)
    def minmax_norm(x: torch.Tensor) -> torch.Tensor:
        x_min = x.min()
        x_max = x.max()
        denom = x_max - x_min
        if torch.isfinite(denom) and denom > 0:
            return (x - x_min) / denom
        return torch.zeros_like(x)
    
    
    for i in range(num_layers):
        r = layer_experts[i]
        r_scores = router_scores[i]
        o_scores = expert_output_norms[i]
        r_norm = minmax_norm(r_scores)
        o_norm = minmax_norm(o_scores)
        combined = alpha * r_norm + (1.0 - alpha) * o_norm
        
        
        _, indices = torch.sort(combined, descending=True)
        experts_to_keep = indices[:r].tolist()
        experts_to_prune = indices[r:].tolist()
        
        logger.info(f"Layer {i} router_scores: {r_scores.tolist()}")
        logger.info(f"Layer {i} output_norms: {o_scores.tolist()}")
        logger.info(f"Layer {i} combined_scores: {combined.tolist()}")
        logger.info(f"Layer {i} keeping experts {experts_to_keep}")
        logger.info(f"Layer {i} pruning experts {experts_to_prune}")
        
        
        moe_layer = model.model.layers[i].block_sparse_moe
        new_experts = nn.ModuleList([moe_layer.experts[j] for j in experts_to_keep])
        moe_layer.experts = new_experts
        moe_layer.num_experts = len(new_experts)
        
        
        old_router_weights = moe_layer.gate.weight.data
        in_features = old_router_weights.size(1)
        device = old_router_weights.device
        dtype = old_router_weights.dtype
        new_router_weights = old_router_weights[experts_to_keep, :].clone().to(device=device, dtype=dtype)
        new_gate = nn.Linear(in_features, len(experts_to_keep), bias=False, dtype=dtype, device=device)
        with torch.no_grad():
            new_gate.weight.copy_(new_router_weights)
        moe_layer.gate = new_gate
    
    
    info = {
        "pruning_method": "router_logits",
        "layer_experts": layer_experts,
        "router_scores": [scores.tolist() for scores in router_scores],
        "expert_output_norms": [scores.tolist() for scores in expert_output_norms],
        "alpha": float(alpha),
    }
    
    logger.info("Router logits-based expert pruning completed successfully")
    return model, info

def outlier_pruning(model, calib_loader, args, layer_experts=None):
    logger.info("Starting outlier-based expert pruning")
    
    
    model.eval()
    
    
    if layer_experts is None:
        if getattr(args, "r", None) is None:
            raise ValueError("Neither layer_experts nor args.r is specified")
        num_layers = model.config.num_hidden_layers
        layer_experts = [args.r] * num_layers
    
    
    num_layers = model.config.num_hidden_layers
    num_experts = model.config.num_local_experts
    gamma = float(getattr(args, "outlier_gamma", 3.0))  
    eps = 1e-8
    
    
    layer_norm_sum = [0.0 for _ in range(num_layers)]
    layer_norm_sq_sum = [0.0 for _ in range(num_layers)]
    layer_token_counts = [0 for _ in range(num_layers)]
    
    def collect_layer_norm_stats(layer_idx):
        def hook(module, input, output):
            out = output[0] if isinstance(output, tuple) else output
            norms = torch.linalg.vector_norm(out.to(torch.float32), dim=-1)
            
            layer_norm_sum[layer_idx] += norms.sum().detach().cpu().item()
            layer_norm_sq_sum[layer_idx] += (norms.square().sum().detach().cpu().item())
            layer_token_counts[layer_idx] += int(norms.numel())
        return hook
    
    
    hooks = []
    for i in range(num_layers):
        moe_layer = model.model.layers[i].block_sparse_moe
        for j in range(num_experts):
            expert_module = moe_layer.experts[j]
            h = expert_module.register_forward_hook(collect_layer_norm_stats(i))
            hooks.append(h)
    
    logger.info("Pass 1/2: Collecting per-layer norm stats on calibration data")
    with torch.no_grad():
        for batch in tqdm(calib_loader):
            inputs = {k: v.cuda() for k, v in batch.items()}
            model(**inputs)
    
    for h in hooks:
        h.remove()
    hooks.clear()
    
    
    layer_thresholds = []
    for i in range(num_layers):
        count = max(layer_token_counts[i], 1)
        mean = layer_norm_sum[i] / count
        mean_sq = layer_norm_sq_sum[i] / count
        var = max(mean_sq - mean * mean, 0.0)
        th = gamma * mean
        layer_thresholds.append(th)
        logger.info(f"Layer {i} norm stats: count={count}, mean={mean:.6f}, threshold={th:.6f}")
    
    
    expert_outlier_counts = [torch.zeros(num_experts, dtype=torch.long) for _ in range(num_layers)]
    expert_total_counts = [torch.zeros(num_experts, dtype=torch.long) for _ in range(num_layers)]
    
    def collect_expert_outliers(layer_idx, expert_idx):
        threshold = layer_thresholds[layer_idx]
        def hook(module, input, output):
            out = output[0] if isinstance(output, tuple) else output
            norms = torch.linalg.vector_norm(out.to(torch.float32), dim=-1)
            total = norms.numel()
            outliers = (norms > threshold).sum().detach().cpu().item()
            expert_outlier_counts[layer_idx][expert_idx] += int(outliers)
            expert_total_counts[layer_idx][expert_idx] += int(total)
        return hook
    
    
    for i in range(num_layers):
        moe_layer = model.model.layers[i].block_sparse_moe
        for j in range(num_experts):
            expert_module = moe_layer.experts[j]
            h = expert_module.register_forward_hook(collect_expert_outliers(i, j))
            hooks.append(h)
    
    logger.info("Pass 2/2: Collecting per-expert outlier ratios on calibration data")
    with torch.no_grad():
        for batch in tqdm(calib_loader):
            inputs = {k: v.cuda() for k, v in batch.items()}
            model(**inputs)
    
    for h in hooks:
        h.remove()
    
    
    for i in range(num_layers):
        r = layer_experts[i]
        totals = expert_total_counts[i].to(torch.float32)
        counts = expert_outlier_counts[i].to(torch.float32)
        ratios = torch.where(totals > 0, counts / (totals + eps), torch.zeros_like(totals))
        
        _, indices = torch.sort(ratios, descending=True)
        experts_to_keep = indices[:r].tolist()
        experts_to_prune = indices[r:].tolist()
        
        logger.info(f"Layer {i} outlier_counts: {counts.tolist()}")
        logger.info(f"Layer {i} total_counts: {totals.tolist()}")
        logger.info(f"Layer {i} outlier_ratios: {ratios.tolist()}")
        logger.info(f"Layer {i} keeping experts {experts_to_keep}")
        logger.info(f"Layer {i} pruning experts {experts_to_prune}")
        
        
        moe_layer = model.model.layers[i].block_sparse_moe
        new_experts = nn.ModuleList([moe_layer.experts[j] for j in experts_to_keep])
        moe_layer.experts = new_experts
        moe_layer.num_experts = len(new_experts)
        
        
        old_router_weights = moe_layer.gate.weight.data
        in_features = old_router_weights.size(1)
        device = old_router_weights.device
        dtype = old_router_weights.dtype
        new_router_weights = old_router_weights[experts_to_keep, :].clone().to(device=device, dtype=dtype)
        new_gate = nn.Linear(in_features, len(experts_to_keep), bias=False, dtype=dtype, device=device)
        with torch.no_grad():
            new_gate.weight.copy_(new_router_weights)
        moe_layer.gate = new_gate
    
    
    info = {
        "pruning_method": "outlier",
        "layer_experts": layer_experts,
        "layer_thresholds": [float(t) for t in layer_thresholds],
        "expert_outlier_counts": [cnt.tolist() for cnt in expert_outlier_counts],
        "expert_total_counts": [tot.tolist() for tot in expert_total_counts],
        "outlier_gamma": gamma,
    }
    
    logger.info("Outlier-based expert pruning completed successfully")
    return model, info
