"""
Visualization Module

Publication-ready plotting utilities for SP-UCB-OLP experiments.

Key Plots:
1. Competitive ratio comparison (bar chart)
2. Regret vs time curves
3. Budget consumption trajectories
4. Mixture weight evolution
5. Oracle gap visualization
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from typing import Dict, List, Tuple, Optional, Any
from pathlib import Path

from .storage import ExperimentResults, RunTrajectory


# Publication-quality defaults
plt.rcParams.update({
    'font.size': 12,
    'axes.labelsize': 14,
    'axes.titlesize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 11,
    'figure.figsize': (8, 5),
    'figure.dpi': 150,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'lines.linewidth': 2,
    'lines.markersize': 8,
})


# Color palette (colorblind-friendly)
COLORS = {
    'SP-UCB-OLP': '#0072B2',      # Blue
    'SPUCBOLP': '#0072B2',
    'Greedy': '#E69F00',          # Orange
    'SPGreedyOLP': '#E69F00',
    'OneHot': '#009E73',          # Green
    'OneHotSPUCBOLP': '#009E73',
    'Oracle': '#CC79A7',          # Pink
    'OraclePolicy': '#CC79A7',
    'Random': '#999999',          # Gray
    'RandomPolicy': '#999999',
    'Fixed': '#D55E00',           # Red-orange
    'FixedConfigPolicy': '#D55E00',
}

# Line styles
LINESTYLES = {
    'SP-UCB-OLP': '-',
    'Greedy': '--',
    'OneHot': '-.',
    'Oracle': ':',
    'Random': ':',
    'Fixed': '--',
}


def get_color(algorithm: str) -> str:
    """Get color for algorithm."""
    return COLORS.get(algorithm, '#000000')


def get_linestyle(algorithm: str) -> str:
    """Get line style for algorithm."""
    return LINESTYLES.get(algorithm, '-')


def plot_competitive_ratio_comparison(
    results: ExperimentResults,
    ax: plt.Axes = None,
    show_error_bars: bool = True,
    title: str = None
) -> plt.Axes:
    """
    Plot competitive ratio comparison as bar chart.

    Parameters
    ----------
    results : ExperimentResults
        Experiment results
    ax : plt.Axes, optional
        Axes to plot on
    show_error_bars : bool
        Show standard error bars
    title : str, optional
        Plot title

    Returns
    -------
    ax : plt.Axes
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 6))

    algorithms = results.algorithms
    n_algs = len(algorithms)

    x = np.arange(n_algs)
    width = 0.6

    means = []
    stds = []
    colors = []

    for alg in algorithms:
        ratios = results.competitive_ratios[alg]
        means.append(np.mean(ratios))
        stds.append(np.std(ratios) / np.sqrt(len(ratios)))  # Standard error
        colors.append(get_color(alg))

    bars = ax.bar(x, means, width, color=colors, edgecolor='black', linewidth=1)

    if show_error_bars:
        ax.errorbar(x, means, yerr=stds, fmt='none', color='black', capsize=5)

    # Add V^mix and V* reference lines
    ax.axhline(y=1.0, color='green', linestyle='--', linewidth=1.5, label='V^mix')
    v_star_ratio = results.V_star / results.V_mix if results.V_mix > 0 else 0
    ax.axhline(y=v_star_ratio, color='red', linestyle=':', linewidth=1.5, label='V*')

    ax.set_ylabel('Competitive Ratio (vs V^mix)')
    ax.set_xlabel('Algorithm')
    ax.set_xticks(x)
    ax.set_xticklabels(algorithms, rotation=45, ha='right')
    ax.legend(loc='upper right')

    if title:
        ax.set_title(title)
    else:
        ax.set_title(f'Competitive Ratio Comparison (T={results.T}, ρ={results.budget_factor})')

    ax.set_ylim(0, 1.1)
    ax.grid(axis='y', alpha=0.3)

    return ax


