import os
import torch
import numpy as np
import gymnasium as gym
from typing import List, Dict, Optional, Tuple
import json
import matplotlib.pyplot as plt
import base64
from PIL import Image
import io
from tabulate import tabulate


def load_agent(training_dir: str, env: gym.Env, agent_type: str = 'ppo') -> object:

    prefix = 'ppo' if agent_type.lower() == 'ppo' else 'dqn'


    model_path = os.path.join(training_dir, f"{prefix}_model_best.pth")
    if not os.path.exists(model_path):
        model_path = os.path.join(training_dir, f"{prefix}_model_final.pth")
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"No {prefix.upper()} model found in {training_dir}")

    state_dim = env.observation_space.shape[0]

    if agent_type.lower() == 'ppo':
        from PPO_Lunarlander import PPOAgent
        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.n
        agent = PPOAgent(state_dim, action_dim)
    else:
        from DQN_Lunarlander import DQNAgent
        action_size = env.action_space.n
        agent = DQNAgent(state_dim, action_size)
        agent.epsilon = 0.01

    agent.load(model_path)
    print(f"Loaded {prefix.upper()} model from {model_path}")
    return agent


def get_training_stats(training_dir: str) -> Dict:
    history_path = os.path.join(training_dir, "training_history.json")
    with open(history_path, 'r') as f:
        history = json.load(f)

    stats = {
        'best_reward': max(history['rewards']),
        'avg_reward': np.mean(history['rewards']),
        'best_moving_avg': max(history['moving_avg_rewards']),
        'avg_moving_avg': np.mean(history['moving_avg_rewards']),
    }

    if 'evaluation_history' in history:
        eval_nws = [e['metrics']['nws'] for e in history['evaluation_history']]
        eval_episodes = [e['episode'] for e in history['evaluation_history']]
        eval_mean_reward = [e['metrics']['mean_reward'] for e in history['evaluation_history']]
        eval_mean_fuel = [e['metrics']['mean_fuel'] for e in history['evaluation_history']]
        eval_success_rate = [e['metrics']['success_rate'] for e in history['evaluation_history']]

        best_nws_index = np.argmax(eval_nws)

        stats.update({
            'best_eval_nws': max(eval_nws),
            'best_eval_episode': eval_episodes[best_nws_index],
            'best_eval_mean_reward': eval_mean_reward[best_nws_index],
            'best_eval_mean_fuel': eval_mean_fuel[best_nws_index],
            'best_eval_success_rate': eval_success_rate[best_nws_index],
            'avg_eval_nws': np.mean(eval_nws),
            'avg_eval_mean_reward': np.mean(eval_mean_reward),
            'avg_eval_mean_fuel': np.mean(eval_mean_fuel),
            'avg_eval_success_rate': np.mean(eval_success_rate),
        })

    return stats


def run_test(
        training_dir: str,
        seeds: List[int],
        max_steps: int = 200,
        render: bool = False,
        agent_type: str = 'ppo'
        , gravity=-10, enable_wind=False
) -> Dict:

    print('gravity', gravity)
    env = gym.make("LunarLander-v3", render_mode='human' if render else 'rgb_array', gravity=gravity,
                   enable_wind=enable_wind)

    agent = load_agent(training_dir, env, agent_type)
    training_stats = get_training_stats(training_dir)

    if agent_type.lower() == 'ppo':
        from PPO_Lunarlander import evaluate_agent
    else:
        from DQN_Lunarlander import evaluate_agent

    test_results = evaluate_agent(agent, seeds, max_steps, gravity=gravity,
                   enable_wind=enable_wind)

    enhanced_results = {
        'training_stats': training_stats,
        'test_results': test_results,
        'model_path': training_dir,
        'agent_type': agent_type.upper()
    }

    print_enhanced_results(enhanced_results)

    if render and 'worst_case_image' in test_results:
        img_data = base64.b64decode(test_results['worst_case_image'])
        img = Image.open(io.BytesIO(img_data))
        plt.figure(figsize=(8, 6))
        plt.imshow(img)
        plt.title(
            f"Worst Case (Reward: {min(test_results['episodes_recorder'][ep]['episode_reward'] for ep in test_results['episodes_recorder']):.4f})")
        plt.axis('off')
        plt.show()

    env.close()
    return enhanced_results

