import gc
from datasets import load_dataset
import warnings
from accelerate import Accelerator
from tqdm import tqdm
from utils.common_utils import set_random_seed
import math
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.optim as optim
import torch.nn as nn
import torch
from scipy.optimize import linprog, minimize
import torch.nn.functional as F
import random


warnings.filterwarnings("ignore")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class MINE(nn.Module):

    def __init__(self, input_dim_x, input_dim_z, hidden_dim=128, dropout=0.1):
        super(MINE, self).__init__()
        self.fc1 = nn.Linear(input_dim_x + input_dim_z, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm1 = nn.LayerNorm(hidden_dim)
        self.layer_norm2 = nn.LayerNorm(hidden_dim)
        self.activation = nn.GELU()  # 使用GELU激活函数

    def forward(self, x, z):
        xz_concat = torch.cat([x, z], dim=1)
        h = self.activation(self.layer_norm1(self.fc1(xz_concat)))
        h = self.dropout(h)
        h = self.activation(self.layer_norm2(self.fc2(h)))
        h = self.dropout(h)
        output = self.fc3(h)
        return output


def estimate_mutual_information(features_x, features_z, mine_hidden_dim=256,
                                mine_epochs=50, mine_lr=1e-4, mine_batch_size=128,
                                verbose=True, early_stopping_patience=30):
    """
    优化版本的互信息估计函数
    """
    num_samples = features_x.size(0)
    dim_x = features_x.size(1)
    dim_z = features_z.size(1)

    # 输入验证
    if num_samples < mine_batch_size:
        mine_batch_size = max(num_samples // 4, 2)
        if verbose:
            print(f"调整批次大小为: {mine_batch_size}")

    # 特征标准化 - 提高训练稳定性
    features_x = F.normalize(features_x, p=2, dim=1)
    features_z = F.normalize(features_z, p=2, dim=1)

    mine_net = MINE(dim_x, dim_z, mine_hidden_dim).to(device)
    optimizer = optim.AdamW(mine_net.parameters(),
                            lr=mine_lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=5, factor=0.7)

    best_mi = float('-inf')
    patience_counter = 0
    best_model_state = None

    # 预先生成全局边缘分布索引
    global_marginal_indices = torch.randperm(num_samples)

    for epoch in range(mine_epochs):
        mine_net.train()
        epoch_loss_sum = 0
        num_batches = 0

        # 每个epoch重新打乱数据
        epoch_indices = torch.randperm(num_samples)

        for i in range(0, num_samples, mine_batch_size):
            batch_indices = epoch_indices[i:i + mine_batch_size]
            if len(batch_indices) < 2:  # 确保有足够的样本
                continue

            optimizer.zero_grad()

            batch_x = features_x[batch_indices].to(device)
            batch_z_joint = features_z[batch_indices].to(device)

            # 改进的边缘分布采样 - 从全局采样而不是批次内shuffle
            batch_size_current = len(batch_indices)
            marginal_start = (i // mine_batch_size *
                              batch_size_current) % num_samples
            marginal_indices = global_marginal_indices[marginal_start:
                                                       marginal_start + batch_size_current]

            # 如果超出范围，则循环采样
            if marginal_start + batch_size_current > num_samples:
                remaining = (marginal_start + batch_size_current) - num_samples
                marginal_indices = torch.cat([
                    global_marginal_indices[marginal_start:],
                    global_marginal_indices[:remaining]
                ])

            batch_z_marginal = features_z[marginal_indices].to(device)

            # 前向传播
            t_joint = mine_net(batch_x, batch_z_joint)
            t_marginal = mine_net(batch_x, batch_z_marginal)

            # 改进的损失函数计算 - 更稳定的数值计算
            mean_t_joint = torch.mean(t_joint)

            # 使用更稳定的方式计算 log E[exp(T_marginal)]
            # 添加温度参数防止数值溢出
            temperature = 1.0
            t_marginal_scaled = t_marginal / temperature
            max_t_marginal = torch.max(t_marginal_scaled)

            # 数值稳定的logsumexp
            log_mean_exp_t_marginal = max_t_marginal + torch.log(
                torch.mean(
                    torch.exp(t_marginal_scaled - max_t_marginal)) + 1e-8
            )
            log_mean_exp_t_marginal = log_mean_exp_t_marginal * temperature

            loss = -(mean_t_joint - log_mean_exp_t_marginal)

            # 添加正则化项
            l2_reg = 1e-5 * sum(p.pow(2.0).sum()
                                for p in mine_net.parameters())
            loss = loss + l2_reg

            loss.backward()

            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(mine_net.parameters(), max_norm=1.0)

            optimizer.step()

            epoch_loss_sum += loss.item()
            num_batches += 1

        # 每隔几个epoch重新生成边缘分布索引
        if epoch % 10 == 0:
            global_marginal_indices = torch.randperm(num_samples)

        if num_batches > 0:
            avg_epoch_loss = epoch_loss_sum / num_batches
            current_mi = -avg_epoch_loss
            scheduler.step(avg_epoch_loss)

            # Early stopping with model state saving
            if current_mi > best_mi:
                best_mi = current_mi
                patience_counter = 0
                best_model_state = mine_net.state_dict().copy()
            else:
                patience_counter += 1

            if patience_counter >= early_stopping_patience:
                if verbose:
                    print(f"Early stopping at epoch {epoch+1}")
                break

            if verbose and (epoch + 1) % 10 == 0:
                current_lr = optimizer.param_groups[0]['lr']
                print(f"MINE Epoch [{epoch+1}/{mine_epochs}], Loss: {avg_epoch_loss:.4f}, "
                      f"MI: {current_mi:.4f} nats, LR: {current_lr:.2e}")

    # 加载最佳模型进行最终评估
    if best_model_state is not None:
        mine_net.load_state_dict(best_model_state)

    # 改进的最终评估
    final_mi = perform_robust_final_evaluation(
        mine_net, features_x, features_z, mine_batch_size, verbose
    )

    return final_mi


def perform_robust_final_evaluation(mine_net, features_x, features_z, batch_size, verbose=True):
    """
    更稳健的最终评估
    """
    mine_net.eval()
    num_samples = features_x.size(0)

    if verbose:
        print("开始稳健的最终评估...")

    with torch.no_grad():
        # 多次评估取平均值
        num_evaluations = 5
        mi_estimates = []

        for eval_run in range(num_evaluations):
            # 每次评估使用不同的随机化
            joint_indices = torch.randperm(num_samples)
            marginal_indices = torch.randperm(num_samples)

            t_joint_list = []
            t_marginal_list = []

            for i in range(0, num_samples, batch_size):
                end_idx = min(i + batch_size, num_samples)
                current_batch_size = end_idx - i

                if current_batch_size < 2:
                    continue

                # 联合分布采样
                batch_joint_idx = joint_indices[i:end_idx]
                batch_x = features_x[batch_joint_idx].to(device)
                batch_z_joint = features_z[batch_joint_idx].to(device)

                # 边缘分布采样
                batch_marginal_idx = marginal_indices[i:end_idx]
                batch_z_marginal = features_z[batch_marginal_idx].to(device)

                # 前向传播
                t_joint = mine_net(batch_x, batch_z_joint)
                t_marginal = mine_net(batch_x, batch_z_marginal)

                t_joint_list.append(t_joint.cpu())
                t_marginal_list.append(t_marginal.cpu())

            if t_joint_list and t_marginal_list:
                # 计算该次评估的MI
                all_t_joint = torch.cat(t_joint_list, dim=0)
                all_t_marginal = torch.cat(t_marginal_list, dim=0)

                mean_t_joint = torch.mean(all_t_joint)

                # 稳定的logsumexp计算
                max_t_marginal = torch.max(all_t_marginal)
                log_mean_exp_t_marginal = max_t_marginal + torch.log(
                    torch.mean(
                        torch.exp(all_t_marginal - max_t_marginal)) + 1e-8
                )

                mi_estimate = (mean_t_joint - log_mean_exp_t_marginal).item()
                mi_estimates.append(mi_estimate)

        if mi_estimates:
            final_mi = np.mean(mi_estimates)
            mi_std = np.std(mi_estimates)

            if verbose:
                print(
                    f"最终MI估计: {final_mi:.4f} ± {mi_std:.4f} nats (基于{num_evaluations}次评估)")

            return final_mi
        else:
            if verbose:
                print("评估失败，返回0")
            return 0.0


# 2. 使用 Transformers 模型提取特征
def get_transformer_features(model, tokenizer, texts, layers_to_extract, max_length=128, batch_size=32):
    """优化版本的特征提取函数"""
    accelerator = Accelerator()
    model = accelerator.prepare(model)
    model.eval()

    extracted_features_batched = {layer_idx: []
                                  for layer_idx in layers_to_extract}

    # 添加进度条
    with tqdm(total=len(texts), desc="提取特征") as pbar:
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]
            batch_texts = [text for text in batch_texts if text.strip()]
            if not batch_texts:
                pbar.update(len(texts[i:i+batch_size]))
                continue

            inputs = tokenizer(batch_texts, return_tensors="pt", padding=True,
                               truncation=True, max_length=max_length)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}

            with torch.no_grad():  # 使用混合精度
                outputs = model(**inputs, output_hidden_states=True)

            all_hidden_states = outputs.hidden_states

            for layer_idx in layers_to_extract:
                if layer_idx < len(all_hidden_states):
                    layer_features_batch = all_hidden_states[layer_idx]
                    # 使用更高效的pooling方法
                    attention_mask = inputs['attention_mask'].unsqueeze(
                        -1).float()
                    masked_features = layer_features_batch * attention_mask
                    features_batch = masked_features.sum(
                        dim=1) / attention_mask.sum(dim=1)

                    extracted_features_batched[layer_idx].append(
                        features_batch.cpu())

            # 清理GPU内存
            del outputs, all_hidden_states
            torch.cuda.empty_cache()
            pbar.update(len(batch_texts))

    # 合并特征并清理内存
    extracted_features = {}
    for layer_idx, feature_list in extracted_features_batched.items():
        if feature_list:
            extracted_features[layer_idx] = torch.cat(feature_list, dim=0)
            # 清理临时列表
            del feature_list
        gc.collect()

    return extracted_features


