import os
import torch
import numpy as np
import gymnasium as gym
from typing import List, Dict, Optional, Tuple
import json
import matplotlib.pyplot as plt
import base64
from PIL import Image
import io
from tabulate import tabulate
import itertools

def load_agent(training_dir: str, env: gym.Env, agent_type: str = 'ppo') -> object:
    """Load model（PPO or DQN）"""

    prefix = 'ppo' if agent_type.lower() == 'ppo' else 'dqn'

    model_path = os.path.join(training_dir, f"{prefix}_model_best.pth")
    if not os.path.exists(model_path):
        model_path = os.path.join(training_dir, f"{prefix}_model_final.pth")
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"No {prefix.upper()} model found in {training_dir}")

    if agent_type.lower() == 'ppo':
        from PPO_RacingCar_continue_v2 import PPOCarRacingAgent
        agent = PPOCarRacingAgent(env, frame_skip=1)
    else:
        from DQN_RacingCar_v4 import CarRacingAgent
        agent = CarRacingAgent(env, frame_skip=1)
        agent.epsilon = 0.01

    agent.load(model_path)
    print(f"Loaded {prefix.upper()} model from {model_path}")
    return agent


def get_training_stats(training_dir: str) -> Dict:
    history_path = os.path.join(training_dir, "training_history.json")
    with open(history_path, 'r') as f:
        history = json.load(f)

    stats = {
        'best_reward': max(history['rewards']),
        'avg_reward': np.mean(history['rewards']),
        'best_moving_avg': max(history['moving_avg_rewards']),
        'avg_moving_avg': np.mean(history['moving_avg_rewards']),
    }

    if 'evaluation_history' in history:
        eval_nws = [e['metrics']['nws'] for e in history['evaluation_history']]
        eval_episodes = [e['episode'] for e in history['evaluation_history']]
        eval_mean_reward = [e['metrics']['mean_reward'] for e in history['evaluation_history']]

        best_nws_index = np.argmax(eval_nws)

        stats.update({
            'best_eval_nws': max(eval_nws),
            'best_eval_episode': eval_episodes[best_nws_index],
            'best_eval_mean_reward': eval_mean_reward[best_nws_index],
            'avg_eval_nws': np.mean(eval_nws),
            'avg_eval_mean_reward': np.mean(eval_mean_reward),
        })

    return stats


def run_test(
        training_dir: str,
        seeds: List[int],
        max_steps: int = 200,
        render: bool = False,
        agent_type: str = 'ppo'
) -> Dict:

    env = gym.make('CarRacing-v3', render_mode='human' if render else 'rgb_array', domain_randomize=False, continuous=True)

    agent = load_agent(training_dir, env, agent_type)
    training_stats = get_training_stats(training_dir)

    if agent_type.lower() == 'ppo':
        from PPO_RacingCar_continue_v2 import evaluate_agent
    else:
        from DQN_RacingCar_v4 import evaluate_agent

    test_results = evaluate_agent(agent, env, seeds, max_steps)

    enhanced_results = {
        'training_stats': training_stats,
        'test_results': test_results,
        'model_path': training_dir,
        'agent_type': agent_type.upper()
    }

    print_enhanced_results(enhanced_results)
    env.close()
    return enhanced_results


def print_enhanced_results(results: Dict):
    print(f"\n{'=' * 60}")
    print(f"{results['agent_type']} Model Performance Summary")
    print(f"Model Path: {results['model_path']}")
    print(f"{'=' * 60}")

    train_data = [
        ["Best Episode Reward", f"{results['training_stats']['best_reward']:.4f}"],
        ["Average Episode Reward", f"{results['training_stats']['avg_reward']:.4f}"],
        ["Best Moving Avg (100 ep)", f"{results['training_stats']['best_moving_avg']:.4f}"],
        ["Average Moving Avg", f"{results['training_stats']['avg_moving_avg']:.4f}"]
    ]

    if 'best_eval_nws' in results['training_stats']:
        train_data.extend([
            ["Best Evaluation NWS", f"{results['training_stats']['best_eval_nws']:.4f}"],
            ["Best Eval Episode", f"{results['training_stats']['best_eval_episode']}"],
            ["Best Eval Mean Reward", f"{results['training_stats']['best_eval_mean_reward']:.4f}"],
            # ["Average Evaluation NWS", f"{results['training_stats']['avg_eval_nws']:.4f}"],
            # ["Average Eval Mean Reward", f"{results['training_stats']['avg_eval_mean_reward']:.4f}"],
        ])

    print("\nTraining Statistics:")
    print(tabulate(train_data, headers=["Metric", "Value"], tablefmt="grid"))

    # 测试统计
    test_data = [
        ["Mean Test Reward", f"{results['test_results']['mean_reward']:.4f}"],
        ["Best Test Reward",
         f"{max(results['test_results']['episodes_recorder'][ep]['episode_reward'] for ep in results['test_results']['episodes_recorder']):.4f}"],
        ["Worst Test Reward",
         f"{min(results['test_results']['episodes_recorder'][ep]['episode_reward'] for ep in results['test_results']['episodes_recorder']):.4f}"],
        ["Normalized Weighted Score (NWS)", f"{results['test_results']['nws']:.4f}"]
    ]

    print("\nTest Statistics:")
    print(tabulate(test_data, headers=["Metric", "Value"], tablefmt="grid"))
    print(f"{'=' * 60}\n")