def plot_regret_vs_time(
    trajectories: Dict[str, List[RunTrajectory]],
    V_mix: float,
    ax: plt.Axes = None,
    show_ci: bool = True,
    title: str = None
) -> plt.Axes:
    """
    Plot regret vs time curves.

    Parameters
    ----------
    trajectories : Dict[str, List[RunTrajectory]]
        Trajectories per algorithm
    V_mix : float
        Oracle value per period
    ax : plt.Axes, optional
        Axes to plot on
    show_ci : bool
        Show 95% confidence interval
    title : str, optional
        Plot title

    Returns
    -------
    ax : plt.Axes
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 6))

    for alg, trajs in trajectories.items():
        if len(trajs) == 0:
            continue

        T = trajs[0].T
        t_values = np.arange(1, T + 1)

        # Compute regret for each trajectory
        regrets = []
        for traj in trajs:
            # Regret = t * V^mix - cumulative reward
            regret = t_values * V_mix - traj.cumulative_reward
            regrets.append(regret)

        regrets = np.array(regrets)
        mean_regret = np.mean(regrets, axis=0)

        color = get_color(alg)
        linestyle = get_linestyle(alg)

        ax.plot(t_values, mean_regret, color=color, linestyle=linestyle, label=alg)

        if show_ci and len(trajs) > 1:
            std_regret = np.std(regrets, axis=0)
            se_regret = std_regret / np.sqrt(len(trajs))
            ax.fill_between(
                t_values,
                mean_regret - 1.96 * se_regret,
                mean_regret + 1.96 * se_regret,
                color=color, alpha=0.2
            )

    # Add sqrt(T) reference line
    t_ref = np.arange(1, T + 1)
    ax.plot(t_ref, np.sqrt(t_ref) * 10, 'k--', alpha=0.5, label='O(√T)')

    ax.set_xlabel('Time t')
    ax.set_ylabel('Regret vs V^mix')
    ax.legend(loc='upper left')
    ax.grid(alpha=0.3)

    if title:
        ax.set_title(title)
    else:
        ax.set_title('Regret vs Time')

    return ax


def plot_budget_utilization(
    trajectories: Dict[str, List[RunTrajectory]],
    B: np.ndarray,
    ax: plt.Axes = None,
    resource_idx: int = 0,
    title: str = None
) -> plt.Axes:
    """
    Plot budget utilization over time for a specific resource.

    Parameters
    ----------
    trajectories : Dict[str, List[RunTrajectory]]
        Trajectories per algorithm
    B : np.ndarray
        Total budget vector
    ax : plt.Axes, optional
        Axes to plot on
    resource_idx : int
        Which resource to plot
    title : str, optional
        Plot title

    Returns
    -------
    ax : plt.Axes
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 6))

    for alg, trajs in trajectories.items():
        if len(trajs) == 0:
            continue

        T = trajs[0].T
        t_values = np.arange(1, T + 1)

        # Compute budget remaining for each trajectory
        budget_remaining = []
        for traj in trajs:
            remaining = traj.budget_remaining[:, resource_idx]
            budget_remaining.append(remaining)

        budget_remaining = np.array(budget_remaining)
        mean_remaining = np.mean(budget_remaining, axis=0)

        color = get_color(alg)
        linestyle = get_linestyle(alg)

        # Plot as utilization (1 - remaining/B)
        utilization = 1 - mean_remaining / B[resource_idx]
        ax.plot(t_values, utilization, color=color, linestyle=linestyle, label=alg)

    # Add ideal utilization line (linear)
    ax.plot(t_values, t_values / T, 'k--', alpha=0.5, label='Ideal (linear)')

    ax.set_xlabel('Time t')
    ax.set_ylabel(f'Budget Utilization (Resource {resource_idx})')
    ax.legend(loc='upper left')
    ax.grid(alpha=0.3)
    ax.set_ylim(0, 1.1)

    if title:
        ax.set_title(title)
    else:
        ax.set_title(f'Budget Utilization Over Time (Resource {resource_idx})')

    return ax


