import itertools
import json
from pathlib import Path

import matplotlib
from pymoo.indicators.hv import Hypervolume

from data.mpec.utils.non_domination_util import non_dominated_mask, duplicated_mask
from plot_utils import HandlerRulerWithHashBelow, HandlerCurvedFlatTopViolin, add_figure_label

matplotlib.use('Agg')

import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import numpy as np
import seaborn as sns

# Set global sizes
plt.rcParams['axes.labelsize'] = 14      # X and Y axis labels
plt.rcParams['axes.titlesize'] = 16     # Plot titles
plt.rcParams['xtick.labelsize'] = 12    # X-axis tick labels
plt.rcParams['ytick.labelsize'] = 12    # Y-axis tick labels
plt.rcParams['legend.fontsize'] = 12    # Legend text
fontsize = 12


def scale(return_):
    arr = np.array(return_)
    arr = arr/-arr[-1]
    return list(arr)

def strip_trajectory_length(return_):
    arr = np.array(return_)
    arr = arr[:-1]
    return list(arr)

def strip_trajectory_length_vec(return_vec):
    arr = np.array(return_vec)
    arr = [ret[:-1] for ret in arr]
    return list(arr)

def invert(return_):
    arr = np.array(return_)
    arr = arr * [-1]
    return list(arr)

# 1. Scatter plot of return points
def plot_ablation_figure_1(non_stationary_returns, stationary_returns, naive_returns, figure_label=None):
    fig, ax = plt.subplots(figsize=(6,4))

    # Helper to plot multiple seeds
    def plot_method(seed_dfs, color, marker, label):
        for df in seed_dfs:
            ax.scatter(*zip(*df.values), color=color, marker=marker, alpha=0.2, s=40, zorder=1)
        # Compute aggregate Pareto front
        all_points_df = pd.concat(seed_dfs)
        all_points = np.array(all_points_df.values.tolist())
        non_dominated_indices = non_dominated_mask(all_points)
        duplicated_indices = duplicated_mask(all_points)
        all_points_df = all_points_df[non_dominated_indices & ~duplicated_indices]

        ax.scatter(*zip(*all_points_df.values), color=color, marker=marker, label=label, s=60, edgecolor="k", zorder=3, alpha=0.75)

    # ax.scatter(*zip(*non_stationary_returns.values), color='C0', label='Non-Stationary', marker='s', s=50, zorder=2, alpha=0.75)
    # ax.scatter(*zip(*stationary_returns.values), color='C1', label='Stationary', marker='^', s=50, zorder=2, alpha=0.75)
    # ax.scatter(*zip(*naive_returns.values), color='C2', label='Naive Policy Selection', marker='o', s=50, zorder=2, alpha=0.75)

    plot_method(non_stationary_returns, 'C0', 's', 'Non-Stationary')
    plot_method(stationary_returns, 'C1', '^', 'Stationary')
    plot_method(naive_returns, 'C2', 'o', 'Naive Policy Selection')

    ax.set_xlabel('Time Penalty')
    ax.set_ylabel('Cumulative Treasure Value')
    ax.set_title('Policy Coverage')
    ax.legend()

    plt.grid(which='major', linestyle='-', color='lightgray')

    if figure_label is not None:
        add_figure_label(fig, figure_label)

    plt.tight_layout(pad=0.5)

    Path(f"{task}/figures").mkdir(parents=True, exist_ok=True)
    plt.savefig(f"{task}/figures/ablation_figure_1.png")