def test_multiple_agents(
        training_dirs: List[str],
        seeds: List[int],
        max_steps: int = 200,
        render: bool = False,
        plot_results: bool = True,
        agent_type: str = 'ppo'
) -> Tuple[Dict[str, Dict], Dict]:
    """
    Test models from multiple training directories.

    Returns:
        Tuple[individual_results, summary_stats]
        individual_results: Detailed results for each model.
        summary_stats: Summary statistics (average and best across all models).
    """
    individual_results = {}

    for dir_path in training_dirs:
        print(f"\n{'#' * 80}")
        print(f"Testing {agent_type.upper()} model in directory: {dir_path}")
        print(f"{'#' * 80}")

        try:
            test_result = run_test(dir_path, seeds, max_steps, render, agent_type)
            individual_results[dir_path] = test_result
        except Exception as e:
            print(f"Error testing model in {dir_path}: {str(e)}")
            individual_results[dir_path] = {"error": str(e)}

    # 计算汇总统计
    valid_results = [res for res in individual_results.values() if "error" not in res]

    summary_stats = {
        'training_summary': {},
        'test_summary': {},
        'agent_type': agent_type.upper(),
        'num_models': len(valid_results)  # 总是包含num_models
    }

    if valid_results:

        best_eval_nws_list = [res['training_stats']['best_eval_nws'] for res in valid_results
                              if 'best_eval_nws' in res['training_stats']]

        train_summary = {
            'best_of_best_eval_nws': max(best_eval_nws_list) if best_eval_nws_list else None,
            'avg_of_best_eval_nws': np.mean(best_eval_nws_list) if best_eval_nws_list else None,
        }

        test_summary = {
            'avg_mean_reward': np.mean([res['test_results']['mean_reward'] for res in valid_results]),
            'avg_best_test_reward': np.mean([max(
                res['test_results']['episodes_recorder'][ep]['episode_reward'] for ep in
                res['test_results']['episodes_recorder']) for res in valid_results]),
            'best_mean_reward': max([res['test_results']['mean_reward'] for res in valid_results]),
            'best_of_best_test_reward': max([max(
                res['test_results']['episodes_recorder'][ep]['episode_reward'] for ep in
                res['test_results']['episodes_recorder']) for res in valid_results]),
            'avg_nws': np.mean([res['test_results']['nws'] for res in valid_results]),
            'best_nws': max([res['test_results']['nws'] for res in valid_results]),
        }

        summary_stats.update({
            'training_summary': train_summary,
            'test_summary': test_summary
        })
    else:
        summary_stats.update({
            'error': f"No valid {agent_type.upper()} test results"
        })

    print_summary_stats(summary_stats)

    return individual_results, summary_stats


def print_summary_stats(summary: Dict):
    """Print the summary statistics of multiple models."""

    print(f"\n{'#' * 80}")
    print(f"{summary['agent_type']} Models Summary Statistics ({summary['num_models']} models)")
    print(f"{'#' * 80}")

    if 'error' in summary:
        print(summary['error'])
        return

    train_summary_data = []

    if 'best_of_best_eval_nws' in summary['training_summary']:
        train_summary_data.extend([
            ["Best of Best Evaluation NWS",
             f"{summary['training_summary']['best_of_best_eval_nws']:.4f}"
             if summary['training_summary']['best_of_best_eval_nws'] is not None else "N/A"],
            ["Average of Best Evaluation NWS",
             f"{summary['training_summary']['avg_of_best_eval_nws']:.4f}"
             if summary['training_summary']['avg_of_best_eval_nws'] is not None else "N/A"]
        ])

    print("\nTraining Summary Across Models:")
    print(tabulate(train_summary_data, headers=["Metric", "Value"], tablefmt="grid"))

    test_summary_data = [
        ["Average Mean Test Reward", f"{summary['test_summary'].get('avg_mean_reward', 0):.4f}"],
        ["Best Mean Test Reward", f"{summary['test_summary'].get('best_mean_reward', 0):.4f}"],
        ["Average of Best Test Rewards", f"{summary['test_summary'].get('avg_best_test_reward', 0):.4f}"],
        ["Best of Best Test Rewards", f"{summary['test_summary'].get('best_of_best_test_reward', 0):.4f}"],
        ["Average NWS", f"{summary['test_summary'].get('avg_nws', 0):.4f}"],
        ["Best NWS", f"{summary['test_summary'].get('best_nws', 0):.4f}"]
    ]

    print("\nTest Summary Across Models:")
    print(tabulate(test_summary_data, headers=["Metric", "Value"], tablefmt="grid"))
    print(f"{'#' * 80}\n")


if __name__ == "__main__":
    saving_root = {
            'dqn': [
                r'results\dqn_carracing_20250605_115728_simple',
                r'results\dqn_carracing_20250608_115712_simple',
                r'results\dqn_carracing_20250610_144911_simple',
                r'results\dqn_carracing_20250613_122801_simple',
                r'results\dqn_carracing_20250615_053724_simple',
            ],
            'ppo': [
                r'results\ppo_carracing_continuous_20250531_002542_simple',
                r'results\ppo_carracing_continuous_20250605_011022',
                r'results\ppo_carracing_continuous_20250606_223600',
                r'results\ppo_carracing_continuous_20250617_010440_simple',
                r'results\ppo_carracing_continuous_20250618_052952_simple',]
        }


    # DEFAULT_SEEDS = [i for i in range(1)]
    DEFAULT_SEEDS = (40, 1231, 516, 413)  # Trai==ning

    DEFAULT_MAX_STEPS = 1200
    DEFAULT_RENDER = False
    AGENT_TYPE = 'dqn'  # 'ppo' or 'dqn'

    TRAINING_DIRS = saving_root[AGENT_TYPE]

    individual_results, summary_stats = test_multiple_agents(
        training_dirs=TRAINING_DIRS,
        seeds=DEFAULT_SEEDS,
        max_steps=DEFAULT_MAX_STEPS,
        render=DEFAULT_RENDER,
        agent_type=AGENT_TYPE
    )