def print_enhanced_results(results: Dict):
    print(f"\n{'=' * 60}")
    print(f"{results['agent_type']} Model Performance Summary")
    print(f"Model Path: {results['model_path']}")
    print(f"{'=' * 60}")

    train_data = [
        ["Best Episode Reward", f"{results['training_stats']['best_reward']:.4f}"],
        ["Average Episode Reward", f"{results['training_stats']['avg_reward']:.4f}"],
        ["Best Moving Avg (100 ep)", f"{results['training_stats']['best_moving_avg']:.4f}"],
        ["Average Moving Avg", f"{results['training_stats']['avg_moving_avg']:.4f}"]
    ]

    if 'best_eval_nws' in results['training_stats']:
        train_data.extend([
            ["Best Evaluation NWS", f"{results['training_stats']['best_eval_nws']:.4f}"],
            ["Best Eval Episode", f"{results['training_stats']['best_eval_episode']}"],
            ["Best Eval Mean Reward", f"{results['training_stats']['best_eval_mean_reward']:.4f}"],
            ["Best Eval Mean Fuel", f"{results['training_stats']['best_eval_mean_fuel']:.4f}"],
            ["Best Eval Success Rate", f"{results['training_stats']['best_eval_success_rate']:.4f}"],
            # ["Average Evaluation NWS", f"{results['training_stats']['avg_eval_nws']:.4f}"],
            # ["Average Eval Mean Reward", f"{results['training_stats']['avg_eval_mean_reward']:.4f}"],
            # ["Average Eval Mean Fuel", f"{results['training_stats']['avg_eval_mean_fuel']:.4f}"],
            # ["Average Eval Success Rate", f"{results['training_stats']['avg_eval_success_rate']:.4f}"]
        ])

    print("\nTraining Statistics:")
    print(tabulate(train_data, headers=["Metric", "Value"], tablefmt="grid"))

    test_data = [
        ["Mean Test Reward", f"{results['test_results']['mean_reward']:.4f}"],
        ["Best Test Reward",
         f"{max(results['test_results']['episodes_recorder'][ep]['episode_reward'] for ep in results['test_results']['episodes_recorder']):.4f}"],
        ["Worst Test Reward",
         f"{min(results['test_results']['episodes_recorder'][ep]['episode_reward'] for ep in results['test_results']['episodes_recorder']):.4f}"],
        ["Mean Fuel", f"{results['test_results']['mean_fuel']:.4f}"],
        ["Success Rate", f"{results['test_results']['success_rate']:.4f}"],
        ["Normalized Weighted Score (NWS)", f"{results['test_results']['nws']:.4f}"]
    ]

    print("\nTest Statistics:")
    print(tabulate(test_data, headers=["Metric", "Value"], tablefmt="grid"))
    print(f"{'=' * 60}\n")


def test_multiple_agents(
        training_dirs: List[str],
        seeds: List[int],
        max_steps: int = 200,
        render: bool = False,
        plot_results: bool = True,
        agent_type: str = 'ppo'
        , gravity=-10,
        enable_wind=False
) -> Tuple[Dict[str, Dict], Dict]:
    """
    Test models from multiple training directories.

    Returns:
        Tuple[individual_results, summary_stats]
        individual_results: Detailed results for each model.
        summary_stats: Summary statistics (average and best across all models).
"""

    individual_results = {}

    for dir_path in training_dirs:
        print(f"\n{'#' * 80}")
        print(f"Testing {agent_type.upper()} model in directory: {dir_path}")
        print(f"{'#' * 80}")

        try:
            test_result = run_test(dir_path, seeds, max_steps, render, agent_type, gravity=gravity,
                                   enable_wind=enable_wind)
            individual_results[dir_path] = test_result
        except Exception as e:
            print(f"Error testing model in {dir_path}: {str(e)}")
            individual_results[dir_path] = {"error": str(e)}

    valid_results = [res for res in individual_results.values() if "error" not in res]

    summary_stats = {
        'training_summary': {},
        'test_summary': {},
        'agent_type': agent_type.upper(),
        'num_models': len(valid_results)
    }

    if valid_results:

        best_eval_nws_list = [res['training_stats']['best_eval_nws'] for res in valid_results
                              if 'best_eval_nws' in res['training_stats']]

        train_summary = {
            'best_of_best_eval_nws': max(best_eval_nws_list) if best_eval_nws_list else None,
            'avg_of_best_eval_nws': np.mean(best_eval_nws_list) if best_eval_nws_list else None,
        }

        test_summary = {
            'avg_mean_reward': np.mean([res['test_results']['mean_reward'] for res in valid_results]),
            'avg_best_test_reward': np.mean([max(
                res['test_results']['episodes_recorder'][ep]['episode_reward'] for ep in
                res['test_results']['episodes_recorder']) for res in valid_results]),
            'best_mean_reward': max([res['test_results']['mean_reward'] for res in valid_results]),
            'best_of_best_test_reward': max([max(
                res['test_results']['episodes_recorder'][ep]['episode_reward'] for ep in
                res['test_results']['episodes_recorder']) for res in valid_results]),
            'avg_nws': np.mean([res['test_results']['nws'] for res in valid_results]),
            'best_nws': max([res['test_results']['nws'] for res in valid_results]),
        }

        summary_stats.update({
            'training_summary': train_summary,
            'test_summary': test_summary
        })
    else:
        summary_stats.update({
            'error': f"No valid {agent_type.upper()} test results"
        })


    print_summary_stats(summary_stats)

    return individual_results, summary_stats