# 2. Line plot: speed of convergence & spurious dominations
def plot_ablation_figure_2(hv_non_stationary, hv_stationary, hv_spurious1, hv_spurious2, hv_spurious3, figure_label=None):

    def plot_with_uncertainty(seed_dfs, color, marker, label, markevery):

        df = pd.concat(seed_dfs, axis=1)
        mean = df.mean(axis=1)
        std = df.std(axis=1)

        x = np.arange(len(mean))

        # Plot mean line
        ax1.plot(x, mean, color=color, marker=marker, markevery=markevery, markersize=7, alpha=0.6, label=label)

        # Shaded std area
        ax1.fill_between(x, mean - std, mean + std, color=color, alpha=0.2)

    fig, ax1 = plt.subplots(figsize=(6, 4))

    plot_with_uncertainty(hv_non_stationary, 'C0', 's', 'Non-Stationary', (0, 2000))
    plot_with_uncertainty(hv_spurious1, 'C3', 's', 'Horizon-length bias', (1000, 2000))
    plot_with_uncertainty(hv_spurious2, 'C4', 's', 'Reward-frequency bias', (1000, 2000))
    plot_with_uncertainty(hv_stationary, 'C1', '^', 'Stationary', (0, 2000))
    plot_with_uncertainty(hv_spurious3, 'C6', '^', 'Structural-info bias', (1000, 2000))

    ax1.margins(x=0, y=0.01)
    ax1.set_xlabel('Time Steps')
    ax1.set_ylabel('Hypervolume')
    ax1.grid(which='major', linestyle='-', color='lightgray')

    # Legend
    color_handles = [
        mlines.Line2D([], [], color='C0', marker='s', markersize=7, label='Non-Stationary'),
        mlines.Line2D([], [], color='C3', marker='s', markersize=7, label='Horizon-length bias'),
        mlines.Line2D([], [], color='C4', marker='s', markersize=7, label='Reward-frequency bias'),
        mlines.Line2D([], [], color='C1', marker='^', markersize=7, label='Stationary'),
        mlines.Line2D([], [], color='C6', marker='^', markersize=7, label='Structural-info bias')
    ]
    ax1.legend(handles=color_handles, loc='upper left', ncol=2, columnspacing=1)
    ax1.set_title('Convergence Speed and Spurious Domination')

    ymax = ax1.get_ylim()[1]
    ax1.set_ylim(0, ymax * 1.5)

    if figure_label is not None:
        add_figure_label(fig, figure_label)

    plt.tight_layout(pad=0.5)
    Path(f"{task}/figures").mkdir(parents=True, exist_ok=True)
    plt.savefig(f"{task}/figures/ablation_figure_2.png")


