from sklearn.metrics import silhouette_score, calinski_harabasz_score, davies_bouldin_score
import torch
import torch.nn.functional as F
from functools import partial
from tqdm import tqdm
from torch.utils.data import DataLoader
import argparse
import pandas as pd
from accelerate import Accelerator
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans, DBSCAN
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils.common_utils import setup_logging, set_random_seed
import numpy as np


logger = setup_logging()


# 模型配置映射
MODEL_CONFIGS = {
    'deepseek': {
        'model_name': 'deepseek-ai/deepseek-moe-16b-base',
        'gating_path': 'model.layers.{layer_idx}.mlp.gate',
        'experts_path': 'model.layers.{layer_idx}.mlp.experts',
        'expert_param_name_template': 'model.layers.{layer_idx}.mlp.experts.{expert_idx}.{param_name}',
        'num_experts_attr': 'n_routed_experts',
        'num_layers_attr': 'num_hidden_layers'
    },
    'phi': {
        'model_name': 'microsoft/Phi-tiny-MoE-instruct',
        'gating_path': 'model.layers.{layer_idx}.block_sparse_moe.gate',
        'experts_path': 'model.layers.{layer_idx}.block_sparse_moe.experts',
        'expert_param_name_template': 'model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.{param_name}',
        'num_experts_attr': 'num_local_experts',
        'num_layers_attr': 'num_hidden_layers'
    }
}


def get_model_config(model_name):
    """根据模型名称获取对应的配置"""
    for key, config in MODEL_CONFIGS.items():
        if key in model_name.lower():
            return config
    # 默认返回deepseek配置
    logger.warning(f"未找到模型 {model_name} 的配置，使用默认配置")
    return MODEL_CONFIGS['deepseek']


def get_nested_attr(obj, attr_path):
    """通过路径获取嵌套属性"""
    attrs = attr_path.split('.')
    for attr in attrs:
        try:
            obj = getattr(obj, attr)
        except AttributeError:
            return None
    return obj


def register_gating_hook(model, layer_index, hook_fn, model_config):
    """注册门控钩子，支持不同模型架构"""
    try:
        gating_path = model_config['gating_path'].format(layer_idx=layer_index)
        gating_layer = get_nested_attr(model, gating_path)
        if gating_layer is None:
            logger.warning(f"无法找到层 {layer_index} 的门控层，路径: {gating_path}")
            return None
        return gating_layer.register_forward_hook(hook_fn)
    except Exception as e:
        logger.warning(f"无法为层 {layer_index} 注册钩子: {e}")
        return None


def get_experts_layer(model, layer_index, model_config):
    """获取专家层，支持不同模型架构"""
    try:
        experts_path = model_config['experts_path'].format(
            layer_idx=layer_index)
        return get_nested_attr(model, experts_path)
    except Exception as e:
        logger.warning(f"无法获取层 {layer_index} 的专家层: {e}")
        return None


