import argparse
import os
from ast import literal_eval
from typing import Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rich

# Set style for publication-quality plots
plt.style.use("seaborn-v0_8-paper")
plt.rcParams["font.family"] = "serif"
plt.rcParams["axes.grid"] = True
plt.rcParams["grid.alpha"] = 0.3
plt.rcParams["axes.labelsize"] = 14
plt.rcParams["axes.titlesize"] = 16
plt.rcParams["xtick.labelsize"] = 12
plt.rcParams["ytick.labelsize"] = 12
plt.rcParams["legend.fontsize"] = 12
plt.rcParams["legend.title_fontsize"] = 14

# Markers and line styles for different strategies
MARKERS = ["o", "s", "D", "^", "v", "<", ">", "p", "*"]
LINE_STYLES = ["-", "--", "-.", ":"]


def load_data(base_dir: str) -> Dict[int, Dict[int, Dict]]:
    """
    Load validation error data from all trainable layer directories.

    Returns a nested dictionary:
    {
        n_trainable_layers: {
            config_id: {
                'val_errors': List[float],
                'epoch_time': float,
                'n_params': int
            }
        }
    }
    """
    all_data = {}

    # Get all directories with pattern X_trainable
    for item in os.listdir(base_dir):
        if os.path.isdir(os.path.join(base_dir, item)) and "_trainable" in item:
            try:
                # Extract the number of trainable layers
                n_trainable = int(item.split("_")[0])
                trainable_dir = os.path.join(base_dir, item)
                csv_path = os.path.join(trainable_dir, "summary", "full.csv")

                if not os.path.exists(csv_path):
                    print(f"File not found: {csv_path}")
                    continue

                # Read the data for this trainable layer configuration
                data = pd.read_csv(csv_path)

                # Create entry for this trainable layer count
                all_data[n_trainable] = {}

                # Process each configuration
                for _, row in data.iterrows():
                    config_id = int(row["id"])

                    # Get validation errors column
                    error_col = (
                        "extra.validation_errors"
                        if "extra.validation_errors" in data.columns
                        else "extra.validation_losses"
                    )
                    val_errors_str = row[error_col]
                    val_errors = literal_eval(val_errors_str)

                    # Extract time per epoch
                    total_time = float(row["cost"])
                    n_epochs = len(val_errors)

                    # Extract number of trainable parameters
                    n_trainable_params = int(row["extra.n_trainable_params"])

                    # Store all relevant data
                    all_data[n_trainable][config_id] = {
                        "val_errors": val_errors,
                        "epoch_time": total_time / n_epochs,  # Approximate time per epoch
                        "n_params": n_trainable_params,
                    }

            except Exception as e:
                print(f"Error loading data from {trainable_dir}: {e}")
                continue

    return all_data