# 3. Violin plot: trajectories and return
def plot_ablation_figure_3(df_returns, df_lengths, methods, left_methods, right_methods, figure_label=None):
    freqs = df_lengths.groupby(["method", "seed", "length"]).size().reset_index(name="count")

    mean_freqs = freqs.groupby(["method", "length"])["count"].mean().reset_index()

    expanded_rows = []
    for _, row in mean_freqs.iterrows():
        n = int(round(row["count"]))  # convert mean count to integer
        expanded_rows.extend([{"method": row["method"], "length": row["length"]}] * n)
    df_expanded = pd.DataFrame(expanded_rows)

    # Step 4: plot
    fig, ax1 = plt.subplots(figsize=(6, 4))

    # Left methods
    sns.violinplot(
        data=df_expanded[df_expanded["method"].isin(left_methods)],
        x="method", y="length",
        inner="box", bw_adjust=2, scale="width", cut=0,
        ax=ax1, color="lightgray",
        order=left_methods
    )

    # Right methods
    sns.violinplot(
        data=df_expanded[df_expanded["method"].isin(right_methods)],
        x="method", y="length",
        inner="box", bw_adjust=2, scale="width", cut=0,
        ax=ax1, color="lightgray",
        order=left_methods + right_methods
    )

    ax1.set_xlabel("Method - NS max length (left), SSM ON or OFF (right)")

    ax1.axvline(len(left_methods) - 0.45, color="black", linestyle="--")
    ax1.set_ylabel("Trajectory Length")
    ax1.set_xlim(-0.5, len(df_lengths["method"].unique()) - 0.5)

    # Overlay promised & realized hypervolume points
    ax2 = ax1.twinx()

    colors = {"chosen": "blue", "collected": "red"}  # optional color map

    for i, m in enumerate(methods):
        sub = df_returns[df_returns["method"] == m].groupby("method")[["collected_return", "desired_return"]].mean()

        x_coords_realized = np.array([i] * len(sub))
        x_coords_promised = np.array([i] * len(sub))

        all_x = np.concatenate([x_coords_realized, x_coords_promised])
        all_y = np.concatenate([sub["collected_return"], sub["desired_return"]])

        from matplotlib import colors as mcolors

        # 2. Assign marker properties for each point
        # `s` is the size in squared points. A good ratio for concentricity is often needed.
        sizes = [80] * len(sub) + [140] * len(sub)  # Adjust size ratio as needed
        facecolors = [mcolors.to_rgba(colors["collected"], alpha=0.6)] * len(sub) + ["none"] * len(sub)
        edgecolors = [mcolors.to_rgba("black", alpha=0.6)] * len(sub) + [mcolors.to_rgba(colors["chosen"], alpha=0.6)] * len(sub)
        linewidths = [1] * len(sub) + [2] * len(sub)

        ax2.scatter(
            all_x,
            all_y,
            s=sizes,
            facecolors=facecolors,
            edgecolors=edgecolors,
            linewidths=linewidths,
            zorder=10
        )

    ax2.scatter([], [], s=140, facecolors="none", edgecolors=colors["chosen"], linewidths=2, label="Promised")
    ax2.scatter([], [], s=80, facecolors=colors["collected"], edgecolors="black", label="Realized")

    leg = ax2.legend(loc="upper right", handlelength=1, handleheight=1.4, scatterpoints=1, bbox_to_anchor=(1.01, 1))
    for handle in leg.legendHandles:
        handle.set_sizes([80])  # force uniform legend marker size

    # Create dummy handles for each custom symbol
    violin_dummy = object()
    ruler_dummy = object()

    ax1.legend(
        [violin_dummy, ruler_dummy],  # both handles
        ["Total # of trajectories",
         "# of trajectories (width)"],  # corresponding labels
        handler_map={
            violin_dummy: HandlerCurvedFlatTopViolin(),
            ruler_dummy: HandlerRulerWithHashBelow()
        },
        loc="upper left",
        handleheight=1.4
    )

    ymax = ax2.get_ylim()[1]  # current top
    ax2.set_ylim(0, ymax + 140)
    ax2.set_ylabel("Hypervolume")

    # --- Rulers ---
    import matplotlib.collections as mcoll
    violins = [c for c in ax1.collections if isinstance(c, mcoll.PolyCollection)]
    violin_bodies = violins[:len(methods)]  # body polygons are first

    # Choose a common baseline for all rulers
    ymin, ymax = ax1.get_ylim()
    ruler_y = ymin + 0.01 * (ymax - ymin)

    # Expand y-limits so rulers and label fit
    ax1.set_ylim(ruler_y - 2, ymax + 0.03 * (ymax - ymin) + 12)

    # Group by method and length, then count
    counts_by_method_and_length = df_expanded.groupby(["method", "length"]).size().reset_index(name="count")
    max_count_df = counts_by_method_and_length.groupby("method")["count"].max().reset_index(name="max_count")

    max_count_df.loc[max_count_df['method'].isin(left_methods), 'width_reference'] = \
        max_count_df.loc[max_count_df['method'].isin(left_methods), 'max_count'].max()

    max_count_df.loc[max_count_df['method'].isin(right_methods), 'width_reference'] = \
        max_count_df.loc[max_count_df['method'].isin(right_methods), 'max_count'].max()

    for i, (m, poly) in enumerate(zip(methods, violin_bodies)):
        proportion = max_count_df.loc[max_count_df['method'] == m, 'max_count'].values[0] / max_count_df.loc[
            max_count_df['method'] == m, 'width_reference'].values[0]  # scale factor  # scale factor

        verts = poly.get_paths()[0].vertices
        center = verts[:, 0].mean()
        verts[:, 0] = i + (verts[:, 0] - center) * proportion

        verts = poly.get_paths()[0].vertices
        x_min, x_max = verts[:, 0].min(), verts[:, 0].max()
        # Bottom ruler shows actual width of the violin (proportional to density)
        counts_for_method = max_count_df[max_count_df["method"] == m]
        ax1.hlines(ruler_y, x_min, x_max, color="black", lw=2)
        ax1.text(i, ruler_y - 0.02 * (ymax - ymin), f"{counts_for_method['max_count'].iloc[0]}", ha="center", va="top", fontsize=10)

        sub = df_returns[df_returns["method"] == m]
        marker_y = sub["collected_return"].max()
        marker_y = max(marker_y, sub["desired_return"].max())

        trans = ax2.transData + ax1.transData.inverted()
        marker_y_in_ax1 = trans.transform((0, marker_y))[1]

        n = len(df_expanded[df_expanded["method"] == m])
        y_max_violin = verts[:, 1].max()
        y_label = y_max_violin + 0.02 * (ymax - ymin)
        # if y_max_violin + 4 > marker_y_in_ax1 > y_max_violin:
        #     y_label = marker_y_in_ax1 + 1.5
        if m == 'l=2':
            y_label += 1.3
        ax1.text(i, y_label, f"{n}", ha="center", va="bottom", fontsize=10, fontweight="bold")

    ax1.set_title("Promised vs Realized Returns")
    plt.xticks(rotation=20, ha="right")

    if figure_label is not None:
        add_figure_label(fig, figure_label)

    plt.tight_layout(pad=0.5)

    Path(f"{task}/figures").mkdir(parents=True, exist_ok=True)
    plt.savefig(f"{task}/figures/ablation_figure_3.png")


