import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import math
import collections
from typing import Dict, List, Tuple
from copy import deepcopy
from EAframework import calculate_distances

class MetaReplayBuffer:
    """Meta-DQN的经验回放池"""
    def __init__(self, capacity=1000):
        self.buffer = collections.deque(maxlen=capacity)
    
    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        if len(self.buffer) < batch_size:
            batch_size = len(self.buffer)
        transitions = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*transitions)
        return np.array(state), action, reward, np.array(next_state), done
    
    def size(self):
        return len(self.buffer)

class MetaQNet(nn.Module):
    """Meta-DQN的Q网络"""
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(MetaQNet, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

class MetaDQN:
    """指导EA搜索的Meta-DQN智能体"""
    
    def __init__(self, state_dim=10, hidden_dim=64, action_dim=5, 
                 learning_rate=1e-3, gamma=0.9, epsilon=0.9, 
                 target_update=10, device='cpu'):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.device = device
        
        # Q网络
        self.q_net = MetaQNet(state_dim, hidden_dim, action_dim).to(device)
        self.target_q_net = MetaQNet(state_dim, hidden_dim, action_dim).to(device)
        
        # 优化器和参数
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=learning_rate)
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = 0.995
        self.min_epsilon = 0.05
        self.target_update = target_update
        self.update_count = 0
        
        # 经验回放
        self.replay_buffer = MetaReplayBuffer(capacity=2000)
        self.batch_size = 32
        self.min_buffer_size = 100
        
        # 策略映射
        self.action_strategies = {
            0: "exploit_dense_regions",      # 在帕累托前沿密集区域深度搜索
            1: "explore_sparse_regions",     # 探索稀疏区域
            2: "balance_diversity",          # 平衡种群多样性
            3: "intensify_convergence",      # 加强收敛
            4: "adaptive_weights"            # 自适应权重调整
        }
        
    def encode_ea_state(self, ea_instance) -> np.ndarray:
        """将EA的当前状态编码为Meta-DQN的观察状态"""
        features = []
        
        # 1. 基本信息
        features.append(len(ea_instance.hv_his))  # 当前迭代次数 (归一化)
        
        # 2. 超体积相关特征
        if len(ea_instance.hv_his) > 0:
            current_hv = ea_instance.hv_his[-1]
            features.append(current_hv)  # 当前HV值
            
            # HV改进趋势
            if len(ea_instance.hv_his) >= 5:
                recent_hv = ea_instance.hv_his[-5:]
                hv_trend = (recent_hv[-1] - recent_hv[0]) / 5
                features.append(hv_trend)
            else:
                features.append(0.0)
        else:
            features.extend([0.0, 0.0])
        
        # 3. 帕累托前沿特征
        if ea_instance.pareto_front:
            pf_size = len(ea_instance.pareto_front)
            features.append(pf_size / ea_instance.n_pop)  # 前沿解占比
            
            # 前沿分布特征
            pf_objs = np.array([ea_instance.pareto_front[i]['objs'] 
                               for i in ea_instance.pareto_front.keys()])
            
            # 目标空间的分布方差 (衡量多样性)
            obj_std = np.std(pf_objs, axis=0).mean()
            features.append(obj_std)
            
            # 前沿的"长度" (衡量覆盖范围)
            obj_range = (np.max(pf_objs, axis=0) - np.min(pf_objs, axis=0)).mean()
            features.append(obj_range)
        else:
            features.extend([0.0, 0.0, 0.0])
        
        # 4. 权重分布特征
        weights = np.array(ea_instance.weights)
        weight_diversity = np.std(weights, axis=0).mean()
        features.append(weight_diversity)
        
        # 5. 收敛性指标
        if len(ea_instance.hv_his) >= 10:
            recent_hv_std = np.std(ea_instance.hv_his[-10:])
            features.append(recent_hv_std)  # 最近HV的方差，衡量收敛稳定性
        else:
            features.append(1.0)  # 初期认为不稳定
        
        # 6. 搜索停滞检测
        if len(ea_instance.hv_his) >= 20:
            early_hv = np.mean(ea_instance.hv_his[-20:-10])
            recent_hv = np.mean(ea_instance.hv_his[-10:])
            stagnation = 1.0 if (recent_hv - early_hv) < 0.001 else 0.0
            features.append(stagnation)
        else:
            features.append(0.0)
        
        # 7. 帕累托前沿区域分析特征
        if ea_instance.pareto_front and len(ea_instance.pareto_front) >= 2:
            try:
                distances = calculate_distances(ea_instance.pareto_front)
                if distances:
                    distance_values = list(distances.values())
                    min_dist = min(distance_values)
                    max_dist = max(distance_values)
                    
                    # 计算密集和稀疏区域
                    dense_threshold = min_dist * 1.5
                    sparse_threshold = max_dist * 0.7
                    
                    dense_count = sum(1 for dist in distance_values if dist <= dense_threshold)
                    sparse_count = sum(1 for dist in distance_values if dist >= sparse_threshold)
                    
                    # 添加密集和稀疏区域比例
                    total_pairs = len(distances)
                    features.append(dense_count / total_pairs if total_pairs > 0 else 0.0)
                    features.append(sparse_count / total_pairs if total_pairs > 0 else 0.0)
                else:
                    features.extend([0.0, 0.0])
            except:
                features.extend([0.0, 0.0])
        else:
            features.extend([0.0, 0.0])
        
        # 确保特征向量长度固定
        while len(features) < self.state_dim:
            features.append(0.0)
        
        return np.array(features[:self.state_dim], dtype=np.float32)
    
    def select_action(self, state):
        """ε-贪婪策略选择动作"""
        if random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1)
        
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        q_values = self.q_net(state_tensor)
        return q_values.argmax().item()
    
    def select_greedy_action(self, state):
        """贪婪策略选择动作（用于评估）"""
        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        q_values = self.q_net(state_tensor)
        return q_values.argmax().item()
    
    def store_experience(self, state, action, reward, next_state, done):
        """存储经验到回放池"""
        self.replay_buffer.add(state, action, reward, next_state, done)
    
    def update(self):
        """更新Meta-DQN网络"""
        if self.replay_buffer.size() < self.min_buffer_size:
            return
        
        # 采样经验
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
        
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.BoolTensor(dones).to(self.device)
        
        # 计算当前Q值
        current_q_values = self.q_net(states).gather(1, actions.unsqueeze(1))
        
        # 计算目标Q值
        next_q_values = self.target_q_net(next_states).max(1)[0].detach()
        target_q_values = rewards + (self.gamma * next_q_values * ~dones)
        
        # 计算损失
        loss = F.mse_loss(current_q_values.squeeze(), target_q_values)
        
        # 反向传播
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # 更新目标网络
        self.update_count += 1
        if self.update_count % self.target_update == 0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())
        
        # 衰减epsilon
        self.epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay)
    
    def get_strategy_name(self, action):
        """获取策略名称"""
        return self.action_strategies.get(action, "unknown")