def allocate_save_ratio_with_qp_smoothness(
    layers_mi: list[float],
    target_overall_save_ratio: float,
    layer_indices_to_analyze,
    min_layer_save_ratio: float = 0.1,
    max_layer_save_ratio: float = 0.9,
    delta_ratio: float = 0.15,
    smoothness_weight: float = 0.15
) -> tuple[dict[str, float], float]:

    if not layers_mi:
        print("警告: 'layers_mi' 为空。")
        return {}, 0.0

    num_layers = len(layers_mi)
    p = np.array(layers_mi)

    # 标准化到[0,1]范围
    p_min, p_max = np.min(p), np.max(p)
    if p_max > p_min:
        p_normalized = (p - p_min) / (p_max - p_min)
    else:
        p_normalized = np.ones_like(p) * 0.5

    print(f"原始重要性得分范围: [{p_min:.4f}, {p_max:.4f}]")
    print(f"标准化后范围: [{np.min(p_normalized):.4f}, {np.max(p_normalized):.4f}]")

    lower_bound = max(min_layer_save_ratio,
                      target_overall_save_ratio - delta_ratio)
    upper_bound = min(max_layer_save_ratio,
                      target_overall_save_ratio + delta_ratio)

    ideal_ratios = upper_bound - p_normalized * (upper_bound - lower_bound)

    def objective_function(s):
        """
        改进的目标函数：使用二次偏差而不是线性项
        """
        # 方法1: 二次偏差目标 (推荐)
        # 目标：让实际分配接近基于重要性的理想分配
        importance_deviation = np.sum((s - ideal_ratios)**2)

        # 平滑项：相邻层差异
        smoothness_term = np.sum(np.diff(s)**2)

        # 分布集中度惩罚：防止过度集中在边界
        mean_s = target_overall_save_ratio
        concentration_penalty = np.sum((s - mean_s)**4)  # 四次项惩罚极端值

        # 组合目标
        objective = (
            (1 - smoothness_weight) * importance_deviation +
            smoothness_weight * smoothness_term +
            0.1 * concentration_penalty  # 小权重的集中度惩罚
        )

        return objective

    # 约束条件
    constraints = ({
        'type': 'eq',
        'fun': lambda s: np.sum(s) - num_layers * target_overall_save_ratio
    })

    # 边界约束
    bounds = tuple([(lower_bound, upper_bound) for _ in range(num_layers)])
    initial_ratios = ideal_ratios.copy()

    # 更好的初始化：基于重要性得分的反向分配
    # 重要性高的层给较低的
    if p_max > p_min:
        # 将重要性得分映射到压缩率范围
        initial_ratios = upper_bound - \
            (p_normalized * (upper_bound - lower_bound))
    else:
        initial_ratios = np.full(num_layers, target_overall_save_ratio)

    # 调整初始值以满足平均约束
    current_mean = np.mean(initial_ratios)
    adjustment = target_overall_save_ratio - current_mean
    initial_guess = np.clip(initial_ratios + adjustment,
                            lower_bound, upper_bound)

    print(
        f"初始猜测值范围: [{np.min(initial_guess):.4f}, {np.max(initial_guess):.4f}]")

    # 求解优化问题
    result = minimize(
        fun=objective_function,
        x0=initial_guess,
        method='SLSQP',
        bounds=bounds,
        constraints=constraints,
        options={'ftol': 1e-9, 'maxiter': 1000}
    )

    if not result.success:
        print(f"优化警告: {result.message}")
        # 如果优化失败，使用简单的线性分配
        return fallback_linear_allocation(
            layers_mi, target_overall_save_ratio,
            min_layer_save_ratio, max_layer_save_ratio
        )

    optimal_ratios = result.x

    # 计算分配质量指标
    ratio_std = np.std(optimal_ratios)
    max_diff = np.max(np.abs(np.diff(optimal_ratios)))
    actual_mean = np.mean(optimal_ratios)

    print(f"分配结果统计:")
    print(f"- 实际平均压缩率: {actual_mean:.4f}")
    print(f"- 压缩率标准差: {ratio_std:.4f}")
    print(f"- 最大相邻差异: {max_diff:.4f}")

    layer_save_ratios = {
        layer_indices_to_analyze[i]: ratio for i, ratio in enumerate(optimal_ratios)
    }

    return layer_save_ratios