def simulate_sh_layers(
    data: Dict[int, Dict[int, Dict]],
    eta: int = 3,
    min_layers: int = 1,
    max_epochs: int = 20,
    fixed_epochs: Optional[int] = None,
) -> Tuple[List[float], List[float]]:
    """
    Simulate Successive Halving with trainable layers as the fidelity.

    Args:
        data: Nested dictionary with validation error data
        eta: Reduction factor (default: 3)
        min_layers: Minimum number of trainable layers to start with
        max_epochs: Maximum number of epochs to train
        fixed_epochs: If provided, use this fixed number of epochs for all evaluations

    Returns:
        Tuple of (wall_clock_times, best_validation_errors)
    """
    if not data:
        raise ValueError("No data provided")

    # Get sorted list of available layer counts
    available_layers = sorted(data.keys())

    # Get all config IDs (assuming they're consistent across layer counts)
    config_ids = list(data[available_layers[0]].keys())

    # Start with all configs at the minimum layer count
    starting_layer_idx = available_layers.index(min_layers) if min_layers in available_layers else 0
    current_layer = available_layers[starting_layer_idx]

    # Initialize variables
    remaining_configs = config_ids.copy()
    wall_clock_time = 0.0
    wall_clock_times = [0.0]
    best_val_errors = [1.0]  # Start with worst possible error

    # Calculate number of SH brackets based on eta
    n_brackets = int(np.log(len(remaining_configs)) / np.log(eta))

    # For each bracket in SH
    for i in range(n_brackets + 1):
        # Number of epochs to train in this bracket
        if fixed_epochs is not None:
            n_epochs = fixed_epochs
        else:
            n_epochs = min(max_epochs, int(max_epochs * (eta**i) / (eta**n_brackets)))

        # Evaluate all remaining configs
        results = []
        for config_id in remaining_configs:
            # Get validation error after n_epochs
            val_errors = data[current_layer][config_id]["val_errors"]
            epoch_time = data[current_layer][config_id]["epoch_time"]

            # Ensure we don't exceed available epochs
            actual_epochs = min(n_epochs, len(val_errors))
            val_error = val_errors[actual_epochs - 1]

            # Track time spent
            wall_clock_time += actual_epochs * epoch_time

            results.append((config_id, val_error))

        # Sort by validation error
        results.sort(key=lambda x: x[1])

        # Record the best validation error seen so far
        # current_best_error = min(best_val_errors[-1], results[0][1])
        current_best_error = min(
            best_val_errors[-1], data[list(data.keys())[-1]][results[0][0]]["val_errors"][-1]
        )
        wall_clock_times.append(wall_clock_time)
        best_val_errors.append(current_best_error)

        # Keep top 1/eta configurations
        n_keep = max(1, int(len(remaining_configs) / eta))
        remaining_configs = [r[0] for r in results[:n_keep]]

        # Increase trainable layers for next round if possible
        if i < n_brackets:
            layer_idx = starting_layer_idx + i + 1
            if layer_idx < len(available_layers):
                current_layer = available_layers[layer_idx]

    return wall_clock_times, best_val_errors


def simulate_sh_epochs(
    data: Dict[int, Dict[int, Dict]], n_layers: int, eta: int = 3, max_epochs: int = 20
) -> Tuple[List[float], List[float]]:
    """
    Simulate Successive Halving with epochs as the fidelity.
    This implementation assumes training from scratch for each configuration (no continuation).

    Args:
        data: Nested dictionary with validation error data
        n_layers: Fixed number of trainable layers
        eta: Reduction factor (default: 3)
        max_epochs: Maximum number of epochs to train

    Returns:
        Tuple of (wall_clock_times, best_validation_errors)
    """
    if not data or n_layers not in data:
        raise ValueError(f"No data available for {n_layers} trainable layers")

    # Get configs for the specified layer count
    layer_data = data[n_layers]
    config_ids = list(layer_data.keys())

    # Initialize variables
    remaining_configs = config_ids.copy()
    wall_clock_time = 0.0
    wall_clock_times = [0.0]
    best_val_errors = [1.0]  # Start with worst possible error

    # Calculate number of SH brackets
    n_brackets = int(np.log(len(remaining_configs)) / np.log(eta))

    # For each bracket in SH
    for i in range(n_brackets + 1):
        # Number of epochs to run in this bracket (1-indexed for counting)
        n_epochs = min(max_epochs, int(max_epochs * (eta**i) / (eta**n_brackets)))

        # Evaluate all remaining configs
        results = []
        for config_id in remaining_configs:
            val_errors = layer_data[config_id]["val_errors"]
            epoch_time = layer_data[config_id]["epoch_time"]

            # Ensure we don't exceed available epochs
            if n_epochs <= len(val_errors):
                # We have enough epochs, use the error at exactly n_epochs
                val_error = val_errors[
                    (n_epochs - 1) if n_epochs > 0 else 0
                ]  # Adjust for 0-indexing

                # Track time spent (always training from scratch)
                wall_clock_time += (n_epochs + 1 if n_epochs == 0 else n_epochs) * epoch_time
            else:
                # We don't have enough epochs in the data, use the last available
                val_error = val_errors[-1]

                # Track time spent on all available epochs
                wall_clock_time += len(val_errors) * epoch_time

            results.append((config_id, val_error))

        # Sort by validation error
        results.sort(key=lambda x: x[1])

        # Record the best validation error seen so far
        current_best_error = min(
            best_val_errors[-1], data[n_layers][results[0][0]]["val_errors"][-1]
        )
        wall_clock_times.append(wall_clock_time)
        best_val_errors.append(current_best_error)

        # Keep top 1/eta configurations
        n_keep = max(1, int(len(remaining_configs) / eta))
        remaining_configs = [r[0] for r in results[:n_keep]]

    return wall_clock_times, best_val_errors