class MetaGuidedEA:
    """由Meta-DQN指导的EA框架"""
    
    def __init__(self, ea_creator_func, 
                 # DQN超参数
                 dqn_learning_rate=1e-3,
                 dqn_gamma=0.9,
                 # 密集区域参数
                 radius_factor=1.5,
                 # 稀疏区域参数
                 length_factor=2.0,
                 width_factor=0.8,
                 use_abs_projection_length=False):
        # 创建EA实例
        self.ea = ea_creator_func()  # Meta指导的EA
        
        # 超参数配置
        self.dqn_learning_rate = dqn_learning_rate
        self.dqn_gamma = dqn_gamma
        self.radius_factor = radius_factor
        self.length_factor = length_factor  
        self.width_factor = width_factor
        self.use_abs_projection_length = use_abs_projection_length
        
        # Meta-DQN智能体
        device = 'cpu'
        if hasattr(self.ea.pops[0], 'device'):
            device = self.ea.pops[0].device
        self.meta_dqn = MetaDQN(
            learning_rate=self.dqn_learning_rate,
            gamma=self.dqn_gamma,
            device=device
        )
        
    def calculate_meta_reward(self, ea_instance, hv_his, previous_hv, current_hv, action):
        """计算Meta-DQN的奖励"""
        # 基础奖励：HV改进
        hv_improvement = current_hv - previous_hv
        # base_reward = hv_improvement * 100  # 放大奖励信号 work!
        base_reward = hv_improvement 
        
        # 奖励塑形
        reward_components = {
            'hv_improvement': base_reward,
            'exploration_bonus': 0.0,
            'convergence_penalty': 0.0,
            'diversity_bonus': 0.0
        }
        
        # 探索奖励：鼓励在停滞时探索
        if len(hv_his) >= 10:
            recent_improvement = hv_his[-1] - hv_his[-10]
            if recent_improvement < 0.001 and action in [1, 2]:  # 探索类策略
                reward_components['exploration_bonus'] = 0.5
        
        # 多样性奖励
        if ea_instance.pareto_front and len(ea_instance.pareto_front) > 3:
            reward_components['diversity_bonus'] = 0.2
        
        # 收敛惩罚：避免过早收敛
        if len(hv_his) < 50 and action == 3:  # 过早使用收敛策略
            reward_components['convergence_penalty'] = -0.3
        
        # 基于当前episode的系数（使用log函数降低训练速度）
        current_episode = len(hv_his)
        
        # 方案1: 线性+log组合，确保有最小基础值
        # episode_coefficient = 0.1 + 0.9 * np.log(current_episode + 1) / np.log(100 + 1)
        
        # 方案2
        # episode_coefficient = max(1, np.log(current_episode + 1))
        # episode_coefficient = np.log(current_episode + 1) / np.log(100 + 1)
        # episode_coefficient = np.log(2 * current_episode + 1) / np.log(100 + 1)  # work!
        episode_coefficient = 1
        
        # total_reward = sum(reward_components.values()) * episode_coefficient
        # total_reward = reward_components['hv_improvement'] * episode_coefficient # todo 0705 修改
        total_reward = max(0, reward_components['hv_improvement']) * np.log(current_episode + 1) # todo 0706 修改
        return total_reward, reward_components
    
    def execute_meta_strategy(self, ea_instance, strategy_action):
        """执行Meta-DQN选择的策略"""
        strategy_name = self.meta_dqn.get_strategy_name(strategy_action)
        
        if strategy_name == "exploit_dense_regions":
            # 在帕累托前沿密集区域加强搜索
            self._exploit_dense_regions(ea_instance)
            
        elif strategy_name == "explore_sparse_regions":
            # 探索稀疏区域
            self._explore_sparse_regions(ea_instance)
            
        elif strategy_name == "balance_diversity":
            # 增加种群多样性
            self._balance_diversity(ea_instance)
            
        elif strategy_name == "intensify_convergence":
            # 加强收敛
            self._intensify_convergence(ea_instance)
            
        elif strategy_name == "adaptive_weights":
            # 自适应权重调整
            self._adaptive_weights(ea_instance)
    
    def _exploit_dense_regions(self, ea_instance):
        """在密集区域深度搜索 - 基于扇形区域搜索所有个体"""
        if not ea_instance.pareto_front or len(ea_instance.pareto_front) < 2:
            # 如果帕累托前沿太少，回退到原始策略
            select_indices = ea_instance.select_policies(ratio=0.5)
            for idx in select_indices:
                ea_instance.exe_task(ea_instance.pops[idx], ea_instance.weights[idx], n_ep=15)
            return
        
        # 1. 找到帕累托前沿上的密集区域中心点
        distances = calculate_distances(ea_instance.pareto_front)
        if not distances:
            return
        
        min_distance = min(distances.values())
        threshold = min_distance * 1.5
        dense_pairs = [(pair, dist) for pair, dist in distances.items() if dist <= threshold]
        
        if not dense_pairs:
            return
        
        # 获取密集区域的中心点
        dense_centers = []
        for (id1, id2), _ in dense_pairs:
            center_obj = [(ea_instance.pareto_front[id1]['objs'][i] + ea_instance.pareto_front[id2]['objs'][i]) / 2 
                         for i in range(ea_instance.n_obj)]
            dense_centers.append(center_obj)
        
        # 2. 在整个种群中搜索密集区域附近的所有个体
        dense_region_indices = self._find_individuals_in_dense_regions(ea_instance, dense_centers)
        
        # 3. Blink机制：一定概率跳过最前面的个体
        blink_prob = 0.3
        if dense_region_indices and random.random() < blink_prob:
            # 按第一个目标排序
            all_objs = [ea_instance.evaluator.evaluate(ea_instance.pops[idx])[1][1:] for idx in dense_region_indices]
            sorted_pairs = sorted(zip(dense_region_indices, all_objs), key=lambda x: x[1][0])
            sorted_indices = [pair[0] for pair in sorted_pairs]
            
            if len(sorted_indices) > 1:
                skip_count = min(2, len(sorted_indices) // 3)
                dense_region_indices = sorted_indices[skip_count:]
        
        # 4. 在密集区域进行深度搜索（带模拟退火接受机制）
        # print(f"🎯 密集区域搜索: 发现 {len(dense_region_indices)} 个个体")
        for idx in dense_region_indices:
            # 深度复制原始个体，避免浅拷贝
            original_agent = deepcopy(ea_instance.pops[idx])
            original_obj = ea_instance.evaluator.evaluate(original_agent)[1][1:]
            
            # 训练个体
            trained_agent = deepcopy(ea_instance.pops[idx])
            ea_instance.exe_task(trained_agent, ea_instance.weights[idx], n_ep=18)
            
            # 评估训练后的个体
            trained_obj = ea_instance.evaluator.evaluate(trained_agent)[1][1:]
            
            # 模拟退火接受机制
            if self._simulated_annealing_accept(original_obj, trained_obj, ea_instance, individual_idx=idx):
                ea_instance.pops[idx] = trained_agent
            # 否则保持原个体不变
    
    def _explore_sparse_regions(self, ea_instance):
        """探索稀疏区域 - 基于扇形区域搜索所有个体"""
        if not ea_instance.pareto_front or len(ea_instance.pareto_front) < 2:
            # 回退到原始策略
            ea_instance.pareto_adaptive_fine_tuning_weights_adjust(pb=8)
            select_indices = ea_instance.select_policies(ratio=0.3)
            for idx in select_indices:
                ea_instance.exe_task(ea_instance.pops[idx], ea_instance.weights[idx], n_ep=10)
            return
        
        # 1. 找到帕累托前沿上的稀疏区域
        distances = calculate_distances(ea_instance.pareto_front)
        if not distances:
            return
        
        max_distance = max(distances.values())
        threshold = max_distance * 0.7
        sparse_pairs = [(pair, dist) for pair, dist in distances.items() if dist >= threshold]
        
        if not sparse_pairs:
            return
        
        # 获取稀疏区域的扇形区域定义
        sparse_sectors = self._define_sparse_sectors(ea_instance, sparse_pairs)
        
        # 2. 在整个种群中搜索稀疏区域内的所有个体
        sparse_region_indices = self._find_individuals_in_sparse_regions(ea_instance, sparse_sectors)
        
        # 3. Blink机制：一定概率跳过最前面的个体
        blink_prob = 0.4  # 稀疏区域更激进
        if sparse_region_indices and random.random() < blink_prob:
            # 按第一个目标排序
            all_objs = [ea_instance.evaluator.evaluate(ea_instance.pops[idx])[1][1:] for idx in sparse_region_indices]
            sorted_pairs = sorted(zip(sparse_region_indices, all_objs), key=lambda x: x[1][0])
            sorted_indices = [pair[0] for pair in sorted_pairs]
            
            if len(sorted_indices) > 1:
                skip_count = min(1, len(sorted_indices) // 4)
                sparse_region_indices = sorted_indices[skip_count:]
        
        # 4. 在稀疏区域进行探索（带模拟退火接受机制）
        # print(f"🔍 稀疏区域探索: 发现 {len(sparse_region_indices)} 个个体")
        
        # 调整权重向稀疏区域靠拢
        ea_instance.pareto_adaptive_fine_tuning_weights_adjust(pb=min(8, len(sparse_region_indices)))
        
        # 对稀疏区域个体进行训练
        for idx in sparse_region_indices:
            # 深度复制原始个体，避免浅拷贝
            original_agent = deepcopy(ea_instance.pops[idx])
            original_obj = ea_instance.evaluator.evaluate(original_agent)[1][1:]
            
            # 训练个体
            trained_agent = deepcopy(ea_instance.pops[idx])
            ea_instance.exe_task(trained_agent, ea_instance.weights[idx], n_ep=12)
            
            # 评估训练后的个体
            trained_obj = ea_instance.evaluator.evaluate(trained_agent)[1][1:]
            
            # 模拟退火接受机制（稀疏区域更激进，接受概率更高）
            if self._simulated_annealing_accept(original_obj, trained_obj, ea_instance, exploration_mode=True, individual_idx=idx):
                ea_instance.pops[idx] = trained_agent
            # 否则保持原个体不变
    
    def _find_individuals_in_dense_regions(self, ea_instance, dense_centers):
        """在密集区域附近找到所有个体"""
        dense_indices = []
        
        # 评估所有个体的目标值
        all_objs = []
        for i in range(ea_instance.n_pop):
            obj = ea_instance.evaluator.evaluate(ea_instance.pops[i])[1][1:]
            all_objs.append(obj)
        
        # 计算密集区域的平均半径
        if dense_centers:
            # 计算帕累托前沿的平均距离作为基准半径
            pf_objs = [ea_instance.pareto_front[key]['objs'] for key in ea_instance.pareto_front.keys()]
            if len(pf_objs) > 1:
                avg_pf_distance = np.mean([np.linalg.norm(np.array(pf_objs[i]) - np.array(pf_objs[j])) 
                                         for i in range(len(pf_objs)) for j in range(i+1, len(pf_objs))])
                search_radius = avg_pf_distance * self.radius_factor
            else:
                search_radius = 5.0  # 默认半径
            
            # 在每个密集中心附近搜索个体
            for center in dense_centers:
                for i, obj in enumerate(all_objs):
                    distance = np.linalg.norm(np.array(obj) - np.array(center))
                    if distance <= search_radius and i not in dense_indices:
                        dense_indices.append(i)
        
        return dense_indices
    
    def _define_sparse_sectors(self, ea_instance, sparse_pairs):
        """定义稀疏区域的扇形区域"""
        sectors = []
        
        for (id1, id2), _ in sparse_pairs:
            obj1 = ea_instance.pareto_front[id1]['objs']
            obj2 = ea_instance.pareto_front[id2]['objs']
            
            # 计算两点间的中点和方向向量
            midpoint = [(obj1[i] + obj2[i]) / 2 for i in range(len(obj1))]
            direction = [obj2[i] - obj1[i] for i in range(len(obj1))]
            
            # 计算垂直方向（用于定义扇形的宽度）
            if len(direction) == 2:  # 二维情况
                perpendicular = [-direction[1], direction[0]]
            else:  # 高维情况的简化处理
                perpendicular = [1.0 if i == 0 else 0.0 for i in range(len(direction))]
            
            sectors.append({
                'center': midpoint,
                'direction': direction,
                'perpendicular': perpendicular,
                'length': np.linalg.norm(direction)
            })
        
        return sectors
    
    def _find_individuals_in_sparse_regions(self, ea_instance, sparse_sectors):
        """在稀疏区域的扇形内找到所有个体"""
        sparse_indices = []
        
        # 评估所有个体的目标值
        all_objs = []
        for i in range(ea_instance.n_pop):
            obj = ea_instance.evaluator.evaluate(ea_instance.pops[i])[1][1:]
            all_objs.append(obj)
        
        # 在每个稀疏扇形区域内搜索个体
        for sector in sparse_sectors:
            center = np.array(sector['center'])
            direction = np.array(sector['direction'])
            perpendicular = np.array(sector['perpendicular'])
            sector_length = sector['length']
            
            # 扇形参数 - 使用可配置参数
            max_distance = sector_length * self.length_factor  # 扇形长度
            max_width = sector_length * self.width_factor      # 扇形宽度
            
            for i, obj in enumerate(all_objs):
                obj_vec = np.array(obj)
                to_obj = obj_vec - center
                
                # 检查是否在扇形内
                if len(direction) > 0 and np.linalg.norm(direction) > 0:
                    # 投影到主方向
                    direction_norm = direction / np.linalg.norm(direction)
                    projection_length = np.dot(to_obj, direction_norm)
                    if self.use_abs_projection_length:
                        projection_length = abs(projection_length)
                    
                    # 投影到垂直方向
                    if len(perpendicular) > 0 and np.linalg.norm(perpendicular) > 0:
                        perp_norm = perpendicular / np.linalg.norm(perpendicular)
                        projection_width = abs(np.dot(to_obj, perp_norm))
                    else:
                        projection_width = 0
                    
                    # 检查是否在扇形范围内
                    if (0 <= projection_length <= max_distance and 
                        projection_width <= max_width and 
                        i not in sparse_indices):
                        sparse_indices.append(i)
        
        return sparse_indices
    
    def _simulated_annealing_accept(self, original_obj, trained_obj, ea_instance, exploration_mode=False, individual_idx=None):
        """
        模拟退火接受机制
        
        Args:
            original_obj: 原始个体的目标值
            trained_obj: 训练后个体的目标值
            ea_instance: EA实例
            exploration_mode: 是否为探索模式（稀疏区域搜索）
            individual_idx: 个体索引，用于获取对应权重
        
        Returns:
            bool: 是否接受训练后的个体
        """
        # 1. 如果训练后的个体支配原个体，直接接受
        if all(x >= y for x, y in zip(original_obj, trained_obj)) and any(x > y for x, y in zip(original_obj, trained_obj)):
            return True
        
        # 2. 计算目标改进程度（使用加权和）
        # 尝试获取个体对应的权重
        current_weights = [0.5, 0.5]  # 默认权重
        if hasattr(ea_instance, 'weights') and len(ea_instance.weights) > 0:
            if individual_idx is not None and individual_idx < len(ea_instance.weights):
                # 使用个体对应的权重
                current_weights = ea_instance.weights[individual_idx]
            else:
                # 使用第一个权重作为默认
                current_weights = ea_instance.weights[0] if ea_instance.weights else [0.5, 0.5]
        
        # 计算加权目标值
        original_weighted = sum(w * obj for w, obj in zip(current_weights, original_obj))
        trained_weighted = sum(w * obj for w, obj in zip(current_weights, trained_obj))
        
        # 目标改进（正值表示改进）
        improvement = original_weighted - trained_weighted
        
        # 3. 如果有改进，直接接受
        if improvement > 0:
            return True
        
        # 4. 如果没有改进，使用模拟退火概率接受
        # 计算当前"温度" - 基于迭代次数和SA参数
        current_iteration = len(ea_instance.hv_his) if hasattr(ea_instance, 'hv_his') else 1
        
        # 使用EA实例的SA参数
        if hasattr(ea_instance, 'sa_ratio'):
            temperature = ea_instance.sa_ratio
        else:
            # 如果没有SA参数，自己计算温度
            initial_temp = 1.0
            cooling_rate = 0.995
            temperature = initial_temp * (cooling_rate ** current_iteration)
        
        # 探索模式下温度更高（更容易接受差解）
        if exploration_mode:
            temperature *= 2.0
        
        # 避免温度过低导致数值问题
        temperature = max(temperature, 1e-6)
        
        # 计算接受概率
        try:
            delta_energy = abs(improvement)  # 能量差（正值）
            accept_prob = math.exp(-delta_energy / temperature)
        except OverflowError:
            accept_prob = 0.0
        
        # 概率接受
        return random.random() < accept_prob
    
    def _balance_diversity(self, ea_instance):
        """平衡多样性（带模拟退火接受机制）"""
        # 随机权重调整 + 增加训练
        ea_instance.pareto_random_weights_adjust(ratio=0.4)
        select_indices = ea_instance.select_policies(ratio=0.6)
        
        for idx in select_indices:
            # 深度复制原始个体，避免浅拷贝
            original_agent = deepcopy(ea_instance.pops[idx])
            original_obj = ea_instance.evaluator.evaluate(original_agent)[1][1:]
            
            # 训练个体
            trained_agent = deepcopy(ea_instance.pops[idx])
            ea_instance.exe_task(trained_agent, ea_instance.weights[idx], n_ep=8)
            
            # 评估训练后的个体
            trained_obj = ea_instance.evaluator.evaluate(trained_agent)[1][1:]
            
            # 模拟退火接受机制
            if self._simulated_annealing_accept(original_obj, trained_obj, ea_instance, individual_idx=idx):
                ea_instance.pops[idx] = trained_agent
    
    def _intensify_convergence(self, ea_instance):
        """加强收敛（带模拟退火接受机制）"""
        # 减少探索，增加利用
        ea_instance.policy_pareto_ascent_direction_weight_adjust(ea_instance.select_policies(ratio=0.2))
        
        # 对最优区域进行密集训练
        if ea_instance.pareto_front:
            best_indices = list(ea_instance.pareto_front.keys())[:3]  # 取前3个
            for idx in best_indices:
                # 深度复制原始个体，避免浅拷贝
                original_agent = deepcopy(ea_instance.pops[idx])
                original_obj = ea_instance.evaluator.evaluate(original_agent)[1][1:]
                
                # 训练个体（收敛阶段训练更多轮次）
                trained_agent = deepcopy(ea_instance.pops[idx])
                ea_instance.exe_task(trained_agent, ea_instance.weights[idx], n_ep=20)
                
                # 评估训练后的个体
                trained_obj = ea_instance.evaluator.evaluate(trained_agent)[1][1:]
                
                # 模拟退火接受机制（收敛阶段更保守）
                if self._simulated_annealing_accept(original_obj, trained_obj, ea_instance, exploration_mode=False, individual_idx=idx):
                    ea_instance.pops[idx] = trained_agent
    
    def _adaptive_weights(self, ea_instance):
        """自适应权重调整（带模拟退火接受机制）"""
        # 结合多种权重调整策略
        ea_instance.policy_pareto_ascent_direction_weight_adjust(ea_instance.select_policies(ratio=0.3))
        ea_instance.pareto_adaptive_fine_tuning_weights_adjust(pb=5)
        
        select_indices = ea_instance.select_policies(ratio=0.4)
        for idx in select_indices:
            # 深度复制原始个体，避免浅拷贝
            original_agent = deepcopy(ea_instance.pops[idx])
            original_obj = ea_instance.evaluator.evaluate(original_agent)[1][1:]
            
            # 训练个体
            trained_agent = deepcopy(ea_instance.pops[idx])
            ea_instance.exe_task(trained_agent, ea_instance.weights[idx], n_ep=12)
            
            # 评估训练后的个体
            trained_obj = ea_instance.evaluator.evaluate(trained_agent)[1][1:]
            
            # 模拟退火接受机制
            if self._simulated_annealing_accept(original_obj, trained_obj, ea_instance, individual_idx=idx):
                ea_instance.pops[idx] = trained_agent
    
 