import os
import json
import pickle
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Optional
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# SETUP PLOTTING STYLE
# ============================================================================

def setup_plot_style():
    """Setup clean ICML-style plotting with pastel colors"""
    plt.style.use('seaborn-v0_8-paper')

    plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.sans-serif': ['DejaVu Sans', 'Arial', 'Helvetica'],
        'font.size': 10,
        'axes.labelsize': 11,
        'axes.titlesize': 12,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 10,
        'figure.titlesize': 13,
        'figure.dpi': 150,
        'savefig.dpi': 300,
        'savefig.bbox': 'tight',
        'axes.grid': True,
        'grid.alpha': 0.3,
        'grid.linestyle': '--',
        'axes.spines.top': False,
        'axes.spines.right': False,
        'axes.linewidth': 1.0,
        'lines.linewidth': 2.0,
        'lines.markersize': 8,
    })

setup_plot_style()

# Pastel color palette for different configurations
PASTEL_COLORS = {
    'UCB-Hedge With FTRL': '#A8DADC',      # Light blue
    'UCB-Hedge Without FTRL': '#F4A261',    # Light orange
    'Thompson-Hedge With FTRL': '#E9C46A',  # Light yellow
    'Thompson-Hedge Without FTRL': '#E76F51' # Light coral
}

MARKER_STYLES = {
    'UCB-Hedge With FTRL': 'o',
    'UCB-Hedge Without FTRL': 's',
    'Thompson-Hedge With FTRL': '^',
    'Thompson-Hedge Without FTRL': 'v'
}

# ============================================================================
# DATA LOADING FUNCTIONS
# ============================================================================

def load_pareto_data(data_file: str) -> Dict:
    """
    Load Pareto front data from JSON or pickle file

    Parameters:
    -----------
    data_file : str
        Path to data file (.json or .pkl)

    Returns:
    --------
    data : dict
        Dictionary containing Pareto front data
    """
    if data_file.endswith('.json'):
        with open(data_file, 'r') as f:
            data = json.load(f)
    elif data_file.endswith('.pkl'):
        with open(data_file, 'rb') as f:
            data = pickle.load(f)
    else:
        raise ValueError("Data file must be .json or .pkl")

    return data


def extract_pareto_fronts_from_data(data: Dict, run_index: Optional[int] = None) -> Dict:
    """
    Extract Pareto front objectives from the data structure

    Parameters:
    -----------
    data : dict
        Loaded pareto data dictionary
    run_index : int, optional
        If specified, extract only this run. Otherwise, aggregate all runs.

    Returns:
    --------
    pareto_fronts : dict
        Dictionary mapping configuration name to list of objective arrays
        Format: {config_name: [array of objectives for run 1, run 2, ...]}
    """
    if 'pareto_fronts' not in data:
        print("Warning: 'pareto_fronts' key not found in data!")
        print(f"Available keys: {list(data.keys())}")
        return {}

    pareto_fronts = {}

    for config_name, runs_data in data['pareto_fronts'].items():
        pareto_fronts[config_name] = []

        if run_index is not None:
            # Extract specific run
            if run_index < len(runs_data):
                objectives = np.array(runs_data[run_index])
                if len(objectives) > 0:
                    pareto_fronts[config_name].append(objectives)
        else:
            # Extract all runs
            for run_objectives in runs_data:
                objectives = np.array(run_objectives)
                if len(objectives) > 0:
                    pareto_fronts[config_name].append(objectives)

    return pareto_fronts


# ============================================================================
# PARETO FRONT PLOTTING FUNCTIONS
# ============================================================================