def print_summary_stats(summary: Dict):
    print(f"\n{'#' * 80}")
    print(f"{summary['agent_type']} Models Summary Statistics ({summary['num_models']} models)")
    print(f"{'#' * 80}")

    if 'error' in summary:
        print(summary['error'])
        return

    train_summary_data = []

    if 'best_of_best_eval_nws' in summary['training_summary']:
        train_summary_data.extend([
            ["Best of Best Evaluation NWS",
             f"{summary['training_summary']['best_of_best_eval_nws']:.4f}"
             if summary['training_summary']['best_of_best_eval_nws'] is not None else "N/A"],
            ["Average of Best Evaluation NWS",
             f"{summary['training_summary']['avg_of_best_eval_nws']:.4f}"
             if summary['training_summary']['avg_of_best_eval_nws'] is not None else "N/A"]
        ])

    print("\nTraining Summary Across Models:")
    print(tabulate(train_summary_data, headers=["Metric", "Value"], tablefmt="grid"))

    test_summary_data = [
        ["Average Mean Test Reward", f"{summary['test_summary'].get('avg_mean_reward', 0):.4f}"],
        ["Best Mean Test Reward", f"{summary['test_summary'].get('best_mean_reward', 0):.4f}"],
        ["Average of Best Test Rewards", f"{summary['test_summary'].get('avg_best_test_reward', 0):.4f}"],
        ["Best of Best Test Rewards", f"{summary['test_summary'].get('best_of_best_test_reward', 0):.4f}"],
        ["Average NWS", f"{summary['test_summary'].get('avg_nws', 0):.4f}"],
        ["Best NWS", f"{summary['test_summary'].get('best_nws', 0):.4f}"]
    ]

    print("\nTest Summary Across Models:")
    print(tabulate(test_summary_data, headers=["Metric", "Value"], tablefmt="grid"))
    print(f"{'#' * 80}\n")


if __name__ == "__main__":
    saving_root = {'ppo': [
        r"results\ppo_lunarlander_20250603_094751",
        r'results\ppo_lunarlander_20250603_095146',
        r'results\ppo_lunarlander_20250603_112742',
        r'results\ppo_lunarlander_20250603_112848',
        r'results\ppo_lunarlander_20250603_130517',
    ],
        'dqn': [r'results\dqn_lunarlander_20250603_003755',
                r'results\dqn_lunarlander_20250603_003756',
                r'results\dqn_lunarlander_20250603_022211',
                r'results\dqn_lunarlander_20250603_024024',
                r'results\dqn_lunarlander_20250603_152545']}

    DEFAULT_SEEDS = [i for i in range(10)]
    # DEFAULT_SEEDS = [42, 520, 1231, 114, 886]  # Training
    DEFAULT_MAX_STEPS = 200
    DEFAULT_RENDER = False
    AGENT_TYPE = 'dqn'  # 'ppo' or 'dqn'

    TRAINING_DIRS = saving_root[AGENT_TYPE]

    individual_results, summary_stats = test_multiple_agents(
        training_dirs=TRAINING_DIRS,
        seeds=DEFAULT_SEEDS,
        max_steps=DEFAULT_MAX_STEPS,
        render=DEFAULT_RENDER,
        agent_type=AGENT_TYPE, gravity=-10, enable_wind=False
    )
