import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from copy import deepcopy
import time

from EAframework import EA, Evaluator
from dynaQ import DynaQ
from vd_env import *
from data_ini import *
from utils import setup_seed, moving_average
from visualization_utils import TrainingVisualizer


class TraditionalEATrainer:
    """传统EA训练器，支持多种选择方法：adaptive、moead、pfa、pa2d_adaptive"""
    
    def __init__(self, n_pop=20, iter_max=1000, n_obj=2, selection_method='adaptive', 
                 min_weight=0.0, max_weight=1.0, delta_weight=0.2, custom_env=None):
        """
        初始化传统EA训练器
        
        Args:
            n_pop: 种群大小
            iter_max: 最大迭代次数
            n_obj: 目标数量
            selection_method: 选择方法 ('adaptive', 'moead', 'pfa', 'pa2d_adaptive')
            min_weight: 最小权重
            max_weight: 最大权重
            delta_weight: 权重步长
            custom_env: 自定义环境（如果提供，将使用此环境而非data_ini中的全局环境）
        """
        self.n_pop = n_pop
        self.iter_max = iter_max
        self.n_obj = n_obj
        self.selection_method = selection_method
        
        # 创建环境和智能体
        self.env = custom_env if custom_env is not None else env
        self.evaluator = Evaluator(self.env)
        
        # DynaQ参数（参考train.py）
        epsilon = 0.95
        epsilon_degrade = 0.9
        alpha = 0.1
        gamma = 0.9
        n_planning = 2
        
        # 创建基础智能体
        self.base_agent = DynaQ(self.env.n_nodes, epsilon, alpha, gamma, n_planning, epsilon_degrade)
        
        # 创建EA，支持不同选择方法
        self.ea = EA(deepcopy(self.base_agent), self.env, self.evaluator,
                    n_pop=n_pop, n_obj=n_obj, iter_max=iter_max,
                    selection_method=selection_method, min_weight=min_weight, 
                    max_weight=max_weight, delta_weight=delta_weight)
        
        # 记录实验数据（在原EA基础上增强记录）
        self.iteration_history = []
        self.strategy_history = []  # 记录使用的策略类型
        
    def train(self, verbose=True):
        """
        训练EA，支持不同的选择方法
        
        Args:
            verbose: 是否显示详细信息
        """
        if verbose:
            print("=" * 60)
            print(f"🚀 {self.selection_method.upper()} EA训练开始")
            print("=" * 60)
            print(f"训练参数:")
            print(f"  选择方法: {self.selection_method}")
            print(f"  种群大小: {self.n_pop}")
            print(f"  最大迭代次数: {self.iter_max}")
            print(f"  目标数量: {self.n_obj}")
            print(f"  环境规模: {self.env.n_nodes}个城市")
            print()
        
        start_time = time.time()
        
        # 1. 初始化阶段
        if verbose:
            print("🔧 初始化种群...")
        
        # initialize all policies
        for i in range(self.ea.n_pop):
            agent = self.ea.pops[i]
            weight = self.ea.weights[i]
            self.ea.exe_task(agent, weight, 10)
        
        initial_hv = self.ea.update_and_cal_pareto_front()
        self.ea.hv_his.append(initial_hv)
        
        if verbose:
            print(f'初始化完成，初始HV: {initial_hv:.4f}')
            print(f'初始前沿解数: {len(self.ea.pareto_front)}')
            print()
        
        # 2. 主迭代
        for i_iter in range(10):  # 显示10个进度条
            with tqdm(total=int(self.iter_max / 10), desc=f'{self.selection_method.upper()} Iteration %d' % (i_iter + 1)) as pbar:
                for i_episode in range(int(self.iter_max / 10)):  # 每个进度条的序列数
                    self.ea.iteration_count += 1
                    current_iteration = i_iter * int(self.iter_max / 10) + i_episode
                    self.iteration_history.append(current_iteration)
                    
                    strategy_used = []
                    
                    if self.selection_method == 'moead':
                        # MOEA/D selection method
                        selected_agents, selected_weights = self.ea.moead_weight_selection()
                        strategy_used.append("moead")
                        
                        # Train selected agents
                        for agent_id, agent in selected_agents.items():
                            weight = selected_weights[agent_id]
                            self.ea.exe_task(agent, weight, 10)
                            
                            # Evaluate and potentially update population
                            if agent_id < len(self.ea.objs):
                                origin_fit_obj = self.ea.objs[agent_id]
                                iter_after_obj = self.ea.evaluator.evaluate(agent)[1][1:]
                                if all(x >= y for x, y in zip(origin_fit_obj, iter_after_obj)) and any(x > y for x, y in zip(origin_fit_obj, iter_after_obj)):
                                    self.ea.pops[agent_id] = agent
                    
                    elif self.selection_method == 'pfa':
                        # PFA selection method
                        pa_select_index = self.ea.pfa_weight_adjustment()
                        strategy_used.append("pfa")
                        
                        # Train with adjusted weights
                        for i in pa_select_index:
                            agent = deepcopy(self.ea.pops[i])
                            weight = self.ea.weights[i]
                            self.ea.exe_task(agent, weight, 10)
                            
                            if i < len(self.ea.objs):
                                origin_fit_obj = self.ea.objs[i]
                                iter_after_obj = self.ea.evaluator.evaluate(agent)[1][1:]
                                if all(x >= y for x, y in zip(origin_fit_obj, iter_after_obj)) and any(x > y for x, y in zip(origin_fit_obj, iter_after_obj)):
                                    self.ea.pops[i] = agent
                                else:
                                    if np.random.random() <= self.ea.sa_ratio:
                                        self.ea.pops[i] = agent
                    
                    else:
                        # Original adaptive method
                        pa_select_index = self.ea.select_policies()
                        pa_agents = {}
                        self.ea.policy_pareto_ascent_direction_weight_adjust(pa_select_index)
                        
                        if pa_select_index:
                            strategy_used.append("pareto_ascent")
                        
                        for i in pa_select_index:
                            agent = deepcopy(self.ea.pops[i])
                            weight = self.ea.weights[i]
                            self.ea.exe_task(agent, weight, 10)
                            pa_agents[i] = agent
                            if i < len(self.ea.objs):
                                origin_fit_obj = self.ea.objs[i]
                                iter_after_obj = self.ea.evaluator.evaluate(agent)[1][1:]
                                
                                # 模拟退火接受机制
                                if all(x >= y for x, y in zip(origin_fit_obj, iter_after_obj)) and any(x > y for x, y in zip(origin_fit_obj, iter_after_obj)):
                                    self.ea.pops[i] = agent
                                else:
                                    if np.random.random() <= self.ea.sa_ratio:
                                        self.ea.pops[i] = agent
                        
                        # Fine-tuning for adaptive method
                        pb_select_index = self.ea.pareto_adaptive_fine_tuning_weights_adjust()
                        if pb_select_index:
                            strategy_used.append("paft")
                        
                        for i in pb_select_index:
                            agent = self.ea.pops[i]
                            weight = self.ea.weights[i]
                            self.ea.exe_task(agent, weight, 5)
                    
                    # 更新参数
                    self.ea.sa_ratio = self.ea.sa_ratio * self.ea.cooling_ratio
                    if self.selection_method == 'adaptive':
                        self.ea.delta_weight *= 0.95
                    
                    # 更新帕累托前沿和HV
                    hv = self.ea.update_and_cal_pareto_front()
                    self.ea.hv_his.append(hv)
                    
                    # 记录策略使用情况
                    self.strategy_history.append(strategy_used)
                    
                    # 更新进度条
                    pbar.set_postfix({
                        'HV': f'{hv:.4f}', 
                        'SA ratio': f'{self.ea.sa_ratio:.4f}',
                        'Method': self.selection_method,
                        'Strategies': '+'.join(strategy_used) if strategy_used else 'none'
                    })
                    pbar.set_description(f"{self.selection_method.upper()} Iter {i_iter + 1}/10 - HV: {hv:.4f}")
                    pbar.update(1)
                    
                    # 定期输出详细信息
                    if verbose and (current_iteration + 1) % 100 == 0:
                        elapsed_time = time.time() - start_time
                        print(f"\n📊 Iteration {current_iteration + 1}/{self.iter_max} (耗时: {elapsed_time:.1f}s)")
                        print(f"  当前HV: {hv:.4f}")
                        print(f"  选择方法: {self.selection_method}")
                        print(f"  使用策略: {'+'.join(strategy_used) if strategy_used else 'none'}")
                        print(f"  前沿解数: {len(self.ea.pareto_front)}")
                        print(f"  SA比率: {self.ea.sa_ratio:.4f}")
                        if self.selection_method == 'adaptive':
                            print(f"  权重步长: {self.ea.delta_weight:.4f}")
        
        total_time = time.time() - start_time
        
        if verbose:
            print(f"\n⏱️  {self.selection_method.upper()}训练完成，总耗时: {total_time:.1f}秒 ({total_time/60:.1f}分钟)")
            print(f"🎯 最终HV: {self.ea.hv_his[-1]:.4f}")
            print(f"🧠 最终前沿解数: {len(self.ea.pareto_front)}")
            
            # 策略使用统计
            strategy_counts = {}
            for strategies in self.strategy_history:
                for strategy in strategies:
                    strategy_counts[strategy] = strategy_counts.get(strategy, 0) + 1
            
            if strategy_counts:
                total_actions = sum(strategy_counts.values())
                print(f"\n📈 策略使用统计:")
                for strategy, count in sorted(strategy_counts.items(), key=lambda x: x[1], reverse=True):
                    percentage = (count / total_actions) * 100 if total_actions > 0 else 0
                    print(f"  {strategy[:25]:25s}: {count:4d}次 ({percentage:5.1f}%)")
    
    def get_results(self):
        """获取训练结果"""
        # 策略使用统计
        strategy_counts = {}
        for strategies in self.strategy_history:
            for strategy in strategies:
                strategy_counts[strategy] = strategy_counts.get(strategy, 0) + 1
        
        return {
            'hv_history': self.ea.hv_his,
            'sp_history': self.ea.sp_his,
            'objs_history': self.ea.objs_his,
            'strategy_history': self.strategy_history,
            'strategy_counts': strategy_counts,
            'final_hv': self.ea.hv_his[-1] if self.ea.hv_his else 0,
            'final_sp': self.ea.sp_his[-1] if self.ea.sp_his else 0,
            'final_pareto_front': self.ea.pareto_front,
            'final_weights': self.ea.weights,
            'final_objs': self.ea.objs,
            'sa_ratio': self.ea.sa_ratio,
            'delta_weight': self.ea.delta_weight,
            'selection_method': self.selection_method
        }
    
    def plot_results(self):
        """绘制训练结果，使用通用可视化工具"""
        results = self.get_results()
        
        # 使用通用可视化工具创建完整仪表板
        fig = TrainingVisualizer.create_training_dashboard(
            results, 
            algorithm_name=f"{self.selection_method.upper()} EA",
            figsize=(15, 12)
        )
        # 不显示图表，避免卡住实验
        plt.close()
        return fig
    
    def plot_pareto_front(self):
        """绘制帕累托前沿，使用通用可视化工具"""
        # 准备数据
        initial_objs = None
        final_objs = None
        pf_objs = None
        
        if len(self.ea.objs_his) > 0:
            initial_objs = self.ea.objs_his[0]
        
        final_objs = self.ea.objs
        
        if self.ea.pareto_front:
            pf_objs = [self.ea.pareto_front[i]['objs'] 
                      for i in self.ea.pareto_front.keys()]
        
        # 使用通用可视化工具
        fig = TrainingVisualizer.plot_pareto_front_scatter(
            initial_objs=initial_objs,
            final_objs=final_objs,
            pareto_front_objs=pf_objs,
            title=f"{self.selection_method.upper()} EA Final Pareto Front",
            show_labels=True
        )
        # 不显示图表，避免卡住实验
        plt.close()
        return fig


