##
## (c) Anonymous authors (2026)
##
## > Script to visualize learning curves for benchmark POMDP tasks
##
##

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
from matplotlib.ticker import ScalarFormatter, FuncFormatter, MaxNLocator
from matplotlib.lines import Line2D
from typing import List


def create_plot(input_dir: str, outdir: str = "figures") -> None:
    """Creating learning curve plot"""

    # Plot style
    textwidth_in = 7.00137
    aspect_ratio = 0.618
    fig_height_in = textwidth_in * aspect_ratio
    sns.set(style="darkgrid")
    plt.style.use('figures/icml.mplstyle')
    plt.rcParams.update({
        'axes.labelsize': 7,
        'xtick.labelsize': 7,
        'ytick.labelsize': 7,
        'legend.fontsize': 9,
        'axes.titlesize': 9
    })

    # Specifying .csv files containing task-specific results
    csv_files = [
        "POMDP-heavenhell_3-episodic-v0_final_0-19.csv",
        "POMDP-shopping_5-episodic-v1_final_0-19.csv",
        "extra-car-flag-v0_final_0-19.csv",
        "extra-cleaner-v0_final_0-19.csv",
        "gv_memory_four_rooms-7x7_final_0-19.csv",
        "gv_memory_four_rooms-9x9_final_0-19.csv"
    ]

    # Specifying environment names
    titles = [
        "Heaven-Hell-3",
        "Shopping-5",
        "Car-Flag",
        "Cleaner",
        "Memory-Four-Rooms-7x7",
        "Memory-Four-Rooms-9x9"
    ]

    # Mapping algorithm names in .csv files to plain names
    method_name_mapping = {
        "asym-a2c": "asym-A2C-hs",
        "asym-a2c-state": "asym-A2C-s",
        "informed-asym-a2c": "informed-asym-A2C",
        "a2c": "A2C",
    }

    # Defining x-ticks format
    def human_format(x, pos):
        def fmt(num, suffix):
            if num == int(num):
                return f'{int(num)}{suffix}'
            else:
                return f'{num:.1f}{suffix}'

        if x >= 1e9:
            return fmt(x / 1e9, 'B')
        elif x >= 1e6:
            return fmt(x / 1e6, 'M')
        elif x >= 1e3:
            return fmt(x / 1e3, 'K')
        else:
            return f'{int(x)}'

    # Specifying ylims for each subplot
    ylims = [
        (-0.5, 1.4), # Heaven-Hell-3
        (-150, 30), # Shopping-5
        (-0.2, 0.59), # Car-Flag
        (0, 85), # Cleaner
        (-5, 1.5), # Memory-Four-Rooms-7x7
        (-5, 0.5), # Memory-Four-Rooms-9x9
    ]


    # Consistent color coding
    methods = ["asym-a2c", "asym-a2c-state", "a2c", "informed-asym-a2c"]
    colors = sns.color_palette("tab10", n_colors=len(methods))

    # Creating figure with 2x3 subplots
    # Set up the figure layout
    fig, axes = plt.subplots(2, 3, figsize=(textwidth_in, fig_height_in))
    axes = axes.flatten()

    for i, (csv_file, title) in enumerate(zip(csv_files, titles)):
        ax = axes[i]

        # Reading CSV assuming columns: 'Timestep', 'Return', 'Method'
        df = pd.read_csv(f"{input_dir}/{csv_file}")

        # Grouping by timestep and method, computing mean and std
        grouped = df.groupby(['timestep', 'algorithm'])['return'].agg(['mean', 'std']).reset_index()

        for method, color in zip(methods, colors):
            method_data = grouped[grouped['algorithm'] == method]
            if not method_data.empty:
                ax.plot(method_data['timestep'], method_data['mean'], label=method_name_mapping[method], color=color,
                        linewidth=0.4)
                ax.fill_between(method_data['timestep'],
                                method_data['mean'] - method_data['std'],
                                method_data['mean'] + method_data['std'],
                                alpha=0.15, color=color, linewidth=0.0)

        # Creating subplot title
        ax.set_title(rf"\textbf{{({chr(97 + i)}) {title}}}")

        ax.set_xlabel("Time step", labelpad=2)
        ax.xaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        ax.ticklabel_format(style='sci', axis='x', scilimits=(5, 5))
        ax.xaxis.set_major_locator(MaxNLocator(nbins=5))

        # Setting task-specific ylims
        if ylims[i] is not None:
            ax.set_ylim(ylims[i])

    # Putting legend outside plots
    legend_handles = [
        Line2D([0], [0], color=color, linewidth=1.5, label=method_name_mapping[method])
        for method, color in zip(methods, colors)
    ]

    fig.legend(
        handles=legend_handles,
        loc='upper center',
        ncol=len(methods),
        frameon=True,
        fancybox=True,
        shadow=False,
        framealpha=0.8,
        edgecolor='gray'
    )

    plt.tight_layout(rect=[0, 0, 1, 0.95], pad=1.1)
    plt.subplots_adjust(wspace=0.2, hspace=0.42)
    # plt.show()
    plt.savefig(f"{outdir}/experiments_figure-1_learning_curves.pdf", dpi=600, bbox_inches='tight', format='pdf')


if __name__ == "__main__":
    create_plot(input_dir="results/preprocessed")
