import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm
from copy import deepcopy
from vd_env import *
from dynaQ import *
from pymoo.indicators.hv import HV
from utils import *


def generate_angles_and_trig_values(n):
    angles = np.random.uniform(0, 90, n)  # Generate n random angles between 0 and 90 degrees
    sin_values = np.sin(np.radians(angles)) ** 2  # Compute sin^2 of each angle
    cos_values = np.cos(np.radians(angles)) ** 2  # Compute cos^2 of each angle

    # Return as a 2D list: [[sin^2(angle), cos^2(angle)]]
    result = [[float(sin_val), float(cos_val)] for angle, sin_val, cos_val in zip(angles, sin_values, cos_values)]
    return result


def generate_weights_batch_dfs(i, obj_num, min_weight, max_weight, delta_weight, weight, weights_batch):
    """Generate uniform weight combinations using DFS (from PGMORL)"""
    if i == obj_num - 1:
        weight.append(1.0 - np.sum(weight[0:i]))
        weights_batch.append(deepcopy(weight))
        weight = weight[0:i]
        return
    w = min_weight
    while w < max_weight + 0.5 * delta_weight and np.sum(weight[0:i]) + w < 1.0 + 0.5 * delta_weight:
        weight.append(w)
        generate_weights_batch_dfs(i + 1, obj_num, min_weight, max_weight, delta_weight, weight, weights_batch)
        weight = weight[0:i]
        w += delta_weight


def euclidean_distance(point1, point2):
    # 计算两个点之间的欧几里得距离
    return np.sqrt(np.sum((np.array(point1) - np.array(point2)) ** 2))


def calculate_distances(pareto_front):
    """
    根据objs的第一个元素排序，并计算排序后的点之间的距离。

    参数:
    - pareto_front: dict, 键为id，值为字典，包含'weight'和'objs'键。

    返回:
    - distances: dict, 键为(id1, id2)元组，值为这两个点之间的距离。
    """
    # 提取数据并根据objs的第一个元素进行排序
    items = sorted(pareto_front.items(), key=lambda x: x[1]['objs'][0])

    # 计算连续两点之间的距离
    distances = {}
    for i in range(len(items) - 1):
        id1, data1 = items[i]
        id2, data2 = items[i + 1]
        # 计算欧式距离
        dist = np.linalg.norm(np.array(data1['objs']) - np.array(data2['objs']))
        distances[(id1, id2)] = dist

    return distances


class Evaluator:
    def __init__(self, env, n_obj=3, ref=50):
        self.env = env
        self.n_obj = n_obj
        # self.ref = [0, ref, ref]
        self.ref = [0, env.n_customers, env.n_customers]

    def evaluate(self, agent):
        rewards = []
        objs = [0 for i in range(self.n_obj)]
        state = self.env.reset()
        while not self.env.done:
            action = agent.take_action_evaluate(state[-1], self.env.available_actions)
            next_state, reward, done, _ = self.env.step(action)
            rewards.append(reward)
            for r_i in range(len(reward)):
                objs[r_i] += reward[r_i]
            state = next_state
        objs[1] = self.ref[1] - objs[1]
        objs[2] = self.ref[2] - objs[2]
        return rewards, objs