def top_k_hook(stats_dict, layer_idx, module, input, output):
    """处理top-k钩子输出"""
    # 情况1：Phi模型 - 门控层直接返回概率分布张量!!!!!!!!!!!!!!!!!!!!!!!还需要调整
    if isinstance(output, torch.Tensor) and output.dim() == 2:
        # output shape: [batch_size * seq_len, num_experts]
        full_probs = F.softmax(output, dim=-1)
        num_tokens_in_batch, num_experts = full_probs.shape

        # 获取top-k
        k = 5  # 假设使用top-2，可以根据模型配置调整
        topk_weights, topk_indices = torch.topk(full_probs, k, dim=-1)

    # 情况2：DeepSeek模型 - 门控层返回(indices, weights)元组
    elif isinstance(output, tuple) and len(output) >= 2:
        topk_indices_raw, topk_weights_raw = output[0], output[1]

        if not isinstance(topk_indices_raw, torch.Tensor) or not isinstance(topk_weights_raw, torch.Tensor) or topk_indices_raw.shape != topk_weights_raw.shape:
            return

        topk_indices = topk_indices_raw.long()
        topk_weights = topk_weights_raw.float()

        k = topk_indices.shape[1]
        num_tokens_in_batch = topk_indices.shape[0]

        # 重构完整的概率分布
        layer_stats = stats_dict[layer_idx]
        num_experts = layer_stats['sum_of_probs'].shape[0]
        device = topk_weights.device

        full_probs = torch.zeros(
            num_tokens_in_batch, num_experts, device=device)
        full_probs.scatter_(1, topk_indices, topk_weights)
    else:
        return

    layer_stats = stats_dict[layer_idx]
    num_experts = layer_stats['sum_of_probs'].shape[0]

    device = topk_weights.device
    for key in ['sum_of_probs', 'sum_of_squared_probs', 'soft_activations', 'sum_of_entropies']:
        layer_stats[key] = layer_stats[key].to(device)
    layer_stats['activation_counts'] = layer_stats['activation_counts'].to(
        device)

    with torch.no_grad():
        full_probs = torch.zeros(
            num_tokens_in_batch, num_experts, device=device)
        full_probs.scatter_(1, topk_indices, topk_weights)

        layer_stats['sum_of_probs'].add_(torch.sum(full_probs, dim=0))
        layer_stats['sum_of_squared_probs'].add_(
            torch.sum(torch.pow(full_probs, 2), dim=0))
        layer_stats['total_tokens'] += num_tokens_in_batch
        entropies = -torch.sum(full_probs *
                               torch.log(full_probs + 1e-9), dim=-1)

        flat_indices = topk_indices.flatten()
        ones_to_add = torch.ones_like(flat_indices, dtype=torch.long)
        layer_stats['activation_counts'].scatter_add_(
            0, flat_indices, ones_to_add)

        flat_weights = topk_weights.flatten()
        layer_stats['soft_activations'].scatter_add_(
            0, flat_indices, flat_weights)

        expanded_entropies = entropies.repeat_interleave(k)
        layer_stats['sum_of_entropies'].scatter_add_(
            0, flat_indices, expanded_entropies)


def calculate_static_metrics(model, model_config):
    """计算静态指标，支持不同模型架构"""
    logger.info(f"\n{'='*20}\n开始一次性计算所有静态指标...\n{'='*20}")

    num_experts = getattr(model.config, model_config['num_experts_attr'], 0)
    num_layers = getattr(model.config, model_config['num_layers_attr'], 0)

    static_results = []

    for i in tqdm(range(num_layers), desc="计算静态指标"):
        experts_layer = get_experts_layer(model, i, model_config)
        if experts_layer is None:
            continue

        for j in range(num_experts):
            param_norm = 0.0
            try:
                expert = experts_layer[j]
                norm = torch.linalg.norm(
                    torch.cat([p.data.flatten() for p in expert.parameters()]))
                param_norm = norm.item()
            except (AttributeError, IndexError):
                pass

            static_results.append({
                'layer': i,
                'expert': j,
                'param_norm': param_norm
            })

    logger.info("--- 静态指标计算完成 ---")
    return pd.DataFrame(static_results)


def load_and_process_dataset(dataset_name, dataset_config_name, text_column):
    """加载和处理数据集"""
    try:
        raw_dataset = load_dataset(
            dataset_name, dataset_config_name, split='train', streaming=True)
        num_samples = 256 if dataset_name != 'codeparrot/codeparrot-clean-valid' else 128
        sampled_data = list(raw_dataset.take(num_samples))
        raw_dataset = Dataset.from_list(sampled_data)
        logger.info(f"成功加载并流式采样 {num_samples} 条样本进行分析。")
    except Exception as e:
        logger.error(f"错误：无法加载数据集 {dataset_name}。跳过。错误: {e}", exc_info=True)
        return None

    all_texts = []
    text_columns = text_column if isinstance(
        text_column, list) else [text_column]

    for sample in raw_dataset:
        combined_text = " ".join(str(sample.get(col, '')) for col in text_columns
                                 if sample.get(col) and isinstance(sample.get(col), str))
        if combined_text.strip():
            all_texts.append(combined_text.strip())

    full_text = "\n\n".join(all_texts)
    if not full_text:
        logger.warning(f"警告: 数据集 {dataset_name} 的指定列中未找到有效文本。跳过。")
        return None

    return full_text