def plot_mixture_evolution(
    trajectory: RunTrajectory,
    ax: plt.Axes = None,
    title: str = None
) -> plt.Axes:
    """
    Plot mixture weight evolution over time.

    Parameters
    ----------
    trajectory : RunTrajectory
        Single trajectory to plot
    ax : plt.Axes, optional
        Axes to plot on
    title : str, optional
        Plot title

    Returns
    -------
    ax : plt.Axes
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 5))

    T = trajectory.T
    K = trajectory.K
    t_values = np.arange(T)

    # Smooth weights for visualization
    window = max(1, T // 100)

    for k in range(K):
        weights = trajectory.w_t[:, k]

        # Moving average
        if window > 1:
            weights_smooth = np.convolve(weights, np.ones(window)/window, mode='valid')
            t_smooth = t_values[window-1:]
        else:
            weights_smooth = weights
            t_smooth = t_values

        ax.plot(t_smooth, weights_smooth, label=f'Config {k}', linewidth=1.5)

    ax.set_xlabel('Time t')
    ax.set_ylabel('Mixture Weight')
    ax.legend(loc='upper right', ncol=min(K, 4))
    ax.grid(alpha=0.3)
    ax.set_ylim(-0.05, 1.05)

    if title:
        ax.set_title(title)
    else:
        ax.set_title('Mixture Weight Evolution')

    return ax


def plot_gap_vs_rho(
    results_by_rho: Dict[float, ExperimentResults],
    ax: plt.Axes = None,
    title: str = None
) -> plt.Axes:
    """
    Plot V^mix - V* gap as function of budget factor rho.

    Parameters
    ----------
    results_by_rho : Dict[float, ExperimentResults]
        Results for different rho values
    ax : plt.Axes, optional
        Axes to plot on
    title : str, optional
        Plot title

    Returns
    -------
    ax : plt.Axes
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 5))

    rhos = sorted(results_by_rho.keys())
    gaps = [results_by_rho[rho].V_mix - results_by_rho[rho].V_star for rho in rhos]

    ax.plot(rhos, gaps, 'o-', color=COLORS['SP-UCB-OLP'], linewidth=2, markersize=8)

    ax.set_xlabel('Budget Factor ρ')
    ax.set_ylabel('Gap (V^mix - V*)')
    ax.grid(alpha=0.3)

    if title:
        ax.set_title(title)
    else:
        ax.set_title('Oracle Gap vs Budget Tightness')

    return ax


def create_summary_figure(
    results: ExperimentResults,
    trajectories: Dict[str, List[RunTrajectory]] = None,
    save_path: str = None
) -> plt.Figure:
    """
    Create a summary figure with multiple subplots.

    Parameters
    ----------
    results : ExperimentResults
        Experiment results
    trajectories : Dict[str, List[RunTrajectory]], optional
        Full trajectories for detailed plots
    save_path : str, optional
        Path to save figure

    Returns
    -------
    fig : plt.Figure
    """
    if trajectories is not None:
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))

        # Competitive ratio comparison
        plot_competitive_ratio_comparison(results, ax=axes[0, 0])

        # Regret vs time
        plot_regret_vs_time(trajectories, results.V_mix, ax=axes[0, 1])

        # Budget utilization
        B = np.array([results.T])  # Placeholder
        if len(list(trajectories.values())[0]) > 0:
            B = list(trajectories.values())[0][0].budget_remaining[0] + \
                list(trajectories.values())[0][0].cumulative_consumption[0]
        plot_budget_utilization(trajectories, B, ax=axes[1, 0])

        # Mixture evolution (first trajectory of main algorithm)
        main_alg = 'SP-UCB-OLP' if 'SP-UCB-OLP' in trajectories else list(trajectories.keys())[0]
        if len(trajectories[main_alg]) > 0:
            plot_mixture_evolution(trajectories[main_alg][0], ax=axes[1, 1])
    else:
        fig, ax = plt.subplots(figsize=(10, 6))
        plot_competitive_ratio_comparison(results, ax=ax)

    fig.suptitle(
        f'{results.family}: T={results.T}, K={results.K}, d={results.d}, ρ={results.budget_factor}',
        fontsize=14, fontweight='bold'
    )
    fig.tight_layout()

    if save_path:
        fig.savefig(save_path, dpi=300, bbox_inches='tight')

    return fig


def save_figure(fig: plt.Figure, path: str, formats: List[str] = ['pdf', 'png']):
    """Save figure in multiple formats."""
    path = Path(path)
    for fmt in formats:
        fig.savefig(path.with_suffix(f'.{fmt}'), dpi=300, bbox_inches='tight')


if __name__ == "__main__":
    # Demo with synthetic data
    print("Visualization module loaded successfully.")
    print("Use create_summary_figure() to generate publication-ready figures.")