def plot_ablation_figure_4(
        split__split_and_later_update, split__split_reconnect_update,
        mismatch__split_and_later_update, mismatch__split_reconnect_update,
        figure_label=None):
    fig, ax = plt.subplots(figsize=(6, 4))

    colors = {'split_and_later_update': 'C0', 'split_reconnect_update': 'C1'}
    linestyles = {'split': '--', 'mismatch': '-'}

    def plot_with_uncertainty(seed_series_list, color, linestyle, label, fill=False):
        df = pd.concat(seed_series_list, axis=1)
        mean = df.mean(axis=1)
        std = df.std(axis=1)

        x = np.arange(len(mean))
        ax.plot(x, mean, color=color, linestyle=linestyle, label=label)
        # if fill:
        #     ax.fill_between(x, mean - std, mean + std, color=color, alpha=0.2)
        # else:
        #     ax.plot(x, mean + std, color=color, linestyle=':', dashes=(1, 4), alpha=0.7)
        #     ax.plot(x, mean - std, color=color, linestyle=':', dashes=(1, 4), alpha=0.7)

    # Plot each group with mean ± std
    plot_with_uncertainty(split__split_and_later_update,
                          colors['split_and_later_update'],
                          linestyles['split'],
                          'Split-and-later-update (# splits)')
    plot_with_uncertainty(split__split_reconnect_update,
                          colors['split_reconnect_update'],
                          linestyles['split'],
                          'Split-and-reconnect (# splits)')
    plot_with_uncertainty(mismatch__split_and_later_update,
                          colors['split_and_later_update'],
                          linestyles['mismatch'],
                          'Split-and-later-update (# mismatches)', fill=True)
    plot_with_uncertainty(mismatch__split_reconnect_update,
                          colors['split_reconnect_update'],
                          linestyles['mismatch'],
                          'Split-and-reconnect (# mismatches)', fill=True)

    ax.margins(x=0, y=0.01)
    ax.set_xlabel('Time Step')
    ax.set_ylabel('Count (cumulative)')
    ax.set_title('Trajectory Splits and Length Mismatches')

    # Legend: colors = policy type, linestyles = metric type
    color_handles = [
        mlines.Line2D([], [], color=colors['split_and_later_update'], label='Split-and-later-update'),
        mlines.Line2D([], [], color=colors['split_reconnect_update'], label='Split-and-reconnect')
    ]
    style_handles = [
        mlines.Line2D([], [], color='black', linestyle=linestyles['split'], label='# of splits'),
        mlines.Line2D([], [], color='black', linestyle=linestyles['mismatch'], label='# of mismatches')
    ]

    legend1 = ax.legend(handles=color_handles, loc='upper left')
    ax.add_artist(legend1)
    legend2 = ax.legend(handles=style_handles, loc='upper left', bbox_to_anchor=(0, 0.82))
    ax.add_artist(legend2)

    plt.grid(which='major', linestyle='-', color='lightgray')

    if figure_label is not None:
        add_figure_label(fig, figure_label)

    plt.tight_layout(pad=0.5)

    Path(f"{task}/figures").mkdir(parents=True, exist_ok=True)
    plt.savefig(f"{task}/figures/ablation_figure_4.png")


