import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.ticker import MaxNLocator

def plot_training_history(agent):
    """绘制训练损失历史"""
    plt.figure(figsize=(15, 5))
    
    if agent.algorithm == "SAC":
        plt.subplot(1, 3, 1)
        plt.plot(agent.train_info['critic_losses'])
        plt.title('Critic损失')
        plt.xlabel('更新次数')
        plt.ylabel('损失')
        plt.grid(True)
        plt.subplot(1, 3, 2)
        plt.plot(agent.train_info['actor_losses'])
        plt.title('Actor损失')
        plt.xlabel('更新次数')
        plt.ylabel('损失')
        plt.grid(True)
        plt.subplot(1, 3, 3)
        plt.plot(agent.train_info['alphas'])
        plt.title('Alpha变化')
        plt.xlabel('更新次数')
        plt.ylabel('Alpha值')
        plt.grid(True)
    
    elif agent.algorithm == "TD3":
        plt.subplot(1, 2, 1)
        plt.plot(agent.train_info['critic_losses'])
        plt.title('Critic损失')
        plt.xlabel('更新次数')
        plt.ylabel('损失')
        plt.grid(True)
        plt.subplot(1, 2, 2)
        plt.plot(agent.train_info['actor_losses'])
        plt.title('Actor损失')
        plt.xlabel('更新次数')
        plt.ylabel('损失')
        plt.grid(True)
    
    elif agent.algorithm == "PPO":
        plt.subplot(2, 2, 1)
        plt.plot(agent.train_info['total_losses'])
        plt.title('总损失')
        plt.xlabel('更新次数')
        plt.ylabel('损失')
        plt.grid(True)
        plt.subplot(2, 2, 2)
        plt.plot(agent.train_info['value_losses'])
        plt.title('值损失')
        plt.xlabel('更新次数')
        plt.ylabel('损失')
        plt.grid(True)
        plt.subplot(2, 2, 3)
        plt.plot(agent.train_info['action_losses'])
        plt.title('动作损失')
        plt.xlabel('更新次数')
        plt.ylabel('损失')
        plt.grid(True)
        plt.subplot(2, 2, 4)
        plt.plot(agent.train_info['entropy_losses'])
        plt.title('熵损失')
        plt.xlabel('更新次数')
        plt.ylabel('损失')
        plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(f"{agent.algorithm}_training_history.png")
    plt.show()


def plot_evaluation_results(eval_results, algorithm_name):
    """绘制评估结果"""
    iterations = [res[0] for res in eval_results]
    rmse_values = [res[1]['avg_rmse'] for res in eval_results]
    success_rates = [res[1]['success_rate'] for res in eval_results]
    
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 2, 1)
    plt.plot(iterations, rmse_values, marker='o', linestyle='-')
    plt.title(f'{algorithm_name} - RMSE变化趋势')
    plt.xlabel('训练迭代次数')
    plt.ylabel('平均RMSE')
    plt.grid(True)
    plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
    plt.subplot(1, 2, 2)
    plt.plot(iterations, success_rates, marker='o', linestyle='-', color='green')
    plt.title(f'{algorithm_name} - 成功率变化趋势')
    plt.xlabel('训练迭代次数')
    plt.ylabel('成功率')
    plt.grid(True)
    plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
    
    plt.tight_layout()
    plt.savefig(f"{algorithm_name}_eval_results.png")
    plt.show()


def compare_algorithms(results_dict):
    """
    比较不同算法的性能
    
    参数:
        results_dict: 包含各算法最终评估结果的字典
    """
    algorithms = list(results_dict.keys())
    metrics = ['avg_rmse', 'success_rate', 'avg_treatment_similarity', 'avg_steps_used']
    metric_names = ['平均RMSE', '成功率', '治疗相似性', '平均步数']
    plt.figure(figsize=(12, 10))
    
    for i, metric in enumerate(metrics):
        plt.subplot(2, 2, i+1)
        
        values = [results_dict[alg][metric] for alg in algorithms]
        if metric == 'success_rate':
            values = [v * 100 for v in values]  
        
        bars = plt.bar(algorithms, values, color=['skyblue', 'lightgreen', 'salmon'])
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                     f'{height:.3f}' if metric != 'success_rate' else f'{height:.1f}%',
                     ha='center', va='bottom', fontsize=9)
        
        plt.title(metric_names[i])
        plt.xlabel('算法')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        
        if metric == 'success_rate':
            plt.ylabel('成功率 (%)')
        else:
            plt.ylabel(metric_names[i])
    
    plt.tight_layout()
    plt.savefig("algorithms_comparison.png")
    plt.show()


def plot_treatment_plan(agent, dataset_collection, history_dict, goal, title=None):
    """
    可视化生成的治疗计划
    
    参数:
        agent: 训练好的模型
        dataset_collection: 数据集集合
        history_dict: 历史状态字典
        goal: 目标状态
        title: 图表标题
    """
    treatments, outputs, steps_used = agent.generate_treatment_plan(
        history_dict, 
        goal, 
        dataset_collection, 
        future_length=agent.future_length,
        early_stop=True
    )
    goal_value = goal.cpu().numpy() if hasattr(goal, 'cpu') else goal
    if 'outputs' in history_dict:
        history_outputs = history_dict['outputs'][0, :, 0]
    elif 'prev_outputs' in history_dict:
        history_outputs = history_dict['prev_outputs'][0, :, 0]
    else:
        history_outputs = []
    history_treatments = history_dict['current_treatments'][0]
    history_time = np.arange(len(history_outputs))
    future_time = np.arange(len(history_outputs), len(history_outputs) + len(outputs))
    
    plt.figure(figsize=(15, 8))
    plt.subplot(2, 1, 1)
    plt.plot(history_time, history_outputs, 'o-', label='历史输出')
    plt.plot(future_time, [o[0] for o in outputs], 'o-', label='预测输出')
    plt.axhline(y=goal_value, color='r', linestyle='--', label='目标')
    plt.xlabel('时间步')
    plt.ylabel('状态值')
    plt.title('状态变化轨迹' if title is None else title)
    plt.legend()
    plt.grid(True)
    plt.subplot(2, 1, 2)
    for i in range(history_treatments.shape[1]):
        plt.plot(history_time, history_treatments[:, i], 'o-', label=f'历史治疗{i+1}')
    for i in range(treatments.shape[1]):
        plt.plot(future_time, treatments[:, i], 'o-', label=f'预测治疗{i+1}')
    
    plt.xlabel('时间步')
    plt.ylabel('治疗强度')
    plt.title('治疗计划')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(f"{agent.algorithm}_treatment_plan.png")
    plt.show()
    mse = ((outputs[-1] - goal_value) ** 2).mean()
    rmse = np.sqrt(mse)
    
    print(f"目标值: {goal_value}")
    print(f"最终预测: {outputs[-1]}")
    print(f"MSE: {mse:.6f}")
    print(f"RMSE: {rmse:.6f}")
    print(f"步数: {steps_used}/{agent.future_length}")
    
    return treatments, outputs, mse, steps_used