def identify_moe_layers(model):
    """自动识别模型中所有包含MoE的层"""
    target_layers = []
    for i, layer in enumerate(model.model.layers):
        if hasattr(layer, 'mlp') and hasattr(layer.mlp, 'experts'):
            target_layers.append(i)

    return target_layers


def get_layer_params(model_name='deepseek-ai/deepseek-moe-16b-base',
                     target_compression_ratio=0.5,
                     min_save_ratio=0.1,
                     max_save_ratio=0.9,
                     delta_ratio=0.1,
                     smoothness_weight=0.05,
                     num_samples=256,
                     mine_epochs=150,
                     verbose=True):
    """
    获取每层的压缩率分配

    Args:
        model_name (str): 模型名称
        target_compression_ratio (float): 目标平均压缩率
        min_save_ratio (float): 最小保留率
        max_save_ratio (float): 最大保留率
        delta_ratio (float): 保留率变化范围
        smoothness_weight (float): 平滑性权重
        num_samples (int): 用于特征提取的样本数量
        verbose (bool): 是否显示详细信息

    Returns:
        tuple: (layer_compression_ratios, actual_average_compression_ratio, layer_indices)
    """

    # 加载数据集
    dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
    all_valid_texts = [text for text in dataset['text'] if text.strip()]
    sample_texts = random.sample(all_valid_texts, num_samples)

    if verbose:
        print(
            f"Total valid samples for feature extraction: {len(sample_texts)}")

    # 加载模型和分词器
    compressed_model = AutoModelForCausalLM.from_pretrained(
        model_name, trust_remote_code=True, device_map='auto', torch_dtype='auto')

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    # 识别MoE层
    layer_indices_to_analyze = identify_moe_layers(compressed_model)
    if verbose:
        print(f"Identified MoE layers: {layer_indices_to_analyze}")

    # 提取特征
    all_features_from_model = get_transformer_features(
        compressed_model,
        tokenizer,
        sample_texts,
        layer_indices_to_analyze,
        batch_size=16
    )

    if verbose:
        for layer_idx, features in all_features_from_model.items():
            print(f"Layer {layer_idx} features shape: {features.shape}")

    # 计算层间互信息
    if verbose:
        print("\nStarting mutual information estimation for adjacent layers...")

    # 只计算相邻层之间的互信息
    adjacent_mutual_information = {}

    for i in range(len(layer_indices_to_analyze) - 1):
        l1 = layer_indices_to_analyze[i]
        l2 = layer_indices_to_analyze[i + 1]

        if l1 in all_features_from_model and l2 in all_features_from_model:
            features1 = all_features_from_model[l1]
            features2 = all_features_from_model[l2]

            # 估计相邻层之间的互信息
            mi_estimate = estimate_mutual_information(
                features1,
                features2,
                mine_hidden_dim=128,
                mine_epochs=mine_epochs,
                mine_lr=1e-4,
                mine_batch_size=64,
                verbose=verbose
            )
            adjacent_mutual_information[(l1, l2)] = mi_estimate
            if verbose:
                print(
                    f"估计的 MI(Layer {l1}, Layer {l2}): {mi_estimate:.4f} nats")

    # 计算每层的互信息得分（基于相邻层的平均）
    layers_mi = [0.0] * len(layer_indices_to_analyze)

    for i in range(len(layer_indices_to_analyze)):
        layer_idx = layer_indices_to_analyze[i]
        relevant_mis = []

        # 查找涉及当前层的相邻层互信息
        for (l1, l2), mi_val in adjacent_mutual_information.items():
            if l1 == layer_idx or l2 == layer_idx:
                relevant_mis.append(mi_val)

        if relevant_mis:
            # 取平均互信息作为该层的重要性得分
            avg_mi = np.mean(relevant_mis)
            layers_mi[i] = avg_mi
            if verbose:
                print(
                    f"Layer {layer_idx}: 平均相邻层互信息 = {avg_mi:.4f} nats")
        else:
            # 如果没有相邻层（比如只有一层的情况），设置默认值
            layers_mi[i] = 0.5  # 中等重要性
            if verbose:
                print(
                    f"Layer {layer_idx}: 无相邻层互信息，使用默认值 0.5")

    total_params, moe_params = count_moe_params(
        compressed_model, layer_indices_to_analyze)

    moe_save_ratio = calculate_moe_save_ratio(
        total_params,
        moe_params,
        target_compression_ratio
    )

    # 分配压缩率
    layer_save_ratios = allocate_save_ratio_with_qp_smoothness(
        layers_mi,
        target_overall_save_ratio=moe_save_ratio,
        layer_indices_to_analyze=layer_indices_to_analyze,
        min_layer_save_ratio=min_save_ratio,
        max_layer_save_ratio=max_save_ratio,
        delta_ratio=delta_ratio,
    )

    if verbose:
        for layer_idx, save_ratio in layer_save_ratios.items():
            print(f"{layer_idx}: Save Ratio = {save_ratio:.4f}")

    # 计算每一层实际能够分配的参数量
    total_params, moe_params = count_moe_params(
        compressed_model, layer_indices_to_analyze)

    moe_params_per_layer = moe_params / len(layer_indices_to_analyze)
    layer_moe_params = {
        l: int(moe_params_per_layer * layer_save_ratios[l]) for l in layer_save_ratios
    }

    # 输出每一层moe的参数量和
    print(sum(layer_moe_params.values()))

    # 清理内存
    del compressed_model, all_features_from_model, tokenizer, sample_texts, dataset
    torch.cuda.empty_cache()
    gc.collect()

    return layer_moe_params, layer_indices_to_analyze


def count_moe_params(model, target_layers):
    """统计MoE参数量"""
    total_params = sum(p.numel() for p in model.parameters())
    moe_params = 0

    for layer_idx in target_layers:
        layer = model.model.layers[layer_idx]
        for expert in layer.mlp.experts:
            moe_params += expert.gate_proj.weight.numel()
            moe_params += expert.up_proj.weight.numel()
            moe_params += expert.down_proj.weight.numel()

    return total_params, moe_params


def calculate_moe_save_ratio(total_params, moe_params, target_compression_ratio):
    """根据整体模型压缩率计算MoE压缩率"""
    non_moe_params = total_params - moe_params

    # 计算所需的MoE压缩率
    # 目标: (non_moe_params + moe_params * moe_save_ratio) / total_params = 1 - target_compression_ratio
    moe_keep_ratio = (total_params * (1 - target_compression_ratio) -
                      non_moe_params) / moe_params

    print(
        f"Model compression_ratio: {target_compression_ratio:.3f} -> MoE save ratio: {moe_keep_ratio:.3f}")
    print(
        f"Total params: {total_params}, MoE params: {moe_params}, Non-MoE params: {non_moe_params}")

    return moe_keep_ratio