######################################################################################################################

def normalize_collected_reward(data):
    df = pd.DataFrame(data)
    df['collected_return'] = df.apply(lambda row: list(np.array(row['collected_return']) / -np.array(row['collected_return'])[-1] * -np.array(row['chosen_return'])[-1]), axis=1)
    df = df['collected_return']
    return df

def flip_treasure_and_time(df):
    df = df.apply(lambda x: list(reversed(x)))
    return df

def expand_until(arr, threshold, lower=False):
    results = []
    results.append(arr)

    arr = np.array(arr)
    m = 2
    while True:
        scaled = arr * m
        if (not lower and scaled[-1] > threshold) or (lower and scaled[-1] < threshold):
            break
        results.append(list(scaled))
        m += 1
    return results

def pad_stationary_cycles(df):
    number_of_steps = NUMBER_OF_STEPS

    # Expand each array with multipliers
    df = df.apply(
        lambda x: expand_until(x, -number_of_steps, lower=True)
    )

    df = df.explode(ignore_index=True)

    desired_return_array = np.array(df.values.tolist())
    non_dominated_indices = non_dominated_mask(desired_return_array)
    duplicated_indices = duplicated_mask(desired_return_array)
    df = df[non_dominated_indices & ~duplicated_indices]

    return df


def get_ablation_figure_1_data():
    as1_as2_NS_df_list = []
    as1_as2_S_df_list = []
    as1_naive_df_list = []

    for seed in seeds:
        with open(f"{task}/results/as1_as2/{task}/mpec_discrete/{seed}/as1_as2_NS.json", 'r') as file:
            as1_as2_NS = json.load(file)
            as1_as2_NS_df = normalize_collected_reward(as1_as2_NS)
            as1_as2_NS_df = as1_as2_NS_df.map(strip_trajectory_length)
            as1_as2_NS_df = flip_treasure_and_time(as1_as2_NS_df)
            as1_as2_NS_df_list.append(as1_as2_NS_df)

        opt_seed = 1 if seed in (3, 4) else seed
        with open(f"{task}/results/as1_as2/{task}/mpec_discrete/{opt_seed}/as1_as2_S.json", 'r') as file:
            as1_as2_S = json.load(file)
            as1_as2_S_df = normalize_collected_reward(as1_as2_S)
            as1_as2_S_df = pad_stationary_cycles(as1_as2_S_df)
            as1_as2_S_df = as1_as2_S_df.map(strip_trajectory_length)
            as1_as2_S_df = flip_treasure_and_time(as1_as2_S_df)
            as1_as2_S_df_list.append(as1_as2_S_df)

        with open(f"{task}/results/as1/{task}/mpec_discrete/{seed}/as1_naive.json", 'r') as file:
            as1_naive = json.load(file)
            as1_naive_df = normalize_collected_reward(as1_naive)
            as1_naive_df = as1_naive_df.map(strip_trajectory_length)
            as1_naive_df = flip_treasure_and_time(as1_naive_df)
            as1_naive_df_list.append(as1_naive_df)

    return as1_as2_NS_df_list, as1_as2_S_df_list, as1_naive_df_list