def simulate_sh_combined(
    data: Dict[int, Dict[int, Dict]],
    eta_layers: int = 3,
    eta_epochs: int = 3,
    min_layers: int = 1,
    max_epochs: int = 20,
) -> Tuple[List[float], List[float]]:
    """
    Simulate Successive Halving with both trainable layers and epochs as fidelities.

    Strategy: Number of brackets is determined solely by epochs, while layer progression
    is scaled accordingly within these brackets.

    Args:
        data: Nested dictionary with validation error data
        eta_layers: Reduction factor for layers (default: 3)
        eta_epochs: Reduction factor for epochs (default: 3)
        min_layers: Minimum number of trainable layers to start with
        max_epochs: Maximum number of epochs to train

    Returns:
        Tuple of (wall_clock_times, best_validation_errors)
    """
    if not data:
        raise ValueError("No data provided")

    # Get sorted list of available layer counts
    available_layers = sorted(data.keys())

    # Get all config IDs (assuming they're consistent across layer counts)
    config_ids = list(data[available_layers[0]].keys())

    # Start with minimum layer count
    starting_layer_idx = available_layers.index(min_layers) if min_layers in available_layers else 0

    # Initialize variables
    remaining_configs = config_ids.copy()
    wall_clock_time = 0.0
    wall_clock_times = [0.0]
    best_val_errors = [1.0]  # Start with worst possible error

    # Calculate number of SH brackets based on eta_epochs only
    n_brackets = int(np.log(len(remaining_configs)) / np.log(eta_epochs))

    # Calculate layer and epoch progressions
    layer_indices = []
    epoch_counts = []

    # Calculate maximum possible layer progression
    max_layer_progress = len(available_layers) - starting_layer_idx - 1

    for i in range(n_brackets + 1):
        # Determine epoch count for this bracket (using eta_epochs)
        # This follows standard SH scheduling with eta_epochs
        n_epochs = min(max_epochs, int(max_epochs * (eta_epochs**i) / (eta_epochs**n_brackets)))
        epoch_counts.append(n_epochs)

        # Determine layer index for this bracket (scaled to fit within n_brackets)
        # Scale i to the range [0, max_layer_progress] based on bracket progress
        layer_progress = int(max_layer_progress * (i / n_brackets)) if n_brackets > 0 else 0
        layer_idx = starting_layer_idx + layer_progress
        layer_indices.append(layer_idx)

    # For each bracket in SH
    for i in range(n_brackets + 1):
        layer_idx = layer_indices[i]
        current_layer = available_layers[layer_idx]
        n_epochs = epoch_counts[i]

        # Evaluate all remaining configs
        results = []
        for config_id in remaining_configs:
            # Skip if config doesn't exist for this layer
            if config_id not in data[current_layer]:
                rich.print(f"Config {config_id} not found for layer {current_layer}")
                continue

            val_errors = data[current_layer][config_id]["val_errors"]
            epoch_time = data[current_layer][config_id]["epoch_time"]

            # Ensure we don't exceed available epochs
            actual_epochs = min(n_epochs, len(val_errors))
            val_error = val_errors[(actual_epochs - 1 if actual_epochs > 0 else 0)]

            # Track time spent
            wall_clock_time += (actual_epochs if actual_epochs > 0 else 1) * epoch_time

            results.append((config_id, val_error))

        if not results:
            # No valid results in this round, use previous best error
            wall_clock_times.append(wall_clock_time)
            best_val_errors.append(best_val_errors[-1])
            break

        # Sort by validation error
        results.sort(key=lambda x: x[1])

        # Record the best validation error seen so far
        current_best_error = min(
            best_val_errors[-1], data[list(data.keys())[-1]][results[0][0]]["val_errors"][-1]
        )
        wall_clock_times.append(wall_clock_time)
        best_val_errors.append(current_best_error)

        # Keep top 1/eta_epochs configurations
        n_keep = max(1, int(len(remaining_configs) / eta_epochs))
        remaining_configs = [r[0] for r in results[:n_keep]]

    return wall_clock_times, best_val_errors