def plot_single_run_pareto_fronts(
    data_file: str,
    run_index: int = 0,
    output_dir: str = 'pareto_plots',
    configs_to_plot: Optional[List[str]] = None,
    objective_names: Optional[List[str]] = None,
    figsize: tuple = (10, 8),
    dpi: int = 300
):
    """
    Plot Pareto fronts for a single run comparing different configurations

    Parameters:
    -----------
    data_file : str
        Path to pareto data file
    run_index : int
        Which run to plot (0-indexed)
    output_dir : str
        Directory to save plots
    configs_to_plot : list, optional
        List of configuration names to plot. If None, plots all.
    objective_names : list, optional
        Names for the objectives (e.g., ['Cost', 'Quality'])
    figsize : tuple
        Figure size
    dpi : int
        Resolution
    """
    print("\n" + "="*80)
    print(f"PLOTTING PARETO FRONTS FOR RUN {run_index}")
    print("="*80)

    # Load data
    data = load_pareto_data(data_file)
    pareto_fronts = extract_pareto_fronts_from_data(data, run_index=run_index)

    if not pareto_fronts:
        print("No Pareto front data found!")
        return

    metadata = data.get('metadata', {})
    problem_type = metadata.get('problem_type', 'Unknown')
    problem_size = metadata.get('problem_size', 'Unknown')
    actual_size = metadata.get('actual_size', 'Unknown')

    print(f"Problem: {problem_type}{actual_size} ({problem_size})")

    os.makedirs(output_dir, exist_ok=True)

    # Filter configurations
    if configs_to_plot is not None:
        pareto_fronts = {k: v for k, v in pareto_fronts.items() if k in configs_to_plot}

    # Create plot
    fig, ax = plt.subplots(1, 1, figsize=figsize)

    for config_name, runs in pareto_fronts.items():
        if len(runs) == 0:
            continue

        objectives = runs[0]  # Single run

        if len(objectives) == 0:
            print(f"Warning: No solutions for {config_name}")
            continue

        # Plot Pareto front
        color = PASTEL_COLORS.get(config_name, None)
        marker = MARKER_STYLES.get(config_name, 'o')

        ax.scatter(
            objectives[:, 0],
            objectives[:, 1],
            c=color,
            marker=marker,
            s=100,
            alpha=0.7,
            edgecolors='black',
            linewidth=1.5,
            label=config_name
        )

        # Sort and connect points to show Pareto front curve
        sorted_indices = np.argsort(objectives[:, 0])
        sorted_objectives = objectives[sorted_indices]
        ax.plot(
            sorted_objectives[:, 0],
            sorted_objectives[:, 1],
            c=color,
            alpha=0.3,
            linewidth=2.0,
            linestyle='--'
        )

    # Styling
    obj_names = objective_names or ['Objective 1', 'Objective 2']
    ax.set_xlabel(obj_names[0], fontsize=12, fontweight='bold')
    ax.set_ylabel(obj_names[1], fontsize=12, fontweight='bold')
    ax.set_title(
        f'Pareto Fronts Comparison: {problem_type}{actual_size} (Run {run_index})',
        fontsize=13,
        fontweight='bold',
        pad=15
    )
    ax.legend(frameon=True, loc='best', fontsize=10)
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.set_axisbelow(True)

    plt.tight_layout()

    plot_file = os.path.join(
        output_dir,
        f'pareto_front_run{run_index}_{problem_type}_{problem_size}.png'
    )
    plt.savefig(plot_file, dpi=dpi, bbox_inches='tight')
    plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight')
    print(f"✓ Saved: {plot_file}")
    plt.close()