def run_analysis_for_dataset(model, tokenizer, dataset_config, max_length, stride, batch_size, model_config):
    """为特定数据集运行分析"""
    dataset_name, dataset_config_name, text_column = dataset_config
    logger.info(
        f"{'='*20}\n开始为数据集 {dataset_name} ({dataset_config_name}) 生成行为特征\n{'='*20}")

    full_text = load_and_process_dataset(
        dataset_name, dataset_config_name, text_column)
    if full_text is None:
        return None

    test_encoding = tokenizer(full_text, return_tensors='pt', truncation=False)
    single_row_dataset = Dataset.from_dict(
        {'input_ids': test_encoding['input_ids']})

    def create_sliding_window_simple(examples):
        full_input_ids = examples['input_ids'][0]
        windows = {'input_ids': []}
        for i in range(0, len(full_input_ids) - max_length + 1, stride):
            windows['input_ids'].append(full_input_ids[i:i + max_length])
        return windows

    processed_dataset = single_row_dataset.map(create_sliding_window_simple,
                                               batched=True,
                                               remove_columns=single_row_dataset.column_names)
    processed_dataset.set_format(type='torch', columns=['input_ids'])

    accelerator = Accelerator()
    dataloader = DataLoader(processed_dataset, batch_size=batch_size)
    model, dataloader = accelerator.prepare(model, dataloader)
    unwrapped_model = accelerator.unwrap_model(model)

    num_experts = getattr(unwrapped_model.config,
                          model_config['num_experts_attr'], 0)
    num_layers = getattr(unwrapped_model.config,
                         model_config['num_layers_attr'], 0)

    stats_dict = {
        i: {'activation_counts': torch.zeros(num_experts, dtype=torch.long),
            'soft_activations': torch.zeros(num_experts, dtype=torch.float32),
            'sum_of_entropies': torch.zeros(num_experts, dtype=torch.float32),
            'sum_of_probs': torch.zeros(num_experts, dtype=torch.float32),
            'sum_of_squared_probs': torch.zeros(num_experts, dtype=torch.float32),
            'total_tokens': 0}
        for i in range(num_layers)
    }

    fisher_sum_sq = {}
    for name, param in model.named_parameters():
        if 'experts' in name and param is not None and hasattr(param, 'data'):
            try:
                fisher_sum_sq[name] = torch.zeros_like(param.data)
            except Exception as e:
                logger.warning(f"无法为参数 {name} 初始化Fisher信息: {e}")
                continue

    handles = []
    for i in range(num_layers):
        hook_fn = partial(top_k_hook, stats_dict, i)
        handle = register_gating_hook(
            unwrapped_model, i, hook_fn, model_config)
        if handle:
            handles.append(handle)

    logger.info(f"为 {len(handles)} 个MoE层注册了钩子。开始数据收集...")
    model.train()
    total_batches = len(dataloader)
    log_interval = max(1, total_batches // 10)
    num_samples = 0

    for i, batch in enumerate(tqdm(dataloader, desc=f"分析 {dataset_name}")):
        outputs = model(
            input_ids=batch["input_ids"], labels=batch["input_ids"])
        loss = outputs.loss
        accelerator.backward(loss)

        for name, param in model.named_parameters():
            if 'experts' in name and name in fisher_sum_sq:
                if param.grad is not None:
                    fisher_sum_sq[name] += param.grad.data ** 2

        model.zero_grad()
        num_samples += batch["input_ids"].shape[0] * accelerator.num_processes

        if (i + 1) % log_interval == 0 or (i + 1) == total_batches:
            logger.info(
                f"数据集 [{dataset_name}] 进度: Batch {i + 1}/{total_batches} ({(i + 1) / total_batches:.1%})")

    for handle in handles:
        handle.remove()

    logger.info("数据收集完成，开始计算最终特征...")
    results = []
    epsilon = 1e-9

    for layer_id, layer_stats in stats_dict.items():
        if layer_stats['total_tokens'] == 0:
            continue

        total_tokens = layer_stats['total_tokens']
        act_counts = layer_stats['activation_counts'].float().cpu()
        soft_act_total = layer_stats['soft_activations'].cpu()
        sum_entropies = layer_stats['sum_of_entropies'].cpu()
        sum_probs = layer_stats['sum_of_probs'].cpu()
        sum_sq_probs = layer_stats['sum_of_squared_probs'].cpu()

        activation_freq = act_counts / total_tokens
        avg_confidence = soft_act_total / (act_counts + epsilon)
        avg_entropy = sum_entropies / (act_counts + epsilon)
        mean_probs = sum_probs / total_tokens
        mean_sq_probs = sum_sq_probs / total_tokens
        variance = mean_sq_probs - torch.pow(mean_probs, 2)

        experts_layer = get_experts_layer(
            unwrapped_model, layer_id, model_config)
        if experts_layer is None:
            continue

        for j in range(num_experts):
            expert_fisher_values = []
            try:
                expert = experts_layer[j]
                for name, param in expert.named_parameters():
                    full_name = model_config['expert_param_name_template'].format(
                        layer_idx=layer_id, expert_idx=j, param_name=name)
                    if full_name in fisher_sum_sq:
                        fisher_info = fisher_sum_sq[full_name] / num_samples
                        expert_fisher_values.append(fisher_info.flatten())

                if expert_fisher_values:
                    all_fisher_for_expert = torch.cat(
                        expert_fisher_values).cpu()
                    results.append({
                        'layer': layer_id,
                        'expert': j,
                        'activation_freq': activation_freq[j].item(),
                        'avg_confidence': avg_confidence[j].item(),
                        'avg_entropy': avg_entropy[j].item(),
                        'specialization_var': variance[j].item(),
                        'fisher_mean': all_fisher_for_expert.mean().item(),
                        'fisher_std': all_fisher_for_expert.std().item(),
                        'fisher_var': all_fisher_for_expert.var().item()
                    })
            except (AttributeError, IndexError):
                logger.warning(
                    f"无法处理专家 {j} 在层 {layer_id} 的数据，可能是因为该专家不存在或没有参数。")
                continue

    return pd.DataFrame(results)


def parse_args():
    parser = argparse.ArgumentParser(
        description="Run cross-domain soft activation profiling on MoE experts.")
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--model_name', type=str, default='deepseek-ai/deepseek-moe-16b-base',
                        help='模型名称，支持：deepseek-ai/DeepSeek-V2-Lite, microsoft/Phi-tiny-MoE-instruct')
    parser.add_argument('--max_length', type=int, default=1024)
    parser.add_argument('--stride', type=int, default=512)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--max_clusters', type=int, default=2,
                        help="聚类分析时的最大簇数")
    parser.add_argument('--cluster_metric', type=str, default='silhouette', choices=[
                        'silhouette', 'calinski_harabasz', 'davies_bouldin'], help="用于聚类分析的指标")
    return parser.parse_args()


