import sys
import os
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from copy import deepcopy
import time

from age_framework import MetaGuidedEA, MetaDQN
from EAframework import EA, Evaluator
from dynaQ import DynaQ
from vd_env import *
from data_ini import *
from utils import setup_seed, moving_average
from visualization_utils import TrainingVisualizer


class MetaDQNTrainer:
    """Meta-DQN训练器，类似EAframework.py中的EA.main()"""
    
    def __init__(self, n_pop=20, iter_max=1000, n_obj=2, custom_env=None,
                 # Meta-DQN超参数
                 dqn_learning_rate=1e-3, dqn_gamma=0.9,
                 # 密集区域参数  
                 radius_factor=1.5,
                 # 稀疏区域参数
                 length_factor=2.0, width_factor=0.8, use_abs_projection_length=False):
        """
        初始化Meta-DQN训练器
        
        Args:
            n_pop: 种群大小
            iter_max: 最大迭代次数
            n_obj: 目标数量
            custom_env: 自定义环境（如果提供，将使用此环境而非data_ini中的全局环境）
        """
        self.n_pop = n_pop
        self.iter_max = iter_max
        self.n_obj = n_obj
        
        # 创建环境和智能体
        self.env = custom_env if custom_env is not None else env
        self.evaluator = Evaluator(self.env)
        
        # DynaQ参数
        epsilon = 0.95
        epsilon_degrade = 0.9
        alpha = 0.1
        gamma = 0.9
        n_planning = 2
        
        # 创建基础智能体
        self.base_agent = DynaQ(self.env.n_nodes, epsilon, alpha, gamma, n_planning, epsilon_degrade)
        
        # 存储超参数
        self.hyperparams = {
            'dqn_learning_rate': dqn_learning_rate,
            'dqn_gamma': dqn_gamma,
            'radius_factor': radius_factor,
            'length_factor': length_factor,
            'width_factor': width_factor,
            'use_abs_projection_length': use_abs_projection_length
        }
        
        # 创建Meta-DQN指导的EA
        def create_ea():
            return EA(deepcopy(self.base_agent), self.env, self.evaluator, 
                     n_pop=n_pop, n_obj=n_obj, iter_max=iter_max)
        
        self.meta_guided_ea = MetaGuidedEA(
            create_ea,
            dqn_learning_rate=dqn_learning_rate,
            dqn_gamma=dqn_gamma,
            radius_factor=radius_factor,
            length_factor=length_factor,
            width_factor=width_factor,
            use_abs_projection_length=use_abs_projection_length
        )
        
        # 记录实验数据
        self.hv_history = []
        self.sp_history = []  # SP (Spacing) metric history
        self.pf_size_history = []
        self.strategy_history = []
        self.reward_history = []
        self.state_history = []
        self.epsilon_history = []
        
        # 性能统计
        self.strategy_counts = {}
        self.strategy_rewards = {}
        
    def train(self, verbose=True):
        """
        训练Meta-DQN，类似EA.main()的结构
        
        Args:
            verbose: 是否显示详细信息
        """
        if verbose:
            print("=" * 60)
            print("🚀 Meta-DQN MORL 训练开始")
            print("=" * 60)
            print(f"训练参数:")
            print(f"  种群大小: {self.n_pop}")
            print(f"  最大迭代次数: {self.iter_max}")
            print(f"  目标数量: {self.n_obj}")
            print(f"  环境规模: {self.env.n_nodes}个城市")
            print()
        
        # 1. 初始化阶段（类似EA的initialize）
        self._initialize_phase(verbose)
        
        # 2. 主训练循环（类似EA的main iteration）
        start_time = time.time()
        
        # 使用tqdm显示进度，参考EAframework.py的风格
        num_progress_bars = 10
        iterations_per_bar = max(1, self.iter_max // num_progress_bars)
        
        for i_bar in range(num_progress_bars):
            # 计算当前进度条的迭代范围
            start_iter = i_bar * iterations_per_bar
            end_iter = min((i_bar + 1) * iterations_per_bar, self.iter_max)
            
            with tqdm(total=end_iter - start_iter, desc=f'Progress {i_bar + 1}/10') as pbar:
                for iteration in range(start_iter, end_iter):
                    # 执行一次Meta-DQN指导的迭代
                    current_hv, strategy_name, reward = self._execute_iteration(iteration)
                    
                    # 更新进度条
                    pbar.set_postfix({
                        'HV': f'{current_hv:.4f}',
                        'Strategy': strategy_name[:10],
                        'Reward': f'{reward:.3f}',
                        'ε': f'{self.meta_guided_ea.meta_dqn.epsilon:.3f}'
                    })
                    pbar.update(1)
                    
                    # 定期输出详细信息
                    if verbose and (iteration + 1) % 100 == 0:
                        elapsed_time = time.time() - start_time
                        print(f"\n📊 Iteration {iteration + 1}/{self.iter_max} (耗时: {elapsed_time:.1f}s)")
                        print(f"  当前HV: {current_hv:.4f}")
                        print(f"  使用策略: {strategy_name}")
                        print(f"  获得奖励: {reward:.4f}")
                        print(f"  前沿解数: {len(self.meta_guided_ea.ea.pareto_front)}")
                        print(f"  Meta-DQN ε: {self.meta_guided_ea.meta_dqn.epsilon:.3f}")
        
        total_time = time.time() - start_time
        
        if verbose:
            print(f"\n⏱️  训练完成，总耗时: {total_time:.1f}秒 ({total_time/60:.1f}分钟)")
            print(f"🎯 最终HV: {self.hv_history[-1]:.4f}")
            print(f"🧠 最终前沿解数: {len(self.meta_guided_ea.ea.pareto_front)}")
            
            # 策略使用统计
            total_actions = sum(self.strategy_counts.values())
            print(f"\n📈 策略使用统计:")
            for strategy, count in sorted(self.strategy_counts.items(), key=lambda x: x[1], reverse=True):
                percentage = (count / total_actions) * 100 if total_actions > 0 else 0
                avg_reward = self.strategy_rewards.get(strategy, 0) / max(count, 1)
                print(f"  {strategy[:25]:25s}: {count:4d}次 ({percentage:5.1f}%) [平均奖励: {avg_reward:+.3f}]")
    
    def _initialize_phase(self, verbose=True):
        """初始化阶段，类似EA的初始化"""
        if verbose:
            print("🔧 初始化种群...")
        
        # 初始化所有策略（类似EA.main的initialize部分）
        for i in range(self.meta_guided_ea.ea.n_pop):
            agent = self.meta_guided_ea.ea.pops[i]
            weight = self.meta_guided_ea.ea.weights[i]
            self.meta_guided_ea.ea.exe_task(agent, weight, 10)  # 参考EA的初始化episodes
        
        # 计算初始HV和SP
        initial_hv = self.meta_guided_ea.ea.update_and_cal_pareto_front()
        self.hv_history.append(initial_hv)
        self.sp_history.append(self.meta_guided_ea.ea.sp_his[-1] if self.meta_guided_ea.ea.sp_his else 0.0)
        self.pf_size_history.append(len(self.meta_guided_ea.ea.pareto_front))
        self.epsilon_history.append(self.meta_guided_ea.meta_dqn.epsilon)
        
        if verbose:
            print(f"初始化完成，初始HV: {initial_hv:.4f}")
            print(f"初始前沿解数: {len(self.meta_guided_ea.ea.pareto_front)}")
            print()
    
    def _execute_iteration(self, iteration):
        """
        执行一次迭代，类似EA的单次迭代逻辑
        
        Returns:
            tuple: (current_hv, strategy_name, reward)
        """
        # 获取上一次的HV（用于计算奖励）
        previous_hv = self.hv_history[-1]
        
        # 1. 编码当前EA状态
        current_state = self.meta_guided_ea.meta_dqn.encode_ea_state(self.meta_guided_ea.ea)
        self.state_history.append(current_state.copy())
        
        # 2. Meta-DQN选择策略
        action = self.meta_guided_ea.meta_dqn.select_action(current_state)
        strategy_name = self.meta_guided_ea.meta_dqn.get_strategy_name(action)
        
        # 3. 执行选择的策略
        self.meta_guided_ea.execute_meta_strategy(self.meta_guided_ea.ea, action)
        
        # 4. 更新EA的温度（模拟退火）
        if hasattr(self.meta_guided_ea.ea, 'sa_ratio') and hasattr(self.meta_guided_ea.ea, 'cooling_ratio'):
            self.meta_guided_ea.ea.sa_ratio *= self.meta_guided_ea.ea.cooling_ratio
        
        # 5. 更新权重衰减（类似EA的delta_weight衰减）
        if hasattr(self.meta_guided_ea.ea, 'delta_weight'):
            self.meta_guided_ea.ea.delta_weight *= 0.99  # 稍微慢一点的衰减
        
        # 6. 计算新的HV
        current_hv = self.meta_guided_ea.ea.update_and_cal_pareto_front()
        
        # 7. 计算奖励并训练Meta-DQN
        reward, reward_components = self.meta_guided_ea.calculate_meta_reward(
            self.meta_guided_ea.ea, self.hv_history, previous_hv, current_hv, action
        )
        
        # 8. 存储经验并更新Meta-DQN
        next_state = self.meta_guided_ea.meta_dqn.encode_ea_state(self.meta_guided_ea.ea)
        self.meta_guided_ea.meta_dqn.store_experience(
            current_state, action, reward, next_state, False
        )
        self.meta_guided_ea.meta_dqn.update()
        
        # 9. 记录数据
        self.hv_history.append(current_hv)
        self.sp_history.append(self.meta_guided_ea.ea.sp_his[-1] if self.meta_guided_ea.ea.sp_his else 0.0)
        self.pf_size_history.append(len(self.meta_guided_ea.ea.pareto_front))
        self.strategy_history.append(strategy_name)
        self.reward_history.append(reward)
        self.epsilon_history.append(self.meta_guided_ea.meta_dqn.epsilon)
        
        # 10. 更新策略统计
        self.strategy_counts[strategy_name] = self.strategy_counts.get(strategy_name, 0) + 1
        if strategy_name not in self.strategy_rewards:
            self.strategy_rewards[strategy_name] = 0
        self.strategy_rewards[strategy_name] += reward
        
        return current_hv, strategy_name, reward
    
    def get_results(self):
        """获取训练结果，类似EA的结果数据"""
        return {
            'hv_history': self.hv_history,
            'sp_history': self.sp_history,
            'pf_size_history': self.pf_size_history,
            'strategy_history': self.strategy_history,
            'reward_history': self.reward_history,
            'epsilon_history': self.epsilon_history,
            'strategy_counts': self.strategy_counts,
            'objs_history': self.meta_guided_ea.ea.objs_his,
            'strategy_rewards': {k: v/max(self.strategy_counts.get(k, 1), 1) 
                               for k, v in self.strategy_rewards.items()},
            'final_hv': self.hv_history[-1] if self.hv_history else 0,
            'final_sp': self.sp_history[-1] if self.sp_history else 0,
            'final_pareto_front': self.meta_guided_ea.ea.pareto_front,
            'meta_dqn': self.meta_guided_ea.meta_dqn
        }
    
    def plot_results(self):
        """绘制训练结果，使用通用可视化工具"""
        results = self.get_results()
        
        # 使用通用可视化工具创建完整仪表板
        fig = TrainingVisualizer.create_training_dashboard(
            results, 
            algorithm_name="Meta-DQN",
            figsize=(15, 12)
        )
        # 不显示图表，避免卡住实验
        plt.close(fig)
        return fig
    
    def plot_pareto_front(self):
        """绘制帕累托前沿，使用通用可视化工具"""
        # 准备数据
        initial_objs = None
        final_objs = None
        pf_objs = None
        
        if len(self.meta_guided_ea.ea.objs_his) > 0:
            initial_objs = self.meta_guided_ea.ea.objs_his[0]
        
        final_objs = self.meta_guided_ea.ea.objs
        
        if self.meta_guided_ea.ea.pareto_front:
            pf_objs = [self.meta_guided_ea.ea.pareto_front[i]['objs'] 
                      for i in self.meta_guided_ea.ea.pareto_front.keys()]
        
        # 使用通用可视化工具
        fig = TrainingVisualizer.plot_pareto_front_scatter(
            initial_objs=initial_objs,
            final_objs=final_objs,
            pareto_front_objs=pf_objs,
            title="Meta-DQN Final Pareto Front",
            show_labels=True
        )
        # 不显示图表，避免卡住实验
        plt.close()
        return fig


def main(n_pop=20, iter_max=1000, custom_env=None):
    """
    主函数，类似EAframework.py的main
    
    Args:
        n_pop: 种群大小
        iter_max: 最大迭代次数
        custom_env: 自定义环境（如果提供，将使用此环境而非data_ini中的全局环境）
    """
    print("🚀 启动Meta-DQN MORL独立训练")
    print(f"种群大小: {n_pop}, 最大迭代次数: {iter_max}")
    
    # 设置随机种子
    setup_seed(0)
    
    # 创建训练器
    trainer = MetaDQNTrainer(n_pop=n_pop, iter_max=iter_max, custom_env=custom_env)
    
    # 执行训练
    trainer.train(verbose=True)
    
    # 获取结果
    results = trainer.get_results()
    
    # 绘制结果
    print("\n📊 生成训练结果图表...")
    trainer.plot_results()
    trainer.plot_pareto_front()
    
    print(f"\n🎉 Meta-DQN训练完成!")
    print(f"最终HV: {results['final_hv']:.4f}")
    print(f"最终SP: {results['final_sp']:.4f}")
    print(f"前沿解数: {len(results['final_pareto_front'])}")
    
    return trainer, results


if __name__ == '__main__':
    trainer, results = main(n_pop=10, iter_max=100)
    
    # 额外的分析
    print("\n" + "=" * 50)
    print("🔍 训练结果分析:")
    
    if len(trainer.hv_history) > 1:
        hv_improvement = trainer.hv_history[-1] - trainer.hv_history[0]
        print(f"HV总改进: {hv_improvement:.4f}")
        print(f"改进率: {(hv_improvement/trainer.hv_history[0])*100:.2f}%")
    
    # 最佳策略
    if trainer.strategy_counts:
        best_strategy = max(trainer.strategy_counts.items(), key=lambda x: x[1])
        print(f"最常用策略: {best_strategy[0]} ({best_strategy[1]}次)")
    
    # 平均奖励
    if trainer.reward_history:
        avg_reward = np.mean(trainer.reward_history)
        positive_ratio = sum(1 for r in trainer.reward_history if r > 0) / len(trainer.reward_history)
        print(f"平均奖励: {avg_reward:.4f}")
        print(f"正奖励比例: {positive_ratio*100:.1f}%") 