def plot_aggregated_pareto_fronts(
    data_file: str,
    output_dir: str = 'pareto_plots',
    configs_to_plot: Optional[List[str]] = None,
    objective_names: Optional[List[str]] = None,
    show_individual_runs: bool = False,
    figsize: tuple = (10, 8),
    dpi: int = 300
):
    """
    Plot aggregated Pareto fronts showing all runs with mean/median

    Parameters:
    -----------
    data_file : str
        Path to pareto data file
    output_dir : str
        Directory to save plots
    configs_to_plot : list, optional
        List of configuration names to plot
    objective_names : list, optional
        Names for the objectives
    show_individual_runs : bool
        Whether to show individual runs in background
    figsize : tuple
        Figure size
    dpi : int
        Resolution
    """
    print("\n" + "="*80)
    print("PLOTTING AGGREGATED PARETO FRONTS")
    print("="*80)

    # Load data
    data = load_pareto_data(data_file)
    pareto_fronts = extract_pareto_fronts_from_data(data, run_index=None)

    if not pareto_fronts:
        print("No Pareto front data found!")
        return

    metadata = data.get('metadata', {})
    problem_type = metadata.get('problem_type', 'Unknown')
    problem_size = metadata.get('problem_size', 'Unknown')
    actual_size = metadata.get('actual_size', 'Unknown')
    num_runs = metadata.get('num_runs', 0)

    print(f"Problem: {problem_type}{actual_size} ({problem_size})")
    print(f"Number of runs: {num_runs}")

    os.makedirs(output_dir, exist_ok=True)

    # Filter configurations
    if configs_to_plot is not None:
        pareto_fronts = {k: v for k, v in pareto_fronts.items() if k in configs_to_plot}

    # Create plot
    fig, ax = plt.subplots(1, 1, figsize=figsize)

    for config_name, runs in pareto_fronts.items():
        if len(runs) == 0:
            continue

        color = PASTEL_COLORS.get(config_name, None)
        marker = MARKER_STYLES.get(config_name, 'o')

        # Plot individual runs with transparency
        if show_individual_runs:
            for run_objectives in runs:
                if len(run_objectives) > 0:
                    ax.scatter(
                        run_objectives[:, 0],
                        run_objectives[:, 1],
                        c=color,
                        marker=marker,
                        s=20,
                        alpha=0.1,
                        edgecolors='none'
                    )

        # Aggregate all solutions from all runs
        all_objectives = np.vstack(runs)

        # Compute the actual Pareto front from aggregated solutions
        pareto_mask = compute_pareto_front(all_objectives)
        pareto_objectives = all_objectives[pareto_mask]

        # Sort for plotting
        sorted_indices = np.argsort(pareto_objectives[:, 0])
        sorted_pareto = pareto_objectives[sorted_indices]

        # Plot aggregated Pareto front
        ax.scatter(
            sorted_pareto[:, 0],
            sorted_pareto[:, 1],
            c=color,
            marker=marker,
            s=150,
            alpha=0.8,
            edgecolors='black',
            linewidth=2.0,
            label=config_name,
            zorder=10
        )

        # Connect Pareto front
        ax.plot(
            sorted_pareto[:, 0],
            sorted_pareto[:, 1],
            c=color,
            alpha=0.5,
            linewidth=3.0,
            linestyle='-',
            zorder=9
        )

    # Styling
    obj_names = objective_names or ['Objective 1', 'Objective 2']
    ax.set_xlabel(obj_names[0], fontsize=12, fontweight='bold')
    ax.set_ylabel(obj_names[1], fontsize=12, fontweight='bold')
    title_suffix = f' ({num_runs} runs aggregated)' if show_individual_runs else ''
    ax.set_title(
        f'Pareto Fronts Comparison: {problem_type}{actual_size}{title_suffix}',
        fontsize=13,
        fontweight='bold',
        pad=15
    )
    ax.legend(frameon=True, loc='best', fontsize=10)
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.set_axisbelow(True)

    plt.tight_layout()

    plot_file = os.path.join(
        output_dir,
        f'pareto_front_aggregated_{problem_type}_{problem_size}.png'
    )
    plt.savefig(plot_file, dpi=dpi, bbox_inches='tight')
    plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight')
    print(f"✓ Saved: {plot_file}")
    plt.close()


