from tqdm import tqdm
from argparse import Namespace
import logging

import torch
from torch.utils.data import DataLoader
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeForCausalLM
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeForCausalLM
from transformers.models.deepseek_v2.modeling_deepseek_v2 import DeepseekV2ForCausalLM

from model import (
    PrunableMixtralSparseMoeBlockWrapper,
    PrunableQwen2MoeSparseMoeBlockWrapper,
    PrunableQwen3MoeSparseMoeBlockWrapper,
    PrunableDeepSeekV2MoeWrapper
)


logger = logging.getLogger(__name__)


def _move_to_device(data, device: torch.device):
    if isinstance(data, torch.Tensor):
        return data.to(device, non_blocking=True)
    if isinstance(data, dict):
        return {key: _move_to_device(value, device) for key, value in data.items()}
    if isinstance(data, (list, tuple)):
        return type(data)(_move_to_device(x, device) for x in data)
    return data


def layerwise_pruning_mixtral(model: MixtralForCausalLM, calib_loader: DataLoader, args: Namespace, layer_experts=None):
    assert isinstance(
        model, MixtralForCausalLM), 'Currently only `Mixtral` is supported'

    
    default_r = args.r
    
    
    final_layer_experts = {}

    for l, layer in enumerate(model.model.layers):
        
        if layer_experts is not None and l < len(layer_experts):
            r = layer_experts[l]
            logger.info(f"Layer {l}: Using {r} experts from config")
        else:
            r = default_r
            logger.info(f"Layer {l}: Using default {r} experts")
            
        final_layer_experts[l] = r
        
        layer.block_sparse_moe = PrunableMixtralSparseMoeBlockWrapper(
            layer.block_sparse_moe, r=r)
        layer.block_sparse_moe.cache_X = True
        layer.block_sparse_moe.cache_Z = True

    with torch.inference_mode():
        for i, batch in enumerate(tqdm(calib_loader, desc='Model forwarding on sample set...')):
            
            embed_device = model.get_input_embeddings().weight.device
            batch = _move_to_device(batch, embed_device)
            model_inputs = model.prepare_inputs_for_generation(**batch)
            outputs = model(**model_inputs)
            assert outputs is not None

    
    
    
    torch.cuda.empty_cache()

    global_loss_history = dict()
    for l, layer in tqdm(list(enumerate(model.model.layers)), desc='Enumerating loss on sample set...'):
        b = layer.block_sparse_moe
        if not hasattr(b, 'cache_space'):
            continue
        if l < 16:
            b.to('cuda:0')
        else:
            b.to('cuda:1')
        loss_history = b.enumerate()
        global_loss_history[l] = loss_history
        b.prune()

    logger.info('Merging & saving...')
    for l, layer in enumerate(model.model.layers):
        layer.block_sparse_moe = layer.block_sparse_moe.model

    
    
    min_experts = min(final_layer_experts.values())
    model.num_experts = min_experts
    model.config.num_local_experts = min_experts
    
    
    return model, (global_loss_history, final_layer_experts)