def expert_grouping(model_name, max_length=1024, stride=512, batch_size=8, max_cluster=2, cluster_metric='calinski_harabasz'):

    # 获取模型配置
    model_config = get_model_config(model_name)
    logger.info(f"Using model config: {model_config}")

    DATASET_CONFIGS = {
        'Wikitext': ('wikitext', 'wikitext-2-raw-v1', 'text'),
        'PIQA': ('piqa', 'plain_text', ['goal', 'sol1', 'sol2']),
        # 'C4': ('c4', 'en', 'text'),
        'PTB': ('ptb_text_only', 'penn_treebank', 'sentence'),
        # 'Code': ('codeparrot/codeparrot-clean-valid', 'default', 'content'),
    }

    logger.info(f"Loading model: {model_name}...")
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, device_map='auto', trust_remote_code=True, torch_dtype='auto')

    df_static_features = calculate_static_metrics(model, model_config)

    all_dynamic_dfs = {}
    for name, config in DATASET_CONFIGS.items():
        df_behavioral = run_analysis_for_dataset(
            model, tokenizer, config, max_length, stride, batch_size, model_config)
        if df_behavioral is not None:
            all_dynamic_dfs[name] = df_behavioral

    logger.info("\n--- Merging all dynamic features into a single DataFrame ---")
    if not all_dynamic_dfs:
        logger.error("未能生成任何行为特征，无法合并。程序终止。")
        return

    final_df = df_static_features.copy()

    for name, df_dyn in all_dynamic_dfs.items():
        df_dyn_renamed = df_dyn.rename(columns=lambda c: f"{name}_{c}" if c not in [
                                       'layer', 'expert'] else c)
        final_df = pd.merge(final_df, df_dyn_renamed, on=[
                            'layer', 'expert'], how='left')
        
    import os
    os.makedirs('results', exist_ok=True)

    # 模型特定的输出文件名
    model_type = 'deepseek' if 'deepseek' in model_name.lower() else 'phi'
    output_filename = f"results/deepseek/expert_features_combined_{model_type}.csv"
    final_df.to_csv(output_filename, index=False)

    logger.info(f"\n终极版特征表已保存到: {output_filename}")
    logger.info("部分结果展示:")
    logger.info(final_df.head().to_string())

    # 聚类分析
    final_df['cluster_role'] = -2

    logger.info(f"\n--- 开始对最终画像进行聚类 ---")

    cluster_metrics = {
        'silhouette': silhouette_score,
        'calinski_harabasz': calinski_harabasz_score,
        'davies_bouldin': davies_bouldin_score
    }
    metric_func = cluster_metrics.get(cluster_metric, silhouette_score)
    for layer_id in sorted(final_df['layer'].unique()):
        logger.info(f"\n{'='*15} 正在分析 Layer {layer_id} {'='*15}")
        layer_df = final_df[final_df['layer'] == layer_id].copy()

        feature_columns = [col for col in layer_df.columns if col not in [
            'layer', 'expert', 'cluster_role']]
        features = layer_df[feature_columns].values

        scaler = StandardScaler()
        scaled_features = scaler.fit_transform(features)

        num_experts_in_layer = len(layer_df)
        max_k = min(max_cluster, num_experts_in_layer - 1)
        k_range = range(2, max_k + 1)

        scores = {}
        logger.info(f"Layer {layer_id}: 正在为K值范围 {list(k_range)} 测试系数...")

        for k in k_range:
            kmeans = KMeans(n_clusters=k, random_state=42,
                            n_init='auto', init='k-means++', max_iter=300)
            labels = kmeans.fit_predict(scaled_features)

            if len(np.unique(labels)) > 1:
                scores[k] = metric_func(scaled_features, labels)
                logger.info(
                    f"  - 对于 K={k}, 系数: {scores[k]:.4f}")
            else:
                logger.warning(f"  - 对于 K={k}, 只形成了一个簇，无法计算系数。")

        if not scores:
            logger.error(f"Layer {layer_id}: 无法为任何K值计算有效的系数。跳过此层。")
            continue

        best_k = max(scores, key=scores.get)
        best_score = scores[best_k]
        logger.info(
            f"--- Layer {layer_id}: 找到最佳分组数 K = {best_k}, 系数: {best_score:.4f} ---")

        final_kmeans = KMeans(n_clusters=1, random_state=42,
                              n_init='auto', init='k-means++', max_iter=300)
        final_labels = final_kmeans.fit_predict(scaled_features)

        final_df.loc[final_df['layer'] == layer_id,
                     'cluster_role'] = final_labels

        layer_df['cluster_role'] = final_labels
        cluster_centers = layer_df.groupby('cluster_role')[
            feature_columns].mean()
        logger.info(
            f"\nLayer {layer_id} 的最终角色划分结果 (平均特征值):\n{cluster_centers.to_string()}")
        logger.info("\n每个角色的专家成员列表:")
        for i in range(best_k):
            members = layer_df[layer_df['cluster_role']
                               == i]['expert'].tolist()
            logger.info(f"  角色 {i} (专家数量: {len(members)}): {sorted(members)}")

    logger.info("\n分析脚本运行结束。")
    cluster_output_filename = f"results/deepseek/expert_features_with_auto_clusters_{model_type}.csv"
    final_df.to_csv(cluster_output_filename, index=False)
    logger.info(f"\n带有自动分组角色的终极版特征表已保存到: {cluster_output_filename}")
    
    layer2group2expert = {}
    for layer in final_df['layer'].unique():
        layer2group2expert[layer] = final_df[final_df['layer'] == layer][
            ['expert', 'cluster_role']].set_index('expert').to_dict()['cluster_role']
        
    del model, tokenizer, df_static_features, all_dynamic_dfs, df_dyn_renamed
    torch.cuda.empty_cache()
    import gc
    gc.collect()
        
    return final_df, layer2group2expert


if __name__ == "__main__":
    main()