def plot_pareto_fronts_separate(
    data_file: str,
    output_dir: str = 'pareto_plots',
    configs_to_plot: Optional[List[str]] = None,
    objective_names: Optional[List[str]] = None,
    figsize: tuple = (8, 6),
    dpi: int = 300
):
    """
    Plot separate Pareto fronts for each configuration

    Parameters:
    -----------
    data_file : str
        Path to pareto data file
    output_dir : str
        Directory to save plots
    configs_to_plot : list, optional
        List of configuration names to plot
    objective_names : list, optional
        Names for the objectives
    figsize : tuple
        Figure size
    dpi : int
        Resolution
    """
    print("\n" + "="*80)
    print("PLOTTING SEPARATE PARETO FRONTS")
    print("="*80)

    # Load data
    data = load_pareto_data(data_file)
    pareto_fronts = extract_pareto_fronts_from_data(data, run_index=None)

    if not pareto_fronts:
        print("No Pareto front data found!")
        return

    metadata = data.get('metadata', {})
    problem_type = metadata.get('problem_type', 'Unknown')
    problem_size = metadata.get('problem_size', 'Unknown')
    actual_size = metadata.get('actual_size', 'Unknown')
    num_runs = metadata.get('num_runs', 0)

    print(f"Problem: {problem_type}{actual_size} ({problem_size})")

    os.makedirs(output_dir, exist_ok=True)

    # Filter configurations
    if configs_to_plot is not None:
        pareto_fronts = {k: v for k, v in pareto_fronts.items() if k in configs_to_plot}

    obj_names = objective_names or ['Objective 1', 'Objective 2']

    for config_name, runs in pareto_fronts.items():
        if len(runs) == 0:
            continue

        fig, ax = plt.subplots(1, 1, figsize=figsize)

        color = PASTEL_COLORS.get(config_name, '#888888')
        marker = MARKER_STYLES.get(config_name, 'o')

        # Plot individual runs
        for run_objectives in runs:
            if len(run_objectives) > 0:
                ax.scatter(
                    run_objectives[:, 0],
                    run_objectives[:, 1],
                    c=color,
                    marker=marker,
                    s=30,
                    alpha=0.3,
                    edgecolors='none'
                )

        # Aggregate and plot overall Pareto front
        all_objectives = np.vstack(runs)
        pareto_mask = compute_pareto_front(all_objectives)
        pareto_objectives = all_objectives[pareto_mask]

        sorted_indices = np.argsort(pareto_objectives[:, 0])
        sorted_pareto = pareto_objectives[sorted_indices]

        ax.scatter(
            sorted_pareto[:, 0],
            sorted_pareto[:, 1],
            c=color,
            marker=marker,
            s=150,
            alpha=0.9,
            edgecolors='black',
            linewidth=2.0,
            label='Aggregated Pareto Front',
            zorder=10
        )

        ax.plot(
            sorted_pareto[:, 0],
            sorted_pareto[:, 1],
            c=color,
            alpha=0.6,
            linewidth=3.0,
            linestyle='-',
            zorder=9
        )

        # Styling
        ax.set_xlabel(obj_names[0], fontsize=12, fontweight='bold')
        ax.set_ylabel(obj_names[1], fontsize=12, fontweight='bold')
        ax.set_title(
            f'{config_name}\n{problem_type}{actual_size} ({num_runs} runs)',
            fontsize=13,
            fontweight='bold',
            pad=15
        )
        ax.legend(frameon=True, loc='best', fontsize=10)
        ax.grid(True, alpha=0.3, linestyle='--')
        ax.set_axisbelow(True)

        plt.tight_layout()

        safe_config_name = config_name.replace(' ', '_').replace('-', '_')
        plot_file = os.path.join(
            output_dir,
            f'pareto_front_{safe_config_name}_{problem_type}_{problem_size}.png'
        )
        plt.savefig(plot_file, dpi=dpi, bbox_inches='tight')
        plt.savefig(plot_file.replace('.png', '.pdf'), bbox_inches='tight')
        print(f"✓ Saved: {plot_file}")
        plt.close()


def compute_pareto_front(objectives: np.ndarray, maximize: bool = False) -> np.ndarray:
    """
    Compute Pareto front from a set of objectives

    Parameters:
    -----------
    objectives : np.ndarray
        Array of shape (n_points, n_objectives)
    maximize : bool
        If True, assumes maximization. If False, assumes minimization.

    Returns:
    --------
    pareto_mask : np.ndarray
        Boolean array indicating which points are on the Pareto front
    """
    n_points = objectives.shape[0]
    pareto_mask = np.ones(n_points, dtype=bool)

    for i in range(n_points):
        if pareto_mask[i]:
            # Check if point i is dominated by any other point
            if maximize:
                # For maximization: a point dominates if all objectives are >=, and at least one is >
                dominated = np.all(objectives >= objectives[i], axis=1) & \
                           np.any(objectives > objectives[i], axis=1)
            else:
                # For minimization: a point dominates if all objectives are <=, and at least one is <
                dominated = np.all(objectives <= objectives[i], axis=1) & \
                           np.any(objectives < objectives[i], axis=1)

            pareto_mask[i] = not np.any(dominated)

    return pareto_mask


