import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestRegressor
from sklearn.feature_selection import mutual_info_regression
import os
import sys
import warnings

def auto_determine_feature_weights_pca(feature_df, method='pca'):
    """
    自动确定特征权重
    
    Args:
        feature_df: 专家特征DataFrame
        method: 'pca', 'variance', 'mutual_info', 'correlation'
    
    Returns:
        dict: 特征权重字典
    """
    print(f"\n=== 使用{method}方法自动确定特征权重 ===")
    
    # 选择数值特征
    feature_columns = [col for col in feature_df.columns 
                      if col not in ['layer', 'expert', 'group_id', 'cluster_role']]
    
    # 处理缺失值
    feature_data = feature_df[feature_columns].fillna(0)
    
    if method == 'pca':
        return _pca_weights(feature_data, feature_columns)
    elif method == 'variance':
        return _variance_weights(feature_data, feature_columns)
    elif method == 'mutual_info':
        return _mutual_info_weights(feature_df, feature_columns)
    elif method == 'correlation':
        return _correlation_weights(feature_data, feature_columns)
    else:
        raise ValueError(f"未知方法: {method}")


def _pca_weights(feature_data, feature_columns):
    """基于PCA的权重计算"""
    scaler = StandardScaler()
    normalized_data = scaler.fit_transform(feature_data)
    
    # 进行PCA
    pca = PCA()
    pca.fit(normalized_data)
    
    # 使用第一主成分的载荷作为权重
    first_pc_weights = np.abs(pca.components_[0])
    
    # 标准化权重
    weights = first_pc_weights / np.sum(first_pc_weights)
    
    weight_dict = dict(zip(feature_columns, weights))
    
    print("PCA权重 (基于第一主成分):")
    for feature, weight in sorted(weight_dict.items(), key=lambda x: x[1], reverse=True):
        print(f"  {feature}: {weight:.4f}")
    
    return weight_dict


def _variance_weights(feature_data, feature_columns):
    """基于方差的权重计算"""
    scaler = StandardScaler()
    normalized_data = scaler.fit_transform(feature_data)
    
    # 计算标准化后的方差
    variances = np.var(normalized_data, axis=0)
    
    # 方差越大，权重越高
    weights = variances / np.sum(variances)
    
    weight_dict = dict(zip(feature_columns, weights))
    
    print("方差权重:")
    for feature, weight in sorted(weight_dict.items(), key=lambda x: x[1], reverse=True):
        print(f"  {feature}: {weight:.4f}")
    
    return weight_dict


def _mutual_info_weights(feature_df, feature_columns):
    """基于互信息的权重计算"""
    # 创建一个综合目标变量（基于多个关键指标的组合）
    target_indicators = []
    
    # 寻找激活频率相关的列
    activation_cols = [col for col in feature_columns if 'activation_freq' in col]
    if activation_cols:
        target_indicators.append(feature_df[activation_cols].mean(axis=1))
    
    # 寻找置信度相关的列
    confidence_cols = [col for col in feature_columns if 'confidence' in col]
    if confidence_cols:
        target_indicators.append(feature_df[confidence_cols].mean(axis=1))
    
    # 参数范数
    if 'param_norm' in feature_columns:
        target_indicators.append(feature_df['param_norm'])
    
    # 组合目标变量
    if target_indicators:
        target = np.mean(target_indicators, axis=0)
    else:
        # 如果没有合适的目标，使用PCA第一主成分
        return _pca_weights(feature_df[feature_columns].fillna(0), feature_columns)
    
    # 计算互信息
    feature_data = feature_df[feature_columns].fillna(0)
    mi_scores = mutual_info_regression(feature_data, target, random_state=42)
    
    # 标准化权重
    weights = mi_scores / np.sum(mi_scores)
    
    weight_dict = dict(zip(feature_columns, weights))
    
    print("互信息权重:")
    for feature, weight in sorted(weight_dict.items(), key=lambda x: x[1], reverse=True):
        print(f"  {feature}: {weight:.4f}")
    
    return weight_dict