def get_ablation_figure_2_data():

    def get_hypervolume(data, strip_trajectory_length=True):

        def apply_hp(i, p):
            print(i)
            return hv(np.array(p))

        hv = Hypervolume(ref_point=ref_point)

        policies = [d["debug_track_policies"] for d in data]

        if strip_trajectory_length:
            policies = [strip_trajectory_length_vec(invert(p)) for p in policies]
        else:
            policies = [invert(p) for p in policies]

        return [apply_hp(i, p) if len(p) > 0 else 0 for i, p in enumerate(policies)]

    score_function = get_hypervolume

    hv_as1_as2_NS_df_list = []
    hv_as1_as2_S_df_list = []
    hv_as3_trajectory_length_off_df_list = []
    hv_as4_average_reward_off_df_list = []
    hv_as5_cycle_detection_off_df_list = []

    for seed in seeds:
        with open(f"{task}/results/as1_as2/{task}/mpec_discrete/{seed}/as1_as2_NS_debug.json", 'r') as file:
            as1_as2_NS = json.load(file)
            hv_as1_as2_NS_values = score_function(as1_as2_NS)
            hv_as1_as2_NS_df = pd.DataFrame(hv_as1_as2_NS_values, columns=['debug_track_policies'])
            hv_as1_as2_NS_df_list.append(hv_as1_as2_NS_df)

        with open(f"{task}/results/as1_as2/{task}/mpec_discrete/{seed}/as1_as2_S_debug.json", 'r') as file:
            as1_as2_S = json.load(file)
            hv_as1_as2_S_values = score_function(as1_as2_S)
            hv_as1_as2_S_df = pd.DataFrame(hv_as1_as2_S_values, columns=['debug_track_policies'])
            hv_as1_as2_S_df_list.append(hv_as1_as2_S_df)

        with open(f"{task}/results/as3/{task}/mpec_discrete/{seed}/as3_trajectory_length_off_debug.json", 'r') as file:
            as3_trajectory_length_off = json.load(file)
            hv_as3_trajectory_length_off_values = score_function(as3_trajectory_length_off, strip_trajectory_length=False)
            hv_as3_trajectory_length_off_df = pd.DataFrame(hv_as3_trajectory_length_off_values, columns=['debug_track_policies'])
            hv_as3_trajectory_length_off_df_list.append(hv_as3_trajectory_length_off_df)

        with open(f"{task}/results/as4/{task}/mpec_discrete/{seed}/as4_average_reward_off_debug.json", 'r') as file:
            as4_average_reward_off = json.load(file)
            hv_as4_average_reward_off_values = score_function(as4_average_reward_off)
            hv_as4_average_reward_off_df = pd.DataFrame(hv_as4_average_reward_off_values, columns=['debug_track_policies'])
            hv_as4_average_reward_off_df_list.append(hv_as4_average_reward_off_df)

        with open(f"{task}/results/as5/{task}/mpec_discrete/{seed}/as5_ssm_off_debug.json", 'r') as file:
            as5_cycle_detection_off = json.load(file)
            hv_as5_cycle_detection_off_values = score_function(as5_cycle_detection_off)
            hv_as5_cycle_detection_off_df = pd.DataFrame(hv_as5_cycle_detection_off_values, columns=['debug_track_policies'])
            hv_as5_cycle_detection_off_df_list.append(hv_as5_cycle_detection_off_df)

    return (hv_as1_as2_NS_df_list, hv_as1_as2_S_df_list, hv_as3_trajectory_length_off_df_list,
            hv_as4_average_reward_off_df_list, hv_as5_cycle_detection_off_df_list)