# ============================================================================
# CONVENIENCE FUNCTION: GENERATE ALL PARETO PLOTS
# ============================================================================

def generate_all_pareto_plots(
    data_file: str,
    output_dir: str = 'pareto_plots',
    configs_to_plot: Optional[List[str]] = None,
    objective_names: Optional[List[str]] = None,
    plot_first_n_runs: int = 3,
    dpi: int = 300
):
    """
    Generate all Pareto front plots

    Parameters:
    -----------
    data_file : str
        Path to pareto data file
    output_dir : str
        Directory to save plots
    configs_to_plot : list, optional
        List of configuration names to plot
    objective_names : list, optional
        Names for the objectives
    plot_first_n_runs : int
        Number of individual runs to plot
    dpi : int
        Resolution
    """
    print("\n" + "="*80)
    print("GENERATING ALL PARETO FRONT PLOTS")
    print("="*80)

    os.makedirs(output_dir, exist_ok=True)

    # 1. Plot first N individual runs
    for run_idx in range(plot_first_n_runs):
        plot_single_run_pareto_fronts(
            data_file=data_file,
            run_index=run_idx,
            output_dir=output_dir,
            configs_to_plot=configs_to_plot,
            objective_names=objective_names,
            dpi=dpi
        )

    # 2. Plot aggregated Pareto fronts
    plot_aggregated_pareto_fronts(
        data_file=data_file,
        output_dir=output_dir,
        configs_to_plot=configs_to_plot,
        objective_names=objective_names,
        show_individual_runs=True,
        dpi=dpi
    )

    # 3. Plot separate Pareto fronts for each configuration
    plot_pareto_fronts_separate(
        data_file=data_file,
        output_dir=output_dir,
        configs_to_plot=configs_to_plot,
        objective_names=objective_names,
        dpi=dpi
    )

    print("\n" + "="*80)
    print("✓ ALL PARETO FRONT PLOTS COMPLETED!")
    print("="*80)
    print(f"Plots saved to: {output_dir}/")


# ============================================================================
# USAGE EXAMPLES
# ============================================================================

if __name__ == "__main__":

    # Example data file - UPDATE THIS PATH TO YOUR ACTUAL FILE
    data_file = 'ablation_results_experts/pareto_fronts/pareto_data_BiTSP_large_20251114-051437.json'

    # Check if file exists
    if not os.path.exists(data_file):
        print(f"ERROR: Data file not found: {data_file}")
        print("\nPlease update the 'data_file' path in this script to point to your pareto data file.")
        exit(1)

    # Example 1: Generate all plots for all configurations
    print("\n" + "="*80)
    print("EXAMPLE 1: All configurations, all plot types")
    print("="*80)

    # Note: This will fail if the data doesn't contain 'pareto_fronts' key
    # We need to first modify the ablation script to save actual Pareto fronts

    # Uncomment this when you have proper pareto front data:
    generate_all_pareto_plots(
        data_file=data_file,
        output_dir='pareto_plots_all',
        objective_names=['Objective 1', 'Objective 2'],  # Update with your objective names
        plot_first_n_runs=3,
        dpi=300
    )

    # Example 2: Only Thompson configurations
    # generate_all_pareto_plots(
    #     data_file=data_file,
    #     output_dir='pareto_plots_thompson',
    #     configs_to_plot=[
    #         'Thompson-Hedge With FTRL',
    #         'Thompson-Hedge Without FTRL'
    #     ],
    #     objective_names=['Objective 1', 'Objective 2'],  # Update as needed
    #     plot_first_n_runs=5,
    #     dpi=300
    # )

    print("\n" + "="*80)
    print("NOTE: Current data file does not contain actual Pareto fronts!")
    print("You need to modify the ablation study to save the 'objectives' field.")
    print("See the modification instructions in the script comments above.")
    print("="*80)