def plot_sh_comparison(
    base_dir: str,
    eta: int = 3,
    eta_epochs: Optional[int] = None,
    eta_layers: Optional[int] = None,
    max_epochs: int = 20,
    selected_layers: Optional[List[int]] = None,
    fixed_epochs: Optional[int] = None,
):
    """
    Plot comparison of different Successive Halving strategies.

    Args:
        base_dir: Base directory containing the grid search results
        eta: Default reduction factor for SH
        eta_epochs: Specific reduction factor for epochs fidelity (defaults to eta)
        eta_layers: Specific reduction factor for layers fidelity (defaults to eta)
        max_epochs: Maximum number of epochs
        selected_layers: Optional list of layer counts to include in the plot (default: all)
        fixed_epochs: Optional fixed number of epochs for layer fidelity SH
    """
    # Set default eta values if not specified
    if eta_epochs is None:
        eta_epochs = eta
    if eta_layers is None:
        eta_layers = eta

    # Load all data
    data = load_data(base_dir)
    if not data:
        print("No data found")
        return

    # Get available layer counts
    available_layers = sorted(data.keys())
    print(f"Available layer counts: {available_layers}")

    # Determine the fully trainable model (maximum layer count)
    fully_trainable = max(available_layers)
    print(f"Using fully trainable model with {fully_trainable} layers for epoch fidelity")

    # Filter to selected layers if specified
    if selected_layers:
        plot_layers = [l for l in selected_layers if l in available_layers]
        if not plot_layers:
            print(
                f"None of the selected layers {selected_layers} are available. Using all available layers."
            )
            plot_layers = available_layers
    else:
        plot_layers = available_layers

    # Ensure output directory exists
    os.makedirs("plots", exist_ok=True)

    # Configure colors using tab10 colormap for better distinction
    tab10 = plt.get_cmap("tab10")

    # Create figure
    fig, ax = plt.subplots(figsize=(8, 8))

    # Simulate SH with trainable layers as fidelity
    times_layers, errors_layers = simulate_sh_layers(
        data, eta=eta_layers, max_epochs=max_epochs, fixed_epochs=fixed_epochs
    )

    epoch_info = f" ({fixed_epochs} epochs)" if fixed_epochs is not None else ""
    ax.plot(
        times_layers,
        errors_layers,
        marker=MARKERS[0],
        linestyle=LINE_STYLES[0],
        label=f"SH with Trainable Layers{epoch_info} (η={eta_layers})",
        color=tab10(0),
        linewidth=3,
        markersize=10,
    )

    # Simulate SH with epochs as fidelity for the fully trainable model only
    if fully_trainable in data and len(data[fully_trainable]) >= eta_epochs:
        times_epochs, errors_epochs = simulate_sh_epochs(
            data, n_layers=fully_trainable, eta=eta_epochs, max_epochs=max_epochs
        )
        ax.plot(
            times_epochs,
            errors_epochs,
            marker=MARKERS[1],
            linestyle=LINE_STYLES[1],
            label=f"SH with Epochs (fully trainable: {fully_trainable} layers, η={eta_epochs})",
            color=tab10(1),
            linewidth=2,
            markersize=8,
        )
    else:
        print(
            f"Skipping epoch fidelity - not enough data for fully trainable model ({fully_trainable} layers)"
        )

    # Simulate combined approach
    times_combined, errors_combined = simulate_sh_combined(
        data, eta_layers=eta_layers, eta_epochs=eta_epochs, max_epochs=max_epochs
    )
    ax.plot(
        times_combined,
        errors_combined,
        marker="*",
        linestyle="-",
        label=f"SH with Combined Fidelities (η_layers={eta_layers}, η_epochs={eta_epochs})",
        color="black",
        linewidth=3,
        markersize=12,
    )

    # Set up plot
    ax.set_xlabel("Wall-Clock Time (seconds)", fontweight="bold")
    ax.set_ylabel("Best Validation Error", fontweight="bold")
    title_str = "Successive Halving Comparison"
    if fixed_epochs:
        title_str += f", Fixed {fixed_epochs} epochs for layer strategy"
    ax.set_title(title_str, fontweight="bold", fontsize=18)
    ax.grid(True, linestyle="--", alpha=0.7)

    # Add legend with better placement
    legend = ax.legend(
        loc="upper right",
        frameon=True,
        fancybox=True,
        shadow=True,
        ncol=1,
        fontsize=10,
        title="Strategy",
        title_fontsize=12,
    )
    legend.get_frame().set_alpha(0.9)

    # Add annotations for the final points
    for i, (times, errors, label) in enumerate(
        [(times_layers, errors_layers, "Layers"), (times_combined, errors_combined, "Combined")]
    ):
        # Annotate only if there are points
        if len(times) > 1 and len(errors) > 1:
            ax.annotate(
                f"{label}: {errors[-1]:.4f}",
                xy=(times[-1], errors[-1]),
                xytext=(10, 10 + i * (-20)),
                textcoords="offset points",
                fontsize=10,
                bbox={"boxstyle": "round,pad=0.3", "fc": "yellow", "alpha": 0.5},
                arrowprops={"arrowstyle": "->", "connectionstyle": "arc3,rad=.2"},
            )

    # Add a grid reference to easily compare the error values
    ax.set_axisbelow(True)

    # Customize y-axis to focus on the relevant range
    min_error = min([errors_layers[-1], errors_combined[-1]])
    max_error = 1.0  # worst possible error
    y_buffer = (max_error - min_error) * 0.1  # 10% buffer
    # ax.set_ylim(max(0, min_error - y_buffer), min(1.0, max_error + y_buffer))

    # Create filename suffix
    suffix = f"_etaL{eta_layers}_etaE{eta_epochs}"
    if fixed_epochs:
        suffix += f"_fixed{fixed_epochs}"
    suffix += "_fullepoch"  # Add suffix to indicate we're using fully trainable for epoch fidelity

    # Save plot with higher quality
    plt.tight_layout()
    plt.savefig(f"plots/sh_comparison{suffix}.pdf", dpi=300, bbox_inches="tight")
    plt.savefig(f"plots/sh_comparison{suffix}.png", dpi=300, bbox_inches="tight")
    print(f"Plot saved to plots/sh_comparison{suffix}.pdf")

    # Generate a second plot with focus on the best performing strategies
    # This is useful to zoom in on the most interesting part of the plot
    fig2, ax2 = plt.subplots(figsize=(8, 8))

    # Plot only the best strategies (layer fidelity and combined)
    ax2.plot(
        times_layers,
        errors_layers,
        marker=MARKERS[0],
        linestyle=LINE_STYLES[0],
        label=f"SH with Trainable Layers{epoch_info} (η={eta_layers})",
        color=tab10(0),
        linewidth=3,
        markersize=10,
    )

    ax2.plot(
        times_combined,
        errors_combined,
        marker="*",
        linestyle="-",
        label=f"SH with Combined Fidelities (η_layers={eta_layers}, η_epochs={eta_epochs})",
        color="black",
        linewidth=3,
        markersize=12,
    )

    # Set up plot
    ax2.set_xlabel("Wall-Clock Time (seconds)", fontweight="bold")
    ax2.set_ylabel("Best Validation Error", fontweight="bold")
    ax2.set_title(f"Best Successive Halving Strategies{epoch_info}", fontweight="bold", fontsize=18)
    ax2.grid(True, linestyle="--", alpha=0.7)
    ax2.legend(loc="upper right", frameon=True, fancybox=True, shadow=True, fontsize=10)

    # Find reasonable y-axis limits to focus on relevant range
    min_error = min([min(errors_layers), min(errors_combined)])
    max_error = max([max(errors_layers), max(errors_combined)])
    y_buffer = (max_error - min_error) * 0.1  # 10% buffer
    # ax2.set_ylim(max(0, min_error - y_buffer), min(1.0, max_error + y_buffer))

    # Annotate key points
    for i, (times, errors, label) in enumerate(
        [(times_layers, errors_layers, "Layers"), (times_combined, errors_combined, "Combined")]
    ):
        # Only annotate if there are points
        if len(times) > 1 and len(errors) > 1:
            ax2.annotate(
                f"{label}: {errors[-1]:.4f}",
                xy=(times[-1], errors[-1]),
                xytext=(10, 10 + i * (-20)),
                textcoords="offset points",
                fontsize=12,
                bbox={"boxstyle": "round,pad=0.3", "fc": "yellow", "alpha": 0.7},
                arrowprops={"arrowstyle": "->", "connectionstyle": "arc3,rad=.2"},
            )

    # Save focused plot
    plt.tight_layout()
    plt.savefig(f"plots/sh_comparison_best{suffix}.pdf", dpi=300, bbox_inches="tight")
    plt.savefig(f"plots/sh_comparison_best{suffix}.png", dpi=300, bbox_inches="tight")
    print(f"Focused plot saved to plots/sh_comparison_best{suffix}.pdf")

    plt.close(fig)
    plt.close(fig2)