class EA:
    def __init__(self, agent, env, evaluator, n_pop=10, n_obj=2, iter_max=1000, 
                 selection_method='adaptive', min_weight=0.0, max_weight=1.0, delta_weight=0.2):
        self.base_agent = agent
        self.env = env
        self.evaluator = evaluator
        self.n_pop = n_pop
        self.n_obj = n_obj
        self.ref = self.env.ref
        self.iter_max = iter_max
        self.selection_method = selection_method  # 'adaptive', 'moead', 'pfa', 'pa2d_adaptive'
        self.min_weight = min_weight
        self.max_weight = max_weight
        self.delta_weight_moea = delta_weight  # for MOEA/D weight generation
        
        # Initialize population
        self.pops = [deepcopy(self.base_agent) for i in range(self.n_pop)]
        
        # Initialize weights based on selection method
        if self.selection_method == 'moead':
            self.weights = self._generate_moead_weights()
        else:
            self.weights = generate_angles_and_trig_values(self.n_pop)
        
        # Other attributes
        self.objs = []
        self.objs_his = []
        self.hv_his = []
        self.sp_his = []  # SP (Spacing) metric history
        self.pareto_front = {}
        self.delta_weight = 0.01  # for adaptive weight adjustment
        self.cooling_ratio = 0.995
        self.sa_ratio = 0.1
        
        # PFA specific attributes
        self.iteration_count = 0
        self.warmup_iter = 0  # assuming no warmup for this framework
        
        # PA2D specific attributes
        self.pa2d_phase_transition = int(iter_max * 0.3)  # 30% transition point

    def _generate_moead_weights(self):
        """Generate uniform weights for MOEA/D using DFS method"""
        weights_batch = []
        generate_weights_batch_dfs(0, self.n_obj, self.min_weight, self.max_weight, 
                                 self.delta_weight_moea, [], weights_batch)
        
        # If we have more weights than population size, sample uniformly
        if len(weights_batch) > self.n_pop:
            indices = np.linspace(0, len(weights_batch)-1, self.n_pop, dtype=int)
            weights_batch = [weights_batch[i] for i in indices]
        
        # If we have fewer weights, duplicate some randomly
        while len(weights_batch) < self.n_pop:
            weights_batch.append(random.choice(weights_batch))
        
        # 确保所有权重都是list而不是tuple
        result = weights_batch[:self.n_pop]
        return [list(w) if not isinstance(w, list) else w for w in result]

    def _evaluate_weighted_sum(self, obj, weight):
        """Evaluate weighted sum scalarization (for MOEA/D)"""
        return sum(o * w for o, w in zip(obj, weight))

    def select_policies(self, ratio=0.3):
        select_policies_index = []
        for i in range(self.n_pop):
            if random.random() <= ratio:
                select_policies_index.append(i)
        return select_policies_index

    def moead_weight_selection(self):
        """MOEA/D style selection: for each weight, find the best individual"""
        selected_agents = {}
        selected_weights = {}
        
        for i, weight in enumerate(self.weights):
            best_idx = None
            best_value = float('-inf')
            
            # Find the individual with best weighted sum for this weight
            for j, obj in enumerate(self.objs):
                value = self._evaluate_weighted_sum(obj, weight)
                if value > best_value:
                    best_value = value
                    best_idx = j
            
            if best_idx is not None:
                selected_agents[i] = deepcopy(self.pops[best_idx])
                selected_weights[i] = weight
        
        return selected_agents, selected_weights

    def pfa_weight_adjustment(self):
        """PFA style progressive weight adjustment (2-objective only)"""
        if self.n_obj > 2:
            print("Warning: PFA only supports 2-objective problems, falling back to adaptive method")
            return self.select_policies()
        
        # Calculate progress ratio
        total_iterations = self.iter_max
        delta_ratio = self.iteration_count / total_iterations
        delta_ratio = np.clip(delta_ratio, 0.0, 1.0)
        
        # Adjust weights progressively
        adjusted_weights = []
        base_weights = np.arange(self.min_weight, self.max_weight + 0.5 * self.delta_weight_moea, self.delta_weight_moea)
        
        for i in base_weights:
            w = np.clip(i + delta_ratio * self.delta_weight_moea, self.min_weight, self.max_weight)
            adjusted_weights.append([abs(w), abs(1.0 - w)])
        
        # Update weights to match population size
        if len(adjusted_weights) > self.n_pop:
            indices = np.linspace(0, len(adjusted_weights)-1, self.n_pop, dtype=int)
            self.weights = [adjusted_weights[i] for i in indices]
        else:
            # Pad with random weights if needed
            while len(adjusted_weights) < self.n_pop:
                adjusted_weights.append([random.random(), 1.0 - adjusted_weights[-1][0]])
            self.weights = adjusted_weights[:self.n_pop]
        
        # Return all policies for PFA (use current population)
        return list(range(self.n_pop))

    def policy_pareto_ascent_direction_weight_adjust(self, select_policies_index):
        if len(self.objs_his) >= 2:
            for i in select_policies_index:
                if max(self.weights[i]) + self.delta_weight < 1:
                    delta_obj = [self.objs_his[-1][i][j] - self.objs_his[-2][i][j] for j in range(self.n_obj)]
                    # 获取最大值和最小值
                    if delta_obj[0] * self.weights[i][1] > delta_obj[1] * self.weights[i][0]:
                        self.weights[i][0] += self.delta_weight
                        self.weights[i][1] -= self.delta_weight
                    else:
                        self.weights[i][0] -= self.delta_weight
                        self.weights[i][1] += self.delta_weight
                else:
                    pass
        else:
            pass

    def pareto_adaptive_fine_tuning_weights_adjust(self, pb=5):
        distances = calculate_distances(self.pareto_front)
        sorted_ids = sorted(distances.items(), key=lambda x: x[1], reverse=True)
        use_ids = [x[0] for x in sorted_ids]
        indices = []
        i = 0
        while len(indices) < pb and i < len(use_ids) - 1:
            id_1, id_2 = use_ids[i][0], use_ids[i][1]
            if max([max(self.weights[id_1]), max(self.weights[id_2])]) + self.delta_weight < 1:
                delta_weight = self.delta_weight * 0.5
                if self.weights[id_1][0] > self.weights[id_2][0]:
                    if random.random() <= 0.5:
                        self.weights[id_1][0] -= delta_weight
                        self.weights[id_1][1] += delta_weight
                        indices.append(id_1)
                    else:
                        self.weights[id_2][0] += delta_weight
                        self.weights[id_2][1] -= delta_weight
                        indices.append(id_2)
                else:
                    if random.random() <= 0.5:
                        self.weights[id_1][0] += delta_weight
                        self.weights[id_1][1] -= delta_weight
                        indices.append(id_1)
                    else:
                        self.weights[id_2][0] -= delta_weight
                        self.weights[id_2][1] += delta_weight
                        indices.append(id_2)
            else:
                pass
            i += 1

        return indices

    def pareto_random_weights_adjust(self, ratio=0.1):
        select_policies_index = []
        for index in range(self.n_pop):
            if random.random() <= ratio:
                select_policies_index.append(index)
                if random.random() <= 0.5:
                    self.weights[index][0] -= self.delta_weight
                    self.weights[index][1] += self.delta_weight
                else:
                    self.weights[index][1] -= self.delta_weight
                    self.weights[index][0] += self.delta_weight
        return select_policies_index

    def pa2d_pareto_ascent_direction_weight_adjust(self, select_policies_index):
        """PA2D-inspired Pareto ascent direction weight adjustment"""
        if len(self.objs_his) >= 2:
            for i in select_policies_index:
                if i >= len(self.objs):
                    continue
                    
                # Compute objective improvement direction
                delta_obj = [self.objs_his[-1][i][j] - self.objs_his[-2][i][j] 
                           for j in range(self.n_obj)]
                
                if self.n_obj == 2:
                    # Check current weighted improvement
                    # Note: All objectives are already transformed to "higher is better" by Evaluator
                    current_weighted_improvement = (delta_obj[0] * self.weights[i][0] + 
                                                  delta_obj[1] * self.weights[i][1])
                    
                    if current_weighted_improvement > 0:
                        # Current direction is effective (improvement), fine-tune weights
                        # Focus more on the objective that improved less to balance exploration
                        if abs(delta_obj[0]) > abs(delta_obj[1]):
                            # Obj0 improved more, increase weight for obj1 to balance
                            adjustment = self.delta_weight * 0.3  # More conservative adjustment
                            self.weights[i][1] = min(1.0, self.weights[i][1] + adjustment)
                            self.weights[i][0] = 1.0 - self.weights[i][1]
                        else:
                            # Obj1 improved more, increase weight for obj0 to balance
                            adjustment = self.delta_weight * 0.3
                            self.weights[i][0] = min(1.0, self.weights[i][0] + adjustment)
                            self.weights[i][1] = 1.0 - self.weights[i][0]
                    else:
                        # Current direction is not effective (degradation), need exploration
                        # Use more balanced weights with controlled perturbation
                        if abs(current_weighted_improvement) > 0.01:  # Significant degradation
                            # Try opposite emphasis: if degrading, shift weights more dramatically
                            if delta_obj[0] < 0 and delta_obj[1] < 0:
                                # Both degrading, try balanced approach
                                balance_factor = 0.5 + 0.05 * (np.random.random() - 0.5)
                            else:
                                # Mixed results, emphasize the better-performing objective
                                if delta_obj[0] > delta_obj[1]:
                                    balance_factor = min(0.8, self.weights[i][0] + 0.1)
                                else:
                                    balance_factor = max(0.2, self.weights[i][0] - 0.1)
                            self.weights[i][0] = balance_factor
                            self.weights[i][1] = 1.0 - balance_factor
                        else:
                            # Small changes, make small adjustments
                            perturbation = 0.05 * (np.random.random() - 0.5)
                            self.weights[i][0] = np.clip(self.weights[i][0] + perturbation, 0.1, 0.9)
                            self.weights[i][1] = 1.0 - self.weights[i][0]
                else:
                    # Multi-objective case: keep original logic
                    if delta_obj[0] * self.weights[i][1] > delta_obj[1] * self.weights[i][0]:
                        self.weights[i][0] = min(1.0, self.weights[i][0] + self.delta_weight)
                        self.weights[i][1] = max(0.0, 1.0 - self.weights[i][0])
                    else:
                        self.weights[i][1] = min(1.0, self.weights[i][1] + self.delta_weight)
                        self.weights[i][0] = max(0.0, 1.0 - self.weights[i][1])

    def pa2d_policy_selection(self, ratio=0.3):
        """PA2D-inspired policy selection with partitioning"""
        if not self.objs:
            return self.select_policies(ratio)
        
        objectives = np.array(self.objs)
        n_policies = len(objectives)
        
        if self.n_obj == 2:
            # 2D case: partition by angle
            angles = np.arctan2(objectives[:, 1], objectives[:, 0])
            angles = np.where(angles < 0, angles + 2*np.pi, angles)
            
            # Create 4 partitions (more reasonable for small populations)
            n_partitions = 4
            partition_size = 2 * np.pi / n_partitions
            selected_indices = []
            
            for partition_id in range(n_partitions):
                angle_min = partition_id * partition_size
                angle_max = (partition_id + 1) * partition_size
                
                in_partition = []
                for idx in range(n_policies):
                    if angle_min <= angles[idx] < angle_max:
                        # Use dominance-based distance instead of simple norm
                        # Prefer solutions that are better in at least one objective
                        dominance_score = max(objectives[idx])  # Best single objective performance
                        in_partition.append((dominance_score, idx))
                
                if in_partition:
                    # Sort by dominance score and select from top candidates
                    in_partition.sort(reverse=True)
                    # Select best one from this partition
                    _, selected_idx = in_partition[0]
                    selected_indices.append(selected_idx)
            
            # Ensure we have enough selections
            target_num = max(1, int(n_policies * ratio))
            while len(selected_indices) < target_num:
                remaining = [i for i in range(n_policies) if i not in selected_indices]
                if remaining:
                    selected_indices.append(random.choice(remaining))
                else:
                    break
                    
            return selected_indices[:target_num]
        else:
            # Multi-objective case: use original selection
            return self.select_policies(ratio)

    def pa2d_pareto_adaptive_fine_tuning_weights_adjust(self, pb=5):
        """PA2D-inspired Pareto adaptive fine-tuning"""
        if not self.pareto_front or len(self.pareto_front) < 2:
            return []
        
        distances = calculate_distances(self.pareto_front)
        if not distances:
            return []
        
        sorted_distances = sorted(distances.items(), key=lambda x: x[1], reverse=True)
        selected_indices = []
        processed_pairs = set()
        
        for (id1, id2), dist in sorted_distances[:pb]:
            if len(selected_indices) >= pb:
                break
                
            pair_key = tuple(sorted([id1, id2]))
            if pair_key in processed_pairs:
                continue
            processed_pairs.add(pair_key)
            
            if self.n_obj == 2:
                # Get objective values
                obj1 = np.array(self.pareto_front[id1]['objs'])
                obj2 = np.array(self.pareto_front[id2]['objs'])
                
                # Compute midpoint-oriented weights
                mid_point = (obj1 + obj2) / 2
                if mid_point[0] + mid_point[1] > 0:
                    weight_factor = mid_point[0] / (mid_point[0] + mid_point[1])
                    weight_factor = np.clip(weight_factor, 0.1, 0.9)
                    
                    # Adjust weights more aggressively towards filling gaps
                    if id1 < len(self.weights):
                        # Move more aggressively towards the gap-filling direction
                        self.weights[id1][0] = 0.5 * self.weights[id1][0] + 0.5 * weight_factor
                        self.weights[id1][1] = 1.0 - self.weights[id1][0]
                        selected_indices.append(id1)
                    
                    if id2 < len(self.weights) and len(selected_indices) < pb:
                        # Use slightly different weight for diversity
                        adjusted_factor = weight_factor + 0.1 * (0.5 - weight_factor)  # Move towards 0.5
                        self.weights[id2][0] = 0.5 * self.weights[id2][0] + 0.5 * adjusted_factor
                        self.weights[id2][1] = 1.0 - self.weights[id2][0]
                        selected_indices.append(id2)
            else:
                # Multi-objective case: simplified approach
                if id1 < len(self.weights) and id2 < len(self.weights):
                    delta_weight = self.delta_weight * 0.5
                    if random.random() <= 0.5:
                        self.weights[id1][0] += delta_weight * (1 if random.random() > 0.5 else -1)
                        self.weights[id1][0] = np.clip(self.weights[id1][0], 0.0, 1.0)
                        self.weights[id1][1] = 1.0 - self.weights[id1][0]
                        selected_indices.append(id1)
                    else:
                        self.weights[id2][0] += delta_weight * (1 if random.random() > 0.5 else -1)
                        self.weights[id2][0] = np.clip(self.weights[id2][0], 0.0, 1.0)
                        self.weights[id2][1] = 1.0 - self.weights[id2][0]
                        selected_indices.append(id2)
        
        return selected_indices[:pb]

    def compute_spacing(self, objectives_list):
        """
        Calculate SP (Spacing) metric for multi-objective optimization.
        SP measures the standard deviation of distances between consecutive solutions.
        Lower SP values indicate more uniform distribution of solutions.
        
        Args:
            objectives_list: List of objective vectors
            
        Returns:
            float: SP (Spacing) value
        """
        if len(objectives_list) < 2:
            return 0.0
        
        objectives = np.array(objectives_list)
        n_solutions = len(objectives)
        
        # Calculate minimum distance from each solution to all other solutions
        min_distances = []
        for i in range(n_solutions):
            distances = []
            for j in range(n_solutions):
                if i != j:
                    # Euclidean distance between solution i and j
                    dist = np.sqrt(np.sum((objectives[i] - objectives[j]) ** 2))
                    distances.append(dist)
            
            # Find minimum distance for solution i
            if distances:
                min_distances.append(min(distances))
        
        if not min_distances:
            return 0.0
        
        # Calculate mean of minimum distances
        mean_min_dist = np.mean(min_distances)
        
        # Calculate standard deviation (SP metric)
        if len(min_distances) == 1:
            return 0.0
        
        variance = np.sum([(d - mean_min_dist) ** 2 for d in min_distances]) / (len(min_distances) - 1)
        spacing = np.sqrt(variance)
        
        return spacing

    def update_and_cal_pareto_front(self):
        self.objs = [self.evaluator.evaluate(agent)[1][1:] for agent in self.pops]
        self.objs_his.append(self.objs)
        self.pareto_front = {}
        for i, candidate in enumerate(self.objs):
            is_pareto = True
            for j, other in enumerate(self.objs):
                if i != j:  # 不与自己比较
                    if all(x <= y for x, y in zip(other, candidate)) and any(x < y for x, y in zip(other, candidate)):
                        is_pareto = False
                        break
            if is_pareto:
                self.pareto_front[i] = {
                    'id': i,
                    'weight': self.weights[i],
                    'objs': self.objs[i],
                    'agent': self.pops[i],
                }

        # Calculate SP (Spacing) metric for Pareto front
        if self.pareto_front:
            pf_objectives = [self.pareto_front[i]['objs'] for i in self.pareto_front.keys()]
            sp = self.compute_spacing(pf_objectives)
            self.sp_his.append(sp)
        else:
            self.sp_his.append(0.0)

        objs_np = np.array(self.objs)
        # 假设参考点为比所有目标都大的点（最大值）
        reference_point = np.array([self.ref, self.ref])  # 用最大值作为参考点
        # 使用pymoo中的HV类计算超体积
        hv = HV(reference_point)
        # 返回归一化后数据的超体积
        return hv(objs_np)

    def exe_task_qlearning(self, agent, weight, n_ep=5):
        for i_episode in range(n_ep):
            episode_return = 0
            state = self.env.reset()[-1]
            done = False
            while not done:
                action = agent.take_action(state, self.env.available_actions)
                next_state, rewardv, done, _ = self.env.step(action)
                reward = rewardv[1] * weight[0] + rewardv[2] * weight[1]
                # reward = rewardv[1] * weight[0]
                next_state = next_state[-1]
                episode_return += reward  # 这里回报的计算不进行折扣因子衰减
                agent.update(state, action, reward, next_state)
                agent.epsilon = agent.epsilon * agent.epsilon_degrade
                state = next_state
            agent.return_list.append(episode_return)
            agent.weight_list.append(weight)

    def exe_task_dqn(self, agent, weight, n_ep=5):
        for i_episode in range(n_ep):
            episode_return = 0
            state = self.env.reset()
            done = False
            while not done:
                action = agent.take_action(state, self.env.available_actions)
                next_state, rewardv, done, _ = self.env.step(action)
                next_state = next_state
                reward = rewardv[1] * weight[0] + rewardv[2] * weight[1]
                agent.replay_buffer.add(state, action, reward, next_state, done)
                state = next_state
                episode_return += reward
                # 当buffer数据的数量超过一定值后,才进行Q网络训练
                if agent.replay_buffer.size() > agent.minimal_size:
                    b_s, b_a, b_r, b_ns, b_d = agent.replay_buffer.sample(agent.batch_size)
                    transition_dict = {
                        'states': b_s,
                        'actions': b_a,
                        'next_states': b_ns,
                        'rewards': b_r,
                        'dones': b_d
                    }
                    agent.update(transition_dict)
                    agent.epsilon = agent.epsilon * agent.epsilon_degrade
            agent.return_list.append(episode_return)
            agent.weight_list.append(weight)

    def exe_task(self, agent, weight, n_ep=5):
        if agent.who == 'q_learning':
            self.exe_task_qlearning(agent, weight, n_ep)
        elif agent.who == 'dqn':
            self.exe_task_dqn(agent, weight, n_ep)
        else:
            pass

    def main(self):
        # initialize all policies
        for i in range(self.n_pop):
            agent = self.pops[i]
            weight = self.weights[i]
            self.exe_task(agent, weight, 10)
        hv = self.update_and_cal_pareto_front()
        self.hv_his.append(hv)
        print(f'initialize finish, initial hv: {hv}')
        print(f'Using selection method: {self.selection_method}')
        
        # main iteration
        for i_iter in range(10):  # 显示10个进度条
            # tqdm的进度条功能
            with tqdm(total=int(self.iter_max / 10), desc='Iteration %d' % i_iter) as pbar:
                for i_episode in range(int(self.iter_max / 10)):  # 每个进度条的序列数
                    self.iteration_count += 1
                    
                    if self.selection_method == 'moead':
                        # MOEA/D selection method
                        selected_agents, selected_weights = self.moead_weight_selection()
                        
                        # Train selected agents
                        for agent_id, agent in selected_agents.items():
                            weight = selected_weights[agent_id]
                            self.exe_task(agent, weight, 10)
                            
                            # Evaluate and potentially update population
                            origin_fit_obj = self.objs[agent_id]
                            iter_after_obj = self.evaluator.evaluate(agent)[1][1:]
                            if all(x >= y for x, y in zip(origin_fit_obj, iter_after_obj)) and any(x > y for x, y in zip(origin_fit_obj, iter_after_obj)):
                                self.pops[agent_id] = agent
                    
                    elif self.selection_method == 'pfa':
                        # PFA selection method
                        pa_select_index = self.pfa_weight_adjustment()
                        
                        # Train with adjusted weights
                        for i in pa_select_index:
                            agent = deepcopy(self.pops[i])
                            weight = self.weights[i]
                            self.exe_task(agent, weight, 10)
                            
                            origin_fit_obj = self.objs[i]
                            iter_after_obj = self.evaluator.evaluate(agent)[1][1:]
                            if all(x >= y for x, y in zip(origin_fit_obj, iter_after_obj)) and any(x > y for x, y in zip(origin_fit_obj, iter_after_obj)):
                                self.pops[i] = agent
                            else:
                                if random.random() <= self.sa_ratio:
                                    self.pops[i] = agent
                    
                    elif self.selection_method == 'pa2d_adaptive':
                        # PA2D-inspired adaptive method
                        current_iteration = i_iter * int(self.iter_max / 10) + i_episode
                        
                        if current_iteration < self.pa2d_phase_transition:
                            # Phase 1: PA2D-inspired Pareto ascent only
                            pa_select_index = self.pa2d_policy_selection()
                            self.pa2d_pareto_ascent_direction_weight_adjust(pa_select_index)
                            pb_select_index = []
                        else:
                            # Phase 2: Combined PA2D ascent and fine-tuning
                            total_policies = max(6, int(len(self.weights) * 0.3))
                            half_policies = total_policies // 2
                            
                            pa_select_index = self.pa2d_policy_selection(ratio=half_policies/len(self.weights))
                            self.pa2d_pareto_ascent_direction_weight_adjust(pa_select_index)
                            
                            pb_select_index = self.pa2d_pareto_adaptive_fine_tuning_weights_adjust(half_policies)
                        
                        # Train PA selected agents
                        for i in pa_select_index:
                            agent = deepcopy(self.pops[i])
                            weight = self.weights[i]
                            self.exe_task(agent, weight, 10)
                            
                            origin_fit_obj = self.objs[i]
                            iter_after_obj = self.evaluator.evaluate(agent)[1][1:]
                            if all(x >= y for x, y in zip(origin_fit_obj, iter_after_obj)) and any(x > y for x, y in zip(origin_fit_obj, iter_after_obj)):
                                self.pops[i] = agent
                            else:
                                if random.random() <= self.sa_ratio:
                                    self.pops[i] = agent
                        
                        # Train PB selected agents (fine-tuning)
                        for i in pb_select_index:
                            if i < len(self.pops):
                                agent = self.pops[i]
                                weight = self.weights[i]
                                self.exe_task(agent, weight, 5)
                    
                    else:
                        # Original adaptive method
                        pa_select_index = self.select_policies()
                        pa_agents = {}
                        self.policy_pareto_ascent_direction_weight_adjust(pa_select_index)  # 需要理论
                        for i in pa_select_index:
                            agent = deepcopy(self.pops[i])
                            weight = self.weights[i]
                            self.exe_task(agent, weight, 10)
                            pa_agents[i] = agent
                            origin_fit_obj = self.objs[i]
                            iter_after_obj = self.evaluator.evaluate(agent)[1][1:]
                            if all(x >= y for x, y in zip(origin_fit_obj, iter_after_obj)) and any(x > y for x, y in zip(origin_fit_obj, iter_after_obj)):
                                self.pops[i] = agent
                            else:
                                if random.random() <= self.sa_ratio:
                                    self.pops[i] = agent
                                else:
                                    # self.pops[i] = agent
                                    pass
                        
                        # Fine-tuning for adaptive method
                        pb_select_index = self.pareto_adaptive_fine_tuning_weights_adjust()
                        # pb_select_index = self.pareto_random_weights_adjust()
                        for i in pb_select_index:
                            agent = self.pops[i]
                            weight = self.weights[i]
                            self.exe_task(agent, weight, 5)
                    
                    # Update parameters
                    self.sa_ratio = self.sa_ratio * self.cooling_ratio
                    if self.selection_method in ['adaptive', 'pa2d_adaptive']:
                        self.delta_weight *= 0.95
                    
                    hv = self.update_and_cal_pareto_front()
                    self.hv_his.append(hv)
                    pbar.set_postfix({'HV': hv, 'SA ratio': self.sa_ratio, 'Method': self.selection_method})
                    pbar.set_description(f"Iter {i_iter + 1}/10 - HV: {hv:.4f}")
                    pbar.update(1)

    def plt_pf_scatter(self):
        plt.figure()
        x, y = [i[0] for i in self.objs_his[0]], [i[1] for i in self.objs_his[0]]
        plt.scatter(x, y, label='Initial Pop', color='g')
        # Scatter plot for ea.objs
        x, y = [i[0] for i in self.objs], [i[1] for i in self.objs]
        plt.scatter(x, y, label='EA Pop')
        # Scatter plot for Pareto front
        x, y = [self.pareto_front[i]['objs'][0] for i in self.pareto_front.keys()],\
               [self.pareto_front[i]['objs'][1] for i in self.pareto_front.keys()]
        plt.scatter(x, y, label='Pareto Front', color='r')
        # Add labels for each point in Pareto front
        for i in self.pareto_front.keys():
            x, y = self.pareto_front[i]['objs'][0], self.pareto_front[i]['objs'][1]
            xi, yi = self.pareto_front[i]['weight'][0], self.pareto_front[i]['weight'][1]
            plt.text(x, y, f'({i}, {xi:.2f}, {yi:.2f})', fontsize=8, color='r')
        plt.legend()
        plt.title(f'Pareto Front - Method: {self.selection_method}')
        # 不显示图表，避免卡住实验
        plt.close()

    def plt_hv_iter(self):
        plt.figure()
        plt.plot(list(range(len(self.hv_his))), self.hv_his)
        mv_hv_his = moving_average(self.hv_his, 5)
        mv_hv_his = moving_average(mv_hv_his, 5)
        plt.plot(list(range(len(mv_hv_his))),
                 mv_hv_his,
                 label='mv')
        plt.title(f'Hypervolume Evolution - Method: {self.selection_method}')
        plt.xlabel('Iteration')
        plt.ylabel('Hypervolume')
        plt.legend()
        # 不显示图表，避免卡住实验
        plt.close()


