import json
import matplotlib.pyplot as plt
import os
import numpy as np


def load_training_history(json_path):
    """Load training trajectory JSON"""
    with open(json_path, 'r') as f:
        data = json.load(f)
    return data


def plot_training_rewards(data, save_dir=None):
    """Draw Reward curve"""
    rewards = data['rewards']
    moving_avg = data['moving_avg_rewards']

    plt.figure(figsize=(12, 6))
    plt.plot(rewards, alpha=0.6, label='Episode Reward')
    plt.plot(moving_avg, 'r-', linewidth=2, label='Moving Avg (100 episodes)')

    eval_episodes = [e['episode'] for e in data['evaluation_history']]
    plt.scatter(eval_episodes, [moving_avg[e - 1] for e in eval_episodes],
                c='green', marker='o', label='Evaluation Points')

    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.title('Training Rewards')
    plt.legend()
    plt.grid()

    if save_dir:
        plt.savefig(os.path.join(save_dir, 'reconstructed_training_rewards.png'))
    plt.show()


def plot_evaluation_metrics(data, save_dir=None):
    """Draw quantitative metric curve"""
    eval_history = data['evaluation_history']

    plt.figure(figsize=(12, 8))

    plt.subplot(2, 2, 1)
    plt.plot([e['episode'] for e in eval_history],
             [e['metrics']['mean_reward'] for e in eval_history], 'o-')
    plt.xlabel('Episode')
    plt.ylabel('Mean Reward')
    plt.title('Evaluation Mean Reward')
    plt.grid()

    plt.subplot(2, 2, 2)
    plt.plot([e['episode'] for e in eval_history],
             [e['metrics']['mean_fuel'] for e in eval_history], 'o-')
    plt.xlabel('Episode')
    plt.ylabel('Mean Fuel')
    plt.title('Evaluation Mean Fuel')
    plt.grid()

    plt.subplot(2, 2, 3)
    plt.plot([e['episode'] for e in eval_history],
             [e['metrics']['success_rate'] for e in eval_history], 'o-')
    plt.xlabel('Episode')
    plt.ylabel('Success Rate')
    plt.title('Evaluation Success Rate')
    plt.grid()

    plt.subplot(2, 2, 4)
    plt.plot([e['episode'] for e in eval_history],
             [e['metrics']['nws'] for e in eval_history], 'o-')
    plt.xlabel('Episode')
    plt.ylabel('NWS')
    plt.title('Normalized Weighted Score')
    plt.grid()

    plt.tight_layout()

    if save_dir:
        plt.savefig(os.path.join(save_dir, 'reconstructed_evaluation_metrics.png'))
    plt.show()


def plot_best_nws_progress(data, save_dir=None):
    eval_history = data['evaluation_history']

    best_nws = -float('inf')
    best_nws_history = []
    episodes = []

    for eval_point in eval_history:
        current_nws = eval_point['metrics']['nws']
        if current_nws > best_nws:
            best_nws = current_nws
        best_nws_history.append(best_nws)
        episodes.append(eval_point['episode'])

    plt.figure(figsize=(10, 6))
    plt.plot(episodes, best_nws_history, 'g-o', linewidth=2, markersize=8)

    max_idx = np.argmax(best_nws_history)
    plt.annotate(f'Best NWS: {best_nws_history[max_idx]:.3f}',
                 xy=(episodes[max_idx], best_nws_history[max_idx]),
                 xytext=(5, -30), textcoords='offset points',
                 bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.5),
                 arrowprops=dict(arrowstyle='->'))

    plt.xlabel('Episode')
    plt.ylabel('Best NWS')
    plt.title('Progression of Best Normalized Weighted Score')
    plt.grid(True)

    if save_dir:
        plt.savefig(os.path.join(save_dir, 'reconstructed_best_nws.png'))
    plt.show()


def visualize_training_history(json_path, save_dir=None):
    if save_dir and not os.path.exists(save_dir):
        os.makedirs(save_dir)

    data = load_training_history(json_path)

    plot_training_rewards(data, save_dir)
    plot_evaluation_metrics(data, save_dir)
    plot_best_nws_progress(data, save_dir)


if __name__ == "__main__":
    import argparse
    which_experiment = 'dqn_lunarlander_20250513_221900'
    json_path = 'results/' + which_experiment + '/training_history.json'
    save_dir = None
    # parser = argparse.ArgumentParser(description='Visualize DQN training history from JSON file')
    # parser.add_argument('json_path', type=str, help='Path to the training_history.json file')
    # parser.add_argument('--save_dir', type=str, default=None,
    #                     help='Directory to save reconstructed plots (optional)')

    # args = parser.parse_args()

    visualize_training_history(json_path, save_dir)