def main(n_pop=20, iter_max=1000, selection_method='adaptive', custom_env=None):
    """
    主函数，支持不同的选择方法
    
    Args:
        n_pop: 种群大小
        iter_max: 最大迭代次数
        selection_method: 选择方法 ('adaptive', 'moead', 'pfa')
        custom_env: 自定义环境（如果提供，将使用此环境而非data_ini中的全局环境）
    """
    print(f"🚀 启动{selection_method.upper()} EA独立训练")
    print(f"种群大小: {n_pop}, 最大迭代次数: {iter_max}")
    
    # 设置随机种子
    setup_seed(0)
    
    # 创建训练器
    trainer = TraditionalEATrainer(n_pop=n_pop, iter_max=iter_max, 
                                 selection_method=selection_method,
                                 min_weight=0.0, max_weight=1.0, delta_weight=0.2,
                                 custom_env=custom_env)
    
    # 执行训练
    trainer.train(verbose=True)
    
    # 获取结果
    results = trainer.get_results()
    
    # 绘制结果
    print("\n📊 生成训练结果图表...")
    trainer.plot_results()
    trainer.plot_pareto_front()
    
    print(f"\n🎉 {selection_method.upper()} EA训练完成!")
    print(f"最终HV: {results['final_hv']:.4f}")
    print(f"最终SP: {results['final_sp']:.4f}")
    print(f"前沿解数: {len(results['final_pareto_front'])}")
    
    return trainer, results