def main():
    parser = argparse.ArgumentParser(
        description="Simulate Successive Halving with multiple fidelities"
    )
    parser.add_argument(
        "--base_dir", type=str, required=True, help="Base directory for grid search results"
    )
    parser.add_argument(
        "--eta", type=int, default=3, help="Default reduction factor for SH (default: 3)"
    )
    parser.add_argument(
        "--eta_layers",
        type=int,
        default=None,
        help="Reduction factor for layers fidelity (default: same as --eta)",
    )
    parser.add_argument(
        "--eta_epochs",
        type=int,
        default=None,
        help="Reduction factor for epochs fidelity (default: same as --eta)",
    )
    parser.add_argument(
        "--selected_layers",
        type=str,
        default=None,
        help="Comma-separated list of layer counts to include (default: all)",
    )
    parser.add_argument(
        "--figsize",
        type=str,
        default="12,8",
        help="Figure size in inches, formatted as width,height (default: 12,8)",
    )
    parser.add_argument(
        "--fixed_epochs",
        type=int,
        default=None,
        help="Fixed number of epochs for SH with layers fidelity (default: None, varies by bracket)",
    )

    args = parser.parse_args()

    # Process selected layers if provided
    selected_layers = None
    if args.selected_layers:
        try:
            selected_layers = [int(x) for x in args.selected_layers.split(",")]
        except ValueError:
            print(f"Invalid layer selection: {args.selected_layers}, using all available layers")

    # Parse figure size
    try:
        width, height = map(float, args.figsize.split(","))
        plt.rcParams["figure.figsize"] = (width, height)
    except ValueError:
        print(f"Invalid figure size: {args.figsize}, using default")

    plot_sh_comparison(
        args.base_dir,
        eta=args.eta,
        eta_layers=args.eta_layers,
        eta_epochs=args.eta_epochs,
        selected_layers=selected_layers,
        fixed_epochs=args.fixed_epochs,
    )


if __name__ == "__main__":
    main()