def _correlation_weights(feature_data, feature_columns):
    """基于特征间相关性的权重计算"""
    # 计算相关系数矩阵
    corr_matrix = np.abs(feature_data.corr().values)
    
    # 计算每个特征与其他特征的平均相关性
    avg_correlations = np.mean(corr_matrix, axis=1)
    
    # 相关性低的特征权重高（更独特的信息）
    weights = (1 - avg_correlations) / np.sum(1 - avg_correlations)
    
    weight_dict = dict(zip(feature_columns, weights))
    
    print("相关性权重 (低相关性=高权重):")
    for feature, weight in sorted(weight_dict.items(), key=lambda x: x[1], reverse=True):
        print(f"  {feature}: {weight:.4f}")
    
    return weight_dict


def compute_composite_features(layer_df):
    """计算复合特征"""

    # 数据集列表
    datasets = ['Wikitext', 
                'PIQA', 
                'PTB'
                ]

    # 计算跨数据集的平均特征
    for feature in ['activation_freq', 'avg_confidence', 'avg_entropy',
                    'specialization_var', 'fisher_mean', 'fisher_std', 'fisher_var']:
        dataset_cols = [
            f"{dataset}_{feature}" for dataset in datasets if f"{dataset}_{feature}" in layer_df.columns]
        if dataset_cols:
            layer_df[f'avg_{feature}'] = layer_df[dataset_cols].mean(axis=1)
            layer_df[f'std_{feature}'] = layer_df[dataset_cols].std(axis=1)

    # 计算专业化强度：激活频率 * (1 - 平均熵)
    if 'avg_activation_freq' in layer_df.columns and 'avg_avg_entropy' in layer_df.columns:
        layer_df['specialization_strength'] = layer_df['avg_activation_freq'] * \
            (1 - layer_df['avg_avg_entropy'])

    # 计算稳定性指标：Fisher信息的变异系数
    if 'avg_fisher_mean' in layer_df.columns and 'avg_fisher_std' in layer_df.columns:
        layer_df['fisher_stability'] = layer_df['avg_fisher_mean'] / \
            (layer_df['avg_fisher_std'] + 1e-8)

    return layer_df


def aggregate_experts_by_group(layer_df, expert_to_group):
    """按组聚合专家特征"""

    # 添加组ID列
    layer_df['group_id'] = layer_df['expert'].map(expert_to_group)

    # 特征聚合策略
    aggregation_funcs = {
        'param_norm': ['mean', 'std', 'sum'],           # 参数范数：平均、标准差、总和
        'avg_activation_freq': ['mean', 'std'],         # 激活频率：平均、标准差
        'avg_avg_confidence': ['mean', 'std'],          # 置信度：平均、标准差
        'avg_avg_entropy': ['mean', 'std'],             # 熵：平均、标准差
        'avg_specialization_var': ['mean'],             # 专业化方差：平均
        'avg_fisher_mean': ['mean', 'std'],             # Fisher均值：平均、标准差
        'avg_fisher_std': ['mean'],                     # Fisher标准差：平均
        'avg_fisher_var': ['mean'],                     # Fisher方差：平均
        'specialization_strength': ['mean', 'std'],     # 专业化强度：平均、标准差
        'fisher_stability': ['mean', 'std'],            # Fisher稳定性：平均、标准差
    }

    group_features = {}

    for group_id in layer_df['group_id'].unique():
        if pd.isna(group_id):
            continue

        group_data = layer_df[layer_df['group_id'] == group_id]
        group_stats = {}

        for feature, funcs in aggregation_funcs.items():
            if feature in group_data.columns:
                for func in funcs:
                    if func == 'mean':
                        group_stats[f'{feature}_mean'] = group_data[feature].mean()
                    elif func == 'std':
                        group_stats[f'{feature}_std'] = group_data[feature].std()
                    elif func == 'sum':
                        group_stats[f'{feature}_sum'] = group_data[feature].sum()

        # 组大小
        group_stats['group_size'] = len(group_data)

        group_features[int(group_id)] = group_stats

    return group_features