if __name__ == '__main__':
    # 测试不同的方法
    methods = ['adaptive', 'moead', 'pfa']
    
    for method in methods:
        print(f"\n{'='*60}")
        print(f"🧪 测试 {method.upper()} 方法")
        print(f"{'='*60}")
        
        trainer, results = main(n_pop=10, iter_max=100, selection_method=method)
        
        # 额外的分析
        print(f"\n🔍 {method.upper()} EA训练结果分析:")
        
        if len(trainer.ea.hv_his) > 1:
            hv_improvement = trainer.ea.hv_his[-1] - trainer.ea.hv_his[0]
            print(f"HV总改进: {hv_improvement:.4f}")
            print(f"改进率: {(hv_improvement/trainer.ea.hv_his[0])*100:.2f}%")
        
        # 策略使用统计
        if results['strategy_counts']:
            most_used_strategy = max(results['strategy_counts'].items(), key=lambda x: x[1])
            print(f"最常用策略: {most_used_strategy[0]} ({most_used_strategy[1]}次)")
        
        # 参数状态
        print(f"最终SA比率: {results['sa_ratio']:.6f}")
        if method == 'adaptive':
            print(f"最终权重步长: {results['delta_weight']:.6f}")
        
        print(f"选择方法: {results['selection_method']}") 