def get_ablation_figure_3_data(as6_methods, as7_methods):

    as6_df_lengths_list = []
    as7_df_lengths_list = []

    for method in as6_methods:
        method = method.replace('=', '_').lower()

        method_as6_df_lengths_list = []
        for seed in seeds:
            with open(f"{task}/results/as6/{task}/mpec_discrete/{seed}/as6_{method}_debug.json", 'r') as file:
                as6 = json.load(file)
                as6_df_lengths = pd.DataFrame(as6[-1])
                as6_df_lengths = as6_df_lengths.to_numpy().flatten().tolist()
                method_as6_df_lengths_list.append(as6_df_lengths)
        as6_df_lengths_list.append(method_as6_df_lengths_list)

    method_as7_df_lengths_list = []
    for seed in seeds:
        with open(f"{task}/results/as7/{task}/mpec_discrete/{seed}/as7_ssm_off_debug.json", 'r') as file:
            as7 = json.load(file)
            as7_df_lengths = pd.DataFrame(as7[-1])
            as7_df_lengths = as7_df_lengths.to_numpy().flatten().tolist()
            method_as7_df_lengths_list.append(as7_df_lengths)

    for seed in seeds:
        with open(f"{task}/results/as7/{task}/mpec_discrete/{seed}/as7_ssm_on_debug.json", 'r') as file:
            as7 = json.load(file)
            as7_df_lengths = pd.DataFrame(as7[-1])
            as7_df_lengths = as7_df_lengths.to_numpy().flatten().tolist()
            method_as7_df_lengths_list.append(as7_df_lengths)
    as7_df_lengths_list.append(method_as7_df_lengths_list)

    def get_hypervolume(data):
        df = pd.DataFrame(data)
        number_of_steps_series = df['collected_return'].apply(lambda x: int(np.array(x)[-1]*-1))
        number_of_steps_series.name = "number_of_steps"
        number_of_steps = number_of_steps_series.min()
        df = pd.concat([df, number_of_steps_series], axis=1)
        df['desired_return'] = df['chosen_return'].apply(lambda x: list((np.array(x)/-np.array(x)[-1]) * number_of_steps))
        df['collected_return'] = df['collected_return'].apply(lambda x: list((np.array(x) / -np.array(x)[-1]) * number_of_steps))
        df[['desired_return', 'collected_return']] = df[['desired_return', 'collected_return']].map(scale).map(strip_trajectory_length)

        desired_return_array = np.array(df['desired_return'].values.tolist())
        non_dominated_indices = non_dominated_mask(desired_return_array)
        duplicated_indices = duplicated_mask(desired_return_array)
        df = df[non_dominated_indices & ~duplicated_indices]

        df[['desired_return', 'collected_return']] = df[['desired_return', 'collected_return']].map(invert)
        df[['desired_return', 'collected_return']] = df[['desired_return', 'collected_return']].map(lambda x: list(np.array(x) * number_of_steps))

        hv = Hypervolume(ref_point=ref_point)
        df_hypervolume = df[['desired_return', 'collected_return']].apply(lambda col: hv(np.vstack(col.values))).to_frame().T
        return df_hypervolume

    as6_df_hypervolume_list = []
    as7_df_hypervolume_list = []

    for method in as6_methods:
        method = method.replace('=', '_').lower()

        method_as6_df_hypervolume_list = []
        for seed in seeds:
            with open(f"{task}/results/as6/{task}/mpec_discrete/{seed}/as6_{method}.json", 'r') as file:
                as6 = json.load(file)
                as6_df_hypervolume = get_hypervolume(as6)
                method_as6_df_hypervolume_list.append(as6_df_hypervolume)
        as6_df_hypervolume_list.append(method_as6_df_hypervolume_list)

    for method in as7_methods:
        method = method.replace('=', '_').lower()

        method_as7_df_hypervolume_list = []
        for seed in seeds:
            with open(f"{task}/results/as7/{task}/mpec_discrete/{seed}/as7_ssm_{method}.json", 'r') as file:
                as7 = json.load(file)
                as7_df_hypervolume = get_hypervolume(as7) if len(as7) else pd.DataFrame()
                method_as7_df_hypervolume_list.append(as7_df_hypervolume)
        as7_df_hypervolume_list.append(method_as7_df_hypervolume_list)

    lengths = [*as6_df_lengths_list, *as7_df_lengths_list]
    returns = [*as6_df_hypervolume_list, *as7_df_hypervolume_list]

    df_methods = pd.DataFrame(itertools.product(methods, seeds), columns=["method", "seed"])
    flat_dfs = [df for sublist in lengths for df in sublist]
    df_lengths = pd.concat([df_methods, pd.DataFrame({'length': flat_dfs})], axis=1).explode('length', ignore_index=True)
    df_lengths['length'] = df_lengths['length'].astype(int)

    df_methods = pd.DataFrame(itertools.product(methods, seeds), columns=["method", "seed"])
    flat_dfs = [df for sublist in returns for df in sublist]
    df_returns = pd.concat([df_methods, pd.concat(flat_dfs, ignore_index=True)], axis=1)

    return df_returns, df_lengths