def calculate_group_scores(group_features, importance_weights):
    """简化版本 - 减少输出"""

    if not group_features:
        return {}

    # 静默处理权重
    if importance_weights is not None:
        if isinstance(importance_weights, dict) and any(key in importance_weights for key in ['param_norm', 'activation_freq', 'avg_confidence']):
            group_weights = convert_expert_weights_to_group_weights(
                importance_weights)
        else:
            group_weights = importance_weights
    else:
        group_weights = get_default_group_weights()

    # 特征提取（保持原逻辑，去掉print）
    feature_extractors = [
        ('param_norm_sum', lambda f: f.get('param_norm_sum', 0)),
        ('avg_activation_freq_mean', lambda f: f.get('avg_activation_freq_mean', 0)),
        ('avg_avg_confidence_mean', lambda f: f.get('avg_avg_confidence_mean', 0)),
        ('specialization', lambda f: 1 - f.get('avg_avg_entropy_mean', 1)),
        ('avg_specialization_var_mean', lambda f: f.get(
            'avg_specialization_var_mean', 0)),
        ('avg_fisher_mean_mean', lambda f: f.get('avg_fisher_mean_mean', 0)),
        ('specialization_strength_mean', lambda f: f.get(
            'specialization_strength_mean', 0)),
        ('fisher_stability_mean', lambda f: f.get('fisher_stability_mean', 0)),
        ('activation_freq_stability', lambda f: 1 /
         (f.get('avg_activation_freq_std', 0) + 1e-8)),
        ('confidence_stability', lambda f: 1 /
         (f.get('avg_avg_confidence_std', 0) + 1e-8)),
        ('group_size', lambda f: np.log(f.get('group_size', 1) + 1))
    ]

    feature_matrix = []
    group_ids = []
    feature_names = [name for name, _ in feature_extractors]

    for group_id, features in group_features.items():
        feature_vector = [extractor(features)
                          for _, extractor in feature_extractors]
        feature_vector = [0 if not np.isfinite(
            val) else val for val in feature_vector]
        feature_matrix.append(feature_vector)
        group_ids.append(group_id)

    if not feature_matrix:
        return {}

    feature_matrix = np.array(feature_matrix)

    # 标准化
    scaler = StandardScaler()
    try:
        normalized_features = scaler.fit_transform(feature_matrix)
    except:
        feature_min = np.min(feature_matrix, axis=0)
        feature_max = np.max(feature_matrix, axis=0)
        feature_range = feature_max - feature_min
        feature_range[feature_range == 0] = 1
        normalized_features = (feature_matrix - feature_min) / feature_range

    # 计算得分
    weights = np.array([group_weights.get(name, 1.0/len(feature_names))
                       for name in feature_names])
    weights = weights / np.sum(weights)
    importance_scores = np.dot(normalized_features, weights)

    return dict(zip(group_ids, importance_scores))


def convert_expert_weights_to_group_weights(expert_weights):
    """
    将专家级权重转换为组级权重
    """
    print("转换专家级权重为组级权重...")
    
    # 映射关系：专家特征 -> 组聚合特征
    group_weights = {
        'param_norm_sum': expert_weights.get('param_norm', 0.15),
        'avg_activation_freq_mean': expert_weights.get('activation_freq', 0.20),
        'avg_avg_confidence_mean': expert_weights.get('avg_confidence', 0.15),
        'specialization': expert_weights.get('avg_entropy', 0.10),  # 注意：实际计算时是 1-entropy
        'avg_specialization_var_mean': expert_weights.get('specialization_var', 0.15),
        'avg_fisher_mean_mean': expert_weights.get('fisher_mean', 0.15),
        'specialization_strength_mean': 0.08,                                           # 修正
        'fisher_stability_mean': expert_weights.get('fisher_std', 0.05),
        'activation_freq_stability': 0.03,
        'confidence_stability': 0.03,
        'group_size': 0.04,
    }
    
    # 标准化权重
    total = sum(group_weights.values())
    normalized_weights = {k: v/total for k, v in group_weights.items()}
    
    print("转换后的组级权重:")
    for feature, weight in sorted(normalized_weights.items(), key=lambda x: x[1], reverse=True):
        print(f"  {feature}: {weight:.4f}")
    
    return normalized_weights


def get_default_group_weights():
    """
    获取默认的组级权重
    """
    return {
        'param_norm_sum': 0.18,
        'avg_activation_freq_mean': 0.18,
        'avg_avg_confidence_mean': 0.12,
        'specialization': 0.15,
        'avg_specialization_var_mean': 0.10,
        'avg_fisher_mean_mean': 0.12,
        'specialization_strength_mean': 0.08,
        'fisher_stability_mean': 0.05,
        'activation_freq_stability': 0.01,
        'confidence_stability': 0.01,
        'group_size': 0.00,
    }