def layerwise_pruning_qwen(model: Qwen2MoeForCausalLM, calib_loader: DataLoader, args: Namespace, layer_experts=None):
    assert isinstance(
        model, (Qwen2MoeForCausalLM, Qwen3MoeForCausalLM)), 'Currently only `Qwen2,3MoE` is supported'

    
    default_r = args.r
    
    
    final_layer_experts = {}

    for l, layer in enumerate(model.model.layers):
        
        if layer_experts is not None and l < len(layer_experts):
            r = layer_experts[l]
            logger.info(f"Layer {l}: Using {r} experts from config")
        else:
            r = default_r
            logger.info(f"Layer {l}: Using default {r} experts")
            
        final_layer_experts[l] = r
        if isinstance(model, Qwen2MoeForCausalLM):
            layer.mlp = PrunableQwen2MoeSparseMoeBlockWrapper(
                layer.mlp, r=r)
        elif isinstance(model, Qwen3MoeForCausalLM):
            layer.mlp = PrunableQwen3MoeSparseMoeBlockWrapper(
                layer.mlp, r=r)
        layer.mlp.cache_X = True
        layer.mlp.cache_Z = True

    with torch.inference_mode():
        for i, batch in enumerate(tqdm(calib_loader, desc='Model forwarding on sample set...')):
            
            embed_device = model.get_input_embeddings().weight.device
            batch = _move_to_device(batch, embed_device)
            model_inputs = model.prepare_inputs_for_generation(**batch)
            outputs = model(**model_inputs)
            assert outputs is not None

    
    
    
    torch.cuda.empty_cache()

    global_loss_history = dict()
    for l, layer in tqdm(list(enumerate(model.model.layers)), desc='Enumerating loss on sample set...'):
        b = layer.mlp
        if not hasattr(b, 'cache_space'):
            continue
        
        
        
        
        loss_history = b.enumerate()
        global_loss_history[l] = loss_history
        b.prune()
        logger.info(f"Expert pruned: {b.experts_to_drop}, loss: {loss_history[b.experts_to_drop]}")
        

    logger.info('Merging & saving...')
    for l, layer in enumerate(model.model.layers):
        layer.mlp = layer.mlp.model

    
    
    min_experts = min(final_layer_experts.values())
    model.num_experts = min_experts
    model.config.num_local_experts = min_experts
    
    
    return model, (global_loss_history, final_layer_experts)

def layerwise_pruning_deepseek(model: DeepseekV2ForCausalLM, calib_loader: DataLoader, args: Namespace, layer_experts=None):
    assert isinstance(
        model, (DeepseekV2ForCausalLM)), 'Currently only `DeepSeek-V2` is supported'

    
    default_r = model.config.n_routed_experts
    
    
    final_layer_experts = {}

    for l, layer in enumerate(model.model.layers):
        
        if l == 0:
            continue
        if layer_experts is not None:
            r = layer_experts[l - 1]
            logger.info(f"Layer {l}: Using {r} experts from config")
        else:
            r = default_r
            logger.info(f"Layer {l}: Using default {r} experts")
            
        
        try:
            num_experts_current = layer.mlp.gate.weight.shape[0]
        except Exception:
            num_experts_current = getattr(model.config, 'n_routed_experts', r)
        if r is None:
            r = num_experts_current
        if r < 0:
            logger.warning(f"Layer {l}: r < 0 ({r}); clamping to 0")
            r = 0
        if r > num_experts_current:
            logger.warning(f"Layer {l}: r ({r}) > num_experts ({num_experts_current}); clamping to {num_experts_current}")
            r = num_experts_current

        final_layer_experts[l - 1] = r
        layer.mlp = PrunableDeepSeekV2MoeWrapper(layer.mlp, r=r)
        layer.mlp.cache_X = True
        layer.mlp.cache_Z = True

    with torch.inference_mode():
        for i, batch in enumerate(tqdm(calib_loader, desc='Model forwarding on sample set...')):
            
            embed_device = model.get_input_embeddings().weight.device
            batch = _move_to_device(batch, embed_device)
            model_inputs = model.prepare_inputs_for_generation(**batch)
            outputs = model(**model_inputs)
            assert outputs is not None

    
    
    
    torch.cuda.empty_cache()

    global_loss_history = dict()
    for l, layer in tqdm(list(enumerate(model.model.layers)), desc='Enumerating loss on sample set...'):
        if l == 0:
            continue
        b = layer.mlp
        if not hasattr(b, 'cache_space'):
            continue
        
        
        
        
        loss_history = b.enumerate()
        global_loss_history[l] = loss_history
        b.prune()
        logger.info(f"Expert pruned: {b.experts_to_drop}, loss: {loss_history[b.experts_to_drop]}")
        

    logger.info('Merging & saving...')
    for l, layer in enumerate(model.model.layers):
        if l == 0:
            continue
        layer.mlp = layer.mlp.model

    
    
    min_experts = min(final_layer_experts.values())
    model.n_routed_experts = min_experts
    model.config.n_routed_experts = min_experts
    
    
    return model, (global_loss_history, final_layer_experts)