if __name__ == '__main__':
    setup_seed(0)
    n_customers = 20
    print(f"=== {n_customers}个客户CVRP测试 ===")
    # 智能生成CVRP实例
    distance_matrix_v, customer_demands, coordinates = generate_cvrp_instance(n_customers)
    distance_matrix_d = distance_matrix_v * 0.9

    print(f"节点配置: 1个depot + {n_customers}个客户")
    print(f"客户需求范围: {min(customer_demands.values())}-{max(customer_demands.values())}")
    print(f"总需求量: {sum(customer_demands.values())}")

    # 智能配置车队 (根据客户数量调整)
    n_vehicles = 1
    n_drones = 5
    vehicle_capacity = max(50, sum(customer_demands.values()) // n_vehicles * 1.2)  # 容量稍大于平均需求
    drone_capacity = 20
    print(f"智能车队配置: {n_vehicles}辆车，每辆容量{vehicle_capacity:.0f}")
    print(f"总容量: {n_vehicles * vehicle_capacity:.0f}")
    print()

    # 创建环境
    env = VDEnvironment(
        distance_matrix_v=distance_matrix_v,
        distance_matrix_d=distance_matrix_d,
        customer_demands=customer_demands,
        n_vehicles=n_vehicles,
        n_drones=n_drones,
        vehicle_capacity=int(vehicle_capacity),
        drone_capacity=int(drone_capacity),
        vehicle_speed=1.0,
        drone_speed=1.5
    )
    n_planning = 2
    print('Q-planning步数为：%d' % n_planning)
    epsilon = 0.95
    epsilon_degrade = 0.9
    alpha = 0.1
    gamma = 0.9
    dynaq = DynaQ(env.n_nodes, epsilon, alpha, gamma, n_planning, epsilon_degrade)
    evaluator = Evaluator(env, 3, env.n_nodes)

    # Test different methods
    methods = ['adaptive', 'moead', 'pfa']
    
    for method in methods:
        print(f"\n=== Testing {method.upper()} method ===")
        ea = EA(dynaq, env, evaluator, n_pop=20, n_obj=2, iter_max=1000, 
                selection_method=method, min_weight=0.0, max_weight=1.0, delta_weight=0.2)
        ea.main()
        ea.plt_pf_scatter()
        ea.plt_hv_iter()
        print(f'Final HV for {method}: {ea.hv_his[-1]}')