# 更新主要的权重计算流程
def calculate_group_importance_scores(feature_df, layer2group2expert, importance_weights=None, use_auto_weights=True, auto_method='pca'):
    """
    根据专家特征计算每个组的重要性得分 - 支持PCA权重
    
    Args:
        feature_df: 专家特征DataFrame
        layer2group2expert: {layer_id: {expert_id: group_id}} 映射
        importance_weights: 各特征的权重字典（优先级最高）
        use_auto_weights: 是否使用自动权重确定
        auto_method: 自动权重方法 ('pca', 'variance', etc.)
    
    Returns:
        dict: {layer_id: {group_id: importance_score}}
    """

    # 权重优先级：手动传入 > 自动确定 > 默认权重
    if importance_weights is not None:
        final_weights = importance_weights
    elif use_auto_weights:
        print(f"使用{auto_method}自动权重...")
        expert_weights = auto_determine_feature_weights_pca(
            feature_df, method=auto_method)
        final_weights = convert_expert_weights_to_group_weights(expert_weights)
    else:
        final_weights = None

    layer_group_importance = {}

    for layer_id, expert_to_group in layer2group2expert.items():
        print(f"\nHandle with Layer {layer_id}...")

        # 获取该层的专家数据
        layer_df = feature_df[feature_df['layer'] == layer_id].copy()

        if layer_df.empty:
            print(f"Layer {layer_id} 没有数据，跳过")
            continue

        # 计算综合特征
        layer_df = compute_composite_features(layer_df)

        # 按组聚合专家特征
        group_features = aggregate_experts_by_group(layer_df, expert_to_group)

        # 计算每个组的重要性得分
        group_importance = calculate_group_scores(
            group_features, final_weights)

        layer_group_importance[layer_id] = group_importance

        # 输出该层的组重要性
        print(f"Layer {layer_id} 组重要性得分:")
        for group_id, score in sorted(group_importance.items()):
            print(f"  组 {group_id}: {score:.4f}")

    return layer_group_importance


def allocate_group_compression_params(
    layer_group_importance,
    target_layer_moe_params,  # 修改：使用目标参数量代替压缩率
    layer2group2expert,       # 新增：需要知道组内专家数量
    min_ratio=0.4,         # 最小保留比例
    max_ratio=0.6,        # 最大保留比例
    smoothness_factor=0.1     # 平滑因子
):
    """
    基于目标参数量预算和组重要性为每层的每个组分配参数量
    
    Args:
        layer_group_importance: {layer_id: {group_id: importance_score}}
        target_layer_moe_params: {layer_id: target_params} 每层MoE的目标参数量
        layer2group2expert: {layer_id: {expert_id: group_id}} 专家到组的映射
        smoothness_factor: 平滑因子(0-1)，控制组间分配差异
    
    Returns:
        dict: {layer_id: {group_id: group_params}} 每组分配的参数量
    """
    print("\n=== 基于参数量预算分配组内参数 ===")
    layer_group_params = {}

    for layer_id, group_importance in layer_group_importance.items():
        if not group_importance:
            continue

        # 获取该层的目标参数量
        if isinstance(layer_id, np.integer):
            layer_key = int(layer_id)
        else:
            layer_key = layer_id

        target_params = target_layer_moe_params.get(layer_key, 0)
        if target_params == 0:
            print(f"警告: 层 {layer_id} 没有目标参数量，跳过")
            continue

        print(f"\n层 {layer_id} (目标参数量: {target_params:,}):")

        # 统计每组专家数量
        expert_to_group = layer2group2expert.get(layer_id, {})
        group_to_experts = {}
        for expert_id, group_id in expert_to_group.items():
            if group_id not in group_to_experts:
                group_to_experts[group_id] = []
            group_to_experts[group_id].append(expert_id)

        group_sizes = {group_id: len(experts)
                       for group_id, experts in group_to_experts.items()}

        # 确保所有组都有重要性分数
        for group_id in group_sizes:
            if group_id not in group_importance:
                group_importance[group_id] = 0.0
                print(f"  警告: 组 {group_id} 没有重要性分数，设为 0.0")

        # 组合重要性分数和组大小
        groups = list(set(group_importance.keys()) | set(group_sizes.keys()))

        if len(groups) == 1:
            # 只有一个组，分配全部参数量
            group_params = {groups[0]: target_params}
        else:
            # 标准化重要性得分
            importance_values = np.array(
                [group_importance.get(g, 0) for g in groups])
            if np.max(importance_values) == np.min(importance_values):
                importance_normalized = np.ones_like(
                    importance_values) / len(importance_values)
            else:
                importance_normalized = (importance_values - np.min(importance_values)) / (
                    np.max(importance_values) - np.min(importance_values) + 1e-8)


            # 计算组大小权重：专家数量越多，权重越大
            # 计算组大小权重：专家数量越多，权重越大
            size_values = np.array([group_sizes.get(g, 1) for g in groups])
            size_weights = size_values / np.sum(size_values)

            # 结合重要性和组大小（加权平均）
            combined_weights = smoothness_factor * size_weights + \
                (1 - smoothness_factor) * importance_normalized

            # 同时应用最小和最大权重限制
            min_weight = size_weights * min_ratio  # 最小权重
            max_weight = max_ratio * size_weights  # 最大权重随专家数量缩放

            # 应用限制
            combined_weights = np.minimum(np.maximum(
                combined_weights, min_weight), max_weight)
            combined_weights = combined_weights / np.sum(combined_weights)  # 重新归一化

            # 分配参数量
            raw_params = combined_weights * target_params

            # 确保分配总量等于目标参数量
            adjustment_factor = target_params / np.sum(raw_params)
            adjusted_params = raw_params * adjustment_factor

            # 构建结果字典
            group_params = {g: int(p) for g, p in zip(groups, adjusted_params)}

        layer_group_params[layer_id] = group_params

        # 输出分配结果
        print("  组参数量分配:")
        total_allocated = 0
        for group_id in sorted(group_params.keys()):
            params = group_params[group_id]
            importance = group_importance.get(group_id, 0)
            experts_count = group_sizes.get(group_id, 0)
            total_allocated += params
            print(
                f"    组 {group_id}: 重要性={importance:.3f}, 专家数={experts_count}, 分配参数量={params:,}")

        print(f"  总分配参数量: {total_allocated:,} / 目标: {target_params:,}")
        if abs(total_allocated - target_params) > len(group_params):
            print(f"  警告: 分配参数量与目标参数量相差 {total_allocated - target_params:,}")

    return layer_group_params


def convert_to_compressor_format(layer2group2expert):
    """
    将分组和压缩率转换为压缩器所需的格式 - 修改为字典格式
    
    Args:
        layer2group2expert: {layer_id: {expert_id: group_id}}
    
    Returns:
        - layers_expert_groups: {layer_id: {group_id: [expert_ids]}}
    """
    layers_expert_groups = {}  # 字典格式

    # 按层ID排序
    sorted_layers = sorted(layer2group2expert.keys())

    for layer_id in sorted_layers:
        expert_to_group = layer2group2expert[layer_id]

        # 构建专家分组字典：{group_id: [expert_ids]}
        group_to_experts = {}
        for expert_id, group_id in expert_to_group.items():
            if group_id not in group_to_experts:
                group_to_experts[group_id] = []
            group_to_experts[group_id].append(expert_id)

        # 存储为字典格式
        layers_expert_groups[layer_id] = group_to_experts

    return layers_expert_groups


def convert_importance_to_compressor_format(
    layer_group_importance,
    layer2group2expert,
    layer_compression_ratios,
    min_group_ratio=0.1,
    max_group_ratio=0.9,
    smoothness_factor=0.3
):
    """
    直接从重要性得分转换为压缩器格式
    """

    # 第一步：分配压缩率
    layer_group_ratios = allocate_group_compression_ratios(
        layer_group_importance, layer_compression_ratios,
        min_group_ratio, max_group_ratio, smoothness_factor
    )

    # 第二步：格式转换
    return convert_to_compressor_format(
        layer2group2expert, layer_group_ratios, layer_compression_ratios
    )