def get_ablation_figure_4_data():

    split__split_and_later_update_list = []
    mismatch__split_and_later_update_list = []
    split__split_reconnect_update_list = []
    mismatch__split_reconnect_update_list = []

    for seed in seeds:
        with open(f"{task}/results/as8/{task}/mpec_discrete/{seed}/as8_split_and_later_update_debug.json", 'r') as file:
            split_and_later_update = json.load(file)
            split_and_later_update_df = pd.DataFrame(split_and_later_update)
            split__split_and_later_update = split_and_later_update_df.loc[:, 'debug_trajectory_splits']
            mismatch__split_and_later_update = split_and_later_update_df.loc[:, 'debug_trajectory_mismatches']
            split__split_and_later_update_list.append(split__split_and_later_update)
            mismatch__split_and_later_update_list.append(mismatch__split_and_later_update)

        with open(f"{task}/results/as8/{task}/mpec_discrete/{seed}/as8_split_reconnect_update_debug.json", 'r') as file:
            split_reconnect_update = json.load(file)
            split_reconnect_update_df = pd.DataFrame(split_reconnect_update)
            split__split_reconnect_update = split_reconnect_update_df.loc[:, 'debug_trajectory_splits']
            mismatch__split_reconnect_update = split_reconnect_update_df.loc[:, 'debug_trajectory_mismatches']
            split__split_reconnect_update_list.append(split__split_reconnect_update)
            mismatch__split_reconnect_update_list.append(mismatch__split_reconnect_update)

    return (split__split_and_later_update_list, split__split_reconnect_update_list,
            mismatch__split_and_later_update_list, mismatch__split_reconnect_update_list)


task = "deep-sea-treasure-v0"
seeds = range(20)
ablation_figure_3_methods = ["l=2", "l=5", "l=10", "l=20", "l=30", "OFF", "ON"]
NUMBER_OF_STEPS=30

if task == "fruit-tree-v0":
    ref_point = (1, 1, 1, 1, 1, 1)
elif task == 'deep-sea-treasure-v0':
    ref_point = (0, NUMBER_OF_STEPS + 1)

methods = ablation_figure_3_methods
as6_methods = methods[:-2]
as7_methods = methods[-2:]

if task == 'deep-sea-treasure-v0':
    # figure_1_data = get_ablation_figure_1_data()
    # plot_ablation_figure_1(*figure_1_data, figure_label='a)')

    figure_2_data = get_ablation_figure_2_data()
    plot_ablation_figure_2(*figure_2_data, figure_label='b)')
    #
    # figure_3_data = get_ablation_figure_3_data(as6_methods, as7_methods)
    # plot_ablation_figure_3(*figure_3_data, methods, as6_methods, as7_methods, figure_label='c)')
    #
    # figure_4_data = get_ablation_figure_4_data()
    # plot_ablation_figure_4(*figure_4_data, figure_label='d)')

if task == 'fruit-tree-v0':
    figure_2_data = get_ablation_figure_2_data()
    plot_ablation_figure_2(*figure_2_data, figure_label='a)')

    figure_3_data = get_ablation_figure_3_data(as6_methods, as7_methods)
    plot_ablation_figure_3(*figure_3_data, methods, as6_methods, as7_methods, figure_label='b)')
