import os
from typing import Optional
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns


def save_figure(fig: plt.Figure, filename: str) -> None:
    """
    Save the given Figure object as an image file.

    Parameters:
    - fig (plt.Figure): The Figure object to be saved.
    - filename (str): The filename to save the Figure as.

    Returns:
    - None
    """
    # Parse the filename, and save it in the requested format
    file_parts = filename.split(".")
    file_format = file_parts[-1]
    fig.savefig(".".join(file_parts[:-1]), format=file_format, dpi=300)

    return




def aggregate_list(data_list: list[dict], quantity: str) -> pd.DataFrame:
    aggregated_data = []

    for entry in data_list:
        estimates = entry["dataframe"][quantity]

        percentiles = {
            "Method": entry["Method"],
            "Dataset": entry["Dataset"],
            "Percentile_5": estimates.quantile(0.05),
            "Percentile_25": estimates.quantile(0.25),
            "Percentile_75": estimates.quantile(0.75),
            "Percentile_95": estimates.quantile(0.95),
            "Median": estimates.median,
        }

        aggregated_data.append(percentiles)

    return pd.DataFrame(aggregated_data)


def bar_quantile_plot(
    data_list: list[dict],
    title: str = "",
    filename: Optional[str] = None,
    x_label: str = "",
    y_label: str = "",
    quantity: str = "Estimate",
    ground_truth: Optional[float] = None
) -> plt.Figure:
    """
    Bar plot with percentiles using the given data.

    Args:
    data (pd.DataFrame): DataFrame containing the data to plot. It should have the following columns:
        - 'Method': Method names
        - 'Median': Median values
        - 'Percentile_5': 5th percentile values
        - 'Percentile_25': 25th percentile values
        - 'Percentile_75': 75th percentile values
        - 'Percentile_95': 95th percentile values
    """
    # Get number of methods
    num_methods = len(data_list)

    fig = plt.figure(figsize=(2*num_methods, 6))

    color_schemes = [
        # Format: (color for 5th to 95th percentile, color for 25th to 75th percentile)
        ("lightblue", "blue"),  # Scheme 1
        ("lightgreen", "green"),  # Scheme 2
        ("lightcoral", "red"),  # Scheme 3
        ("lightgray", "gray"),  # Scheme 4
        ("lightyellow", "yellow"),  # Scheme 5
        ("thistle", "purple"),  # Scheme 6
        ("peachpuff", "darkorange"),  # Scheme 7
        ("powderblue", "cadetblue"),  # Scheme 8
        ("lightsalmon", "salmon"),  # Scheme 9
        ("palegoldenrod", "gold")  # Scheme 10
    ]
    
    # Aggregate the data
    #aggregated_data = aggregate_list(data_list, quantity=quantity)


    # Convert the aggregated data list to a DataFrame
    df = pd.DataFrame()

    for entry in data_list:
        entry_df = entry["dataframe"]
        entry_df["Method"] = entry["Method"]
        entry_df["Dataset"] = entry["Dataset"]
        df = pd.concat([df, entry_df], ignore_index=True)

    sns.boxplot(
            y=quantity,
            x="Method",
            data=df,
            palette="husl",
            width=0.5,
            linewidth=1,
            fliersize=0,
            ax=plt.gca(),
        )
    
    # If the ground truth value is provided, plot it as a horizontal line
    if ground_truth:
        plt.axhline(y=ground_truth, color="red", linestyle="--", label="Ground Truth")

    # Customize the plot
    plt.xticks(rotation=45, ha="right", fontsize=10)
    plt.yticks(fontsize=10)
    plt.xlabel(x_label)
    plt.ylabel(y_label, fontsize=12)
    plt.title(title, fontsize=14)
    plt.grid(axis="y", linestyle="--", alpha=0.7)
    plt.tight_layout()

    # Save the plot if a filename is provided
    if filename:
        plt.savefig(filename, format="svg", dpi=300)

    return fig


def generate_comparison_bar_plot(
    data: pd.DataFrame,
    x_label: str = "",
    y_label: str = "",
    filename: Optional[str] = None,
    groupby: str = "Dataset",
) -> plt.Figure:
    """Given a DataFrame, generate a comparison plot (Bar plots with bars subdivided in groups).

    Args:
        data (pd.DataFrame): A DataFrame containing the quality comparison data. 
        It should a column for the method names and the remaining columns for metrics.
        x_label (str, optional): The label for the x-axis of the plot. Defaults to an empty string.
        y_label (str, optional): The label for the y-axis of the plot. Defaults to an empty string.
        filename (str, optional): The filename to save the plot as an SVG file. 
        If not provided, the plot will not be saved.

    Returns:
        plt.Figure: The generated comparison plot as a matplotlib Figure object.
    """

    groupings = data[groupby].tolist()

    colors = [
        "#E74C3C",
        "#3498DB",
        "#1ABC9C",
        "#9B59B6",
        "#34495E",
        "#F39C12",
        "#8E44AD",
        "#3498DB",
    ]

    # Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle("Quality Comparison by Provider (Preliminary Results)", fontsize=16)

    for ax, category in zip(axes.flatten(), groupings):
        bars = ax.bar(groupings, data[category], color=colors)

        # Adding data labels
        for bar in bars:
            yval = bar.get_height()
            ax.text(
                bar.get_x() + bar.get_width() / 2,
                yval + 1,
                f"{yval}%",
                ha="center",
                va="bottom",
                fontsize=10,
            )

        ax.set_title(category)
        ax.set_ylim(0, 80)
        ax.set_xticklabels(groupings, rotation=45, ha="right")
        ax.grid(True, axis="y", linestyle="--", linewidth=0.7)

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.xlabel(x_label)
    plt.ylabel(y_label)

    # Save the plot if a filename is provided
    if filename:
        plt.savefig(filename, format="svg", dpi=300)

    return fig


def draw_plots(
    metrics_list: list[dict],
    results_list : list[dict],
    runtime_list: list[dict],
    experiments_folder: str,
) -> None:
    """
    Draw plots based on the given metrics, results, and runtime data.

    Parameters:
    - metrics_dict (dict): A dictionary containing the computed error metrics for each experiment.
    - results_df (pd.DataFrame): A DataFrame containing the aggregated results.
    - runtime_df (pd.DataFrame): A DataFrame containing the aggregated runtime data.
    - experiments_folder (str): The path to the folder containing the experiment results.

    Returns:
    - None
    """

    # Create figures folder to store the plots
    figures_folder = os.path.join(experiments_folder, "figures")
    if not os.path.exists(figures_folder):
        os.makedirs(figures_folder)

    # TODO: Code for drawing plots goes here
    # Plot estimate values with a bar plot, where for each method we have the median and percentiles
    bar_quantile_plot(results_list, "Estimate Values", filename=os.path.join(figures_folder, "estimate_values.svg"), quantity="Estimate")
    bar_quantile_plot(runtime_list, "Estimation Time", filename=os.path.join(figures_folder, "runtime_values.svg"), quantity="Estimation Time")

    # Metrics plots
    bar_quantile_plot(metrics_list, title = "MSE", filename=os.path.join(figures_folder, "mse.svg"), quantity="SE")
    bar_quantile_plot(metrics_list, title = "MAE", filename=os.path.join(figures_folder, "mae.svg"), quantity="MAE")
    #generate_comparison_bar_plot(metrics_dict, filename=os.path.join(figures_folder, "metrics.svg"))

    return