def layerwise_pruning_search(model: MixtralForCausalLM, calib_loader: DataLoader, args: Namespace, layer_experts=None):
    assert isinstance(
        model, MixtralForCausalLM), 'Currently only `Mixtral` is supported'
    
    
    layer_idx = args.layer_index
    if layer_experts is not None and layer_idx < len(layer_experts):
        r = layer_experts[layer_idx]
        logger.info(f"Layer {layer_idx}: Using {r} experts from config")
    else:
        r = args.r
        logger.info(f"Layer {layer_idx}: Using default {r} experts")
    
    
    layer = model.model.layers[layer_idx]
    layer.block_sparse_moe = PrunableMixtralSparseMoeBlockWrapper(
        layer.block_sparse_moe, r=r)
    layer.block_sparse_moe.cache_X = True
    layer.block_sparse_moe.cache_Z = True

    with torch.inference_mode():
        for i, batch in enumerate(tqdm(calib_loader, desc='Model forwarding on sample set...')):
            embed_device = model.get_input_embeddings().weight.device
            batch = _move_to_device(batch, embed_device)
            model_inputs = model.prepare_inputs_for_generation(**batch)
            outputs = model(**model_inputs)
            assert outputs is not None

    
    
    torch.cuda.empty_cache()

    global_loss_history = dict()
    
    l = args.layer_index
    b = model.model.layers[l].block_sparse_moe
    if hasattr(b, 'cache_space'):
        if l < 16:
            b.to('cuda:0')
        else:
            b.to('cuda:1')
        loss_history = b.enumerate()
        global_loss_history[l] = loss_history
        b.prune()
        

    logger.info('Merging & saving...')
    
    model.model.layers[args.layer_index].block_sparse_moe = model.model.layers[args.layer_index].block_sparse_moe.model

    
    final_layer_experts = {layer_idx: r}

    return model, (global_loss_history, final_layer_experts)

def progressive_pruning(model: MixtralForCausalLM, calib_loader: DataLoader, args: Namespace, layer_experts=None):
    assert isinstance(
        model, MixtralForCausalLM), 'Currently only `Mixtral` is supported'

    
    default_r = args.r
    
    
    final_layer_experts = {}

    for l, layer in enumerate(model.model.layers):
        
        if layer_experts is not None and l < len(layer_experts):
            r = layer_experts[l]
            logger.info(f"Layer {l}: Using {r} experts from config")
        else:
            r = default_r
            logger.info(f"Layer {l}: Using default {r} experts")
            
        final_layer_experts[l] = r
        
        layer.block_sparse_moe = PrunableMixtralSparseMoeBlockWrapper(
            layer.block_sparse_moe, r=r)
        layer.block_sparse_moe.cache_Z = True

    with torch.inference_mode():
        for i, batch in enumerate(tqdm(calib_loader, desc='Computing Z activations on sample set...')):
            embed_device = model.get_input_embeddings().weight.device
            batch = _move_to_device(batch, embed_device)
            model_inputs = model.prepare_inputs_for_generation(**batch)
            outputs = model(**model_inputs)
            assert outputs is not None

    del model_inputs
    del outputs
    torch.cuda.empty_cache()

    for l, layer in enumerate(model.model.layers):
        layer.block_sparse_moe.cache_Z = False

    
    global_loss_history = dict()

    for l, layer in tqdm(list(enumerate(model.model.layers)), desc='Dropping layers...'):
        b = layer.block_sparse_moe

        b.cache_X = True
        with torch.inference_mode():
            for i, batch in enumerate(calib_loader):
                embed_device = model.get_input_embeddings().weight.device
                batch = _move_to_device(batch, embed_device)
                model_inputs = model.prepare_inputs_for_generation(**batch)
                outputs = model(**model_inputs)
                assert outputs is not None

        del model_inputs
        del outputs
        torch.cuda.empty_cache()
        b.cache_X = False

        loss_history = b.enumerate()
        global_loss_history[l] = loss_history

        b.prune()
        layer.block_sparse_moe = b.model

    
    min_experts = min(final_layer_experts.values())
    model.num_experts = min_experts
    model.config.num_local_experts = min_experts
    
    
    return model, (global_loss_history, final_layer_experts)
