import os
import pickle
import random
from typing import List, Dict, Union

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from matplotlib.patches import Patch
import wandb
from tqdm import tqdm

from rliable import library as rly
from rliable import metrics
from rliable import plot_utils


os.makedirs("out", exist_ok=True)
os.makedirs("bin", exist_ok=True)
os.makedirs("tables", exist_ok=True)

# import matplotlib.font_manager as fm
# font_path = "/usr/share/fonts/truetype/gentiumplus/GentiumPlus-Regular.ttf"
# custom_font = fm.FontProperties(fname=font_path)
# plt.rcParams['font.family'] = custom_font.get_name()
#
# fm.findSystemFonts(fontpaths=None, fontext='ttf')
# font_dirs = ["/usr/share/fonts/truetype/gentiumplus/GentiumPlus-Regular.ttf"]
# font_files = fm.findSystemFonts(fontpaths=font_dirs)
#
# for font_file in font_files:
#     fm.fontManager.addfont(font_file)
# print([f.name for f in fm.fontManager.ttflist])

# set font
# plt.rcParams['font.family'] = "Gentium Plus"

# plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['figure.dpi'] = 300
sns.set(style="ticks", font_scale=0.75)

# from matplotlib.font_manager import fontManager, FontProperties
# path = "/usr/share/fonts/truetype/gentiumplus/GentiumPlus-Regular.ttf"
# fontManager.addfont(path)
# prop = FontProperties(fname=path)
# sns.set(font=prop.get_name(), style="whitegrid")

plt.rcParams.update({
    # 'font.family': 'gentiumplus',
    'font.serif': 'Geologica'
})
# plt.rcParams['font.family'] = 'gentiumplus'
# plt.rcParams['font.family'] = "Gentium Plus"


# plt.text(0.5, 0.5, 'Custom Font Example', fontsize=20)
# plt.xlabel("X axis")
# plt.ylabel("Y axis")
# plt.show()

dunno_colors = ['#2F3677', '#D56131', '#659157', '#650D1B']

def convert_scores_list_to_tex(score, algo):
    algo = "{" + algo + "}"
    return "& \\textbf{" + algo + "} & " + " & ".join(map(str, score)) + " \\\\"

with open('bin/scores.pickle', 'rb') as handle:
    profiles_data = pickle.load(handle)


def extract_metric(data, split, metric, dataset_filter="", prefix='eval'):
    result = {}
    for algo in data:
        result[algo] = {}
        for dataset in data[algo]:
            add_data = True
            if dataset_filter != "":
                if dataset_filter == 'full':
                    add_data = 'early' not in dataset and 'mid' not in dataset and 'late' not in dataset
                else:
                    add_data = dataset_filter in dataset
            if add_data:
                result[algo][dataset] = data[algo][dataset][f'{prefix}/{split}_{metric}']
    return result


def flatten(data, target_lens=4):
    flat = []
    for env in data:
        env_list = []
        # print(data[env])
        env_list += data[env]
        flat.append(env_list)
    return np.array(flat)

def plot_metrics(profiles_data, algorithms, split, metric, metric_name, dataset_filter="", pi_range=(0.4, 0.9), main_algo="AD", div_trials=1, profiles=True, min_val=0, max_val=1, suffix=""):
    colors_map = {alg: clr for (alg, clr) in zip(algorithms, dunno_colors)}

    orig_suffix = suffix
    sns.set(style="ticks", font_scale=0.5)
    # sns.set(font=prop.get_name(), style="ticks", font_scale=0.5)
    algorithms = list(algorithms)
    extracted_metric = extract_metric(profiles_data, split, metric, dataset_filter=dataset_filter)
    suffix = f"{split}_{metric}"
    if dataset_filter != "":
        suffix += "_" + dataset_filter
    if len(orig_suffix) > 0:
        suffix += "_" + orig_suffix

    flat_profiles_data = {algo: flatten(extracted_metric[algo]) for algo in extracted_metric}

    normalized_score_dict = {
        k: flat_profiles_data[k].T for k in algorithms
    }
    aggregate_func = lambda x: np.array([
      metrics.aggregate_median(x),
      metrics.aggregate_iqm(x),
      metrics.aggregate_mean(x),
      # metrics.aggregate_optimality_gap(x)
    ])
    aggregate_scores, aggregate_score_cis = rly.get_interval_estimates(
      normalized_score_dict, aggregate_func, reps=10000 // div_trials)
    plot_utils.plot_interval_estimates(
      aggregate_scores, aggregate_score_cis,
      metric_names=['Median', 'IQM', 'Mean',
                    # 'Optimality Gap'
                    ],
      algorithms=algorithms, xlabel=metric_name, colors=colors_map)
    plt.tight_layout()
    plt.savefig(f"out/metrics_{suffix}.pdf", dpi=300, bbox_inches='tight')
    plt.close()

    if profiles:
        sns.set(style="ticks", font_scale=0.95)
        # sns.set(font=prop.get_name(), style="ticks", font_scale=0.7)
        thresholds = np.linspace(min_val, max_val, 50)
        score_distributions, score_distributions_cis = rly.create_performance_profile(
            normalized_score_dict, thresholds)
        # Plot score distributions
        fig, ax = plt.subplots(ncols=1, figsize=(7, 5))
        # plt.legend()
        plot_utils.plot_performance_profiles(
            score_distributions, thresholds,
            performance_profile_cis=score_distributions_cis,
            colors=colors_map,
            xlabel=metric_name + r' $(\tau)$',
            ax=ax,
            legend=True,
        )
        ax.set_xlabel(ax.get_xlabel(), fontsize=12)
        plt.savefig(f"out/perf_profiles_{suffix}.pdf", dpi=300, bbox_inches='tight')
        plt.close()

    # if type(main_algo) is not list:
    #     algorithms.remove(main_algo)
    #     algorithm_pairs = {
    #         f'{k},{main_algo}': (flat_profiles_data[k].T, flat_profiles_data[main_algo].T) for k in algorithms
    #     }
    # else:
    #     algorithm_pairs = {}
    #     for ma in main_algo:
    #         algorithms.remove(ma)
    #         for alg in algorithms:
    #             if ma in alg:
    #                 algorithm_pairs[f'{alg},{ma}'] = (flat_profiles_data[alg].T, flat_profiles_data[ma].T)
    #
    # average_probabilities, average_prob_cis = rly.get_interval_estimates(
    #   algorithm_pairs, metrics.probability_of_improvement, reps=500 // div_trials)
    # ax = plot_utils.plot_probability_of_improvement(average_probabilities, average_prob_cis)
    # ax.set_xlim(pi_range[0], pi_range[1])
    # # plt.show()
    # plt.savefig(f"out/improvement_probability_{suffix}.pdf", dpi=300, bbox_inches='tight')
    # plt.close()


def get_table(profiles_data, algorithms, split, metric, dataset_filter="", save_suffix="", prefix="eval"):
    algorithms = list(algorithms)
    extracted_metric = extract_metric(profiles_data, split, metric, dataset_filter=dataset_filter, prefix=prefix)
    suffix = save_suffix + f"{split}_{metric}"
    if dataset_filter != "":
        suffix += "_" + dataset_filter
    datasets = [k for k in extracted_metric[algorithms[0]]]
    table_text = """\\begin{table}[ht]
    \\label{tab:}
    \\begin{center}
    \\caption{}
    \\begin{small}
    \\begin{adjustbox}{max width=\columnwidth}
		\\begin{tabular}{l|rrrr}
		\\toprule
	"""
    table_text += "\\textbf{Dataset} & " + " & ".join(map(lambda x: "\\textbf{" + x + "}", algorithms)) + "\\\\\n" + "\\midrule\n"
    sum_perf = {a: 0 for a in algorithms}
    for ds in datasets:
        table_row = f"{ds.replace('goals_', '').replace('data_', '').replace('_histories', '').replace('_', '-').replace('k2d', 'K2D9').replace('DR-', 'DR9-')}"
        for alg in algorithms:
            table_row += f" & {np.mean(extracted_metric[alg][ds]):1.2f} $\\pm$ {np.std(extracted_metric[alg][ds]):1.2f}"
            sum_perf[alg] += np.sum(extracted_metric[alg][ds])
        table_row += "\\\\\n"
        table_text += table_row
    table_text += "\\midrule\n"
    table_text += "Average"
    for alg in algorithms:
        table_text += f" & {sum_perf[alg] / len(datasets) / 4:1.2f}"
    table_text += "\\\\\n"
    table_text += """\\end{tabular}
        \\end{adjustbox}
    \\end{small}
    \\end{center}
    \\vskip -0.1in
\end{table}
    """
    with open(f"tables/{split}_{metric}_{dataset_filter}_{save_suffix}_{prefix}.tex", "w") as f:
        f.write(table_text)


def plot_quadrangle(data, algorithms, split, plot_name, dataset_filter="", save_suffix="", metrics=['NAUC', 'Return after 100 episodes', 'Return after 50 episodes', 'Return after 25 episodes',]):
    sns.set(style="whitegrid", font_scale=1.5)
    colors_map = {alg: clr for (alg, clr) in zip(algorithms, dunno_colors)}
    # sns.set(font=prop.get_name(), style="whitegrid", font_scale=1.0)

    suffix = save_suffix + f"{split}"
    if dataset_filter != "":
        suffix += "_" + dataset_filter

    auc = extract_metric(data, split, "auc", dataset_filter=dataset_filter)
    ep25 = extract_metric(data, split, "mean_return_quarter", dataset_filter=dataset_filter)
    ep50 = extract_metric(data, split, "mean_return_half", dataset_filter=dataset_filter)
    ep100 = extract_metric(data, split, "mean_return", dataset_filter=dataset_filter)

    auc = {alg: np.mean([auc[alg][ds] for ds in auc[alg]]) for alg in algorithms}
    ep25 = {alg: np.mean([ep25[alg][ds] for ds in ep25[alg]]) for alg in algorithms}
    ep50 = {alg: np.mean([ep50[alg][ds] for ds in ep50[alg]]) for alg in algorithms}
    ep100 = {alg: np.mean([ep100[alg][ds] for ds in ep100[alg]]) for alg in algorithms}

    performance = {
        alg: [auc[alg], ep100[alg], ep50[alg], ep25[alg]]
        for alg in algorithms
        # 'Algorithm A': [0.8, 0.6, 0.7, 0.9],
        # 'Algorithm B': [0.5, 0.7, 0.6, 0.8],
        # 'Algorithm C': [0.9, 0.8, 0.5, 0.6],
    }

    # Normalize data to range [0, 1]
    normalized_performance = {
        algo: np.array(values) / max(max(performance[algo]), 1)
        for algo, values in performance.items()
    }

    # Coordinates for quadrangle vertices (in polar coordinates)
    angles = np.linspace(np.pi / 4, 2 * np.pi + np.pi / 4, len(metrics), endpoint=False).tolist()
    angles += angles[:1]  # Ensure closed loop for polygons

    # Create plot
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_aspect('equal')

    # Plot each algorithm's quadrangle with different colors
    colors = dunno_colors
    for algo, edge_color, values in zip(algorithms, colors, normalized_performance.values()):
        values = np.append(values, values[0])  # Ensure closed loop
        x = values * np.cos(angles)
        y = values * np.sin(angles)

        # Create and add the polygon with edge color only
        points = np.c_[x, y]
        polygon = Polygon(
            points, closed=True, fill=False, edgecolor=edge_color, linewidth=2, label=algo
        )
        ax.add_patch(polygon)
        ax.plot(x, y, color=edge_color, linewidth=2, alpha=0.6)  # Outline with edge color

    # Add radial axis lines, numerical ticks, and labels
    tick_values = np.linspace(0.2, 1., 5)  # Tick levels
    for i, metric in enumerate(metrics):
        angle = angles[i]
        x_end = np.cos(angle)
        y_end = np.sin(angle)

        # Draw radial line
        ax.plot([0, x_end], [0, y_end], color='gray', linewidth=1, linestyle='dotted')

        # Add ticks and numerical labels along the axis
        for tick in tick_values:
            x_tick = tick * x_end
            y_tick = tick * y_end
            ax.scatter(x_tick, y_tick, color='gray', s=10)  # Mark the tick
            ax.text(
                x_tick * 1.02, y_tick * 1.02, f'{tick:.1f}',
                ha='center', va='center', fontsize=12, color='black'
            )

    # Add metric labels at the outer edges
    for i, metric in enumerate(metrics):
        angle = angles[i]
        x_label = 1.1 * np.cos(angle)
        y_label = 1.1 * np.sin(angle)
        ax.text(x_label, y_label, metric, ha='center', va='center', fontsize=14, weight='bold')

    # Add concentric circles for reference levels
    for tick in tick_values:
        circle = plt.Circle((0, 0), radius=tick, color='gray', fill=False, linestyle='dotted', alpha=0.7)
        ax.add_artist(circle)

    # Add legend and title
    # ax.legend(loc='upper right')
    plt.legend(loc='center left', bbox_to_anchor=(0.85, 0.5))
    ax.set_title(plot_name, fontsize=14)

    # Adjust layout to provide more space for the title
    # plt.subplots_adjust(top=1.0)  # Increase space at the top
    import matplotlib.image as mpimg

    # Load the background image
    background_image = mpimg.imread('./labs.jpg')  # Replace with your image path

    # Add the image to the background
    # ax.imshow(
    #     background_image,
    #     extent=[-1.0, 1.0, -1.0, 1.0],  # Adjust extent to cover the entire plot area
    #     aspect='auto',
    #     alpha=0.8,  # Transparency of the image
    #     zorder=0  # Ensure the image is in the background
    # )

    # Remove default axis
    ax.axis('off')
    plt.tight_layout()
    plt.savefig(f"out/quad_{suffix}.pdf", dpi=300, bbox_inches='tight')
    plt.close()
    # plt.show()


def plot_janus(data, algorithms, split, plot_name, dataset_filter="", save_suffix=""):
    sns.set(style="whitegrid", font_scale=1.5)
    # sns.set(font=prop.get_name(), style="whitegrid", font_scale=1.0)
    # metrics = ['NAUC', '100 episodes', '50 episodes', '25 episodes', ]
    metrics = ['Dynamic 2 NAUC, isolated', 'Dynamic 1 NAUC, isolated', 'Dynamic 1 NAUC', 'Dynamic 2 NAUC',]

    suffix = save_suffix + f"{split}"
    if dataset_filter != "":
        suffix += "_" + dataset_filter

    ep100 = extract_metric(data, split, "auc", dataset_filter=dataset_filter, prefix='inverted')
    auc = extract_metric(data, split, "auc", dataset_filter=dataset_filter, prefix='default')
    ep50 = extract_metric(data, split, "auc_janus", dataset_filter=dataset_filter, prefix='inverted')
    ep25 = extract_metric(data, split, "auc_janus", dataset_filter=dataset_filter, prefix='default')

    auc = {alg: np.mean([auc[alg][ds] for ds in auc[alg]]) for alg in algorithms}
    ep25 = {alg: np.mean([ep25[alg][ds] for ds in ep25[alg]]) for alg in algorithms}
    ep50 = {alg: np.mean([ep50[alg][ds] for ds in ep50[alg]]) for alg in algorithms}
    ep100 = {alg: np.mean([ep100[alg][ds] for ds in ep100[alg]]) for alg in algorithms}

    performance = {
        alg: [auc[alg], ep100[alg], ep50[alg], ep25[alg]]
        for alg in algorithms
        # 'Algorithm A': [0.8, 0.6, 0.7, 0.9],
        # 'Algorithm B': [0.5, 0.7, 0.6, 0.8],
        # 'Algorithm C': [0.9, 0.8, 0.5, 0.6],
    }

    # Normalize data to range [0, 1]
    normalized_performance = {
        algo: np.array(values) / max(max(performance[algo]), 1)
        for algo, values in performance.items()
    }

    # Coordinates for quadrangle vertices (in polar coordinates)
    angles = np.linspace(np.pi / 4, 2 * np.pi + np.pi / 4, len(metrics), endpoint=False).tolist()
    angles += angles[:1]  # Ensure closed loop for polygons

    # Create plot
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_aspect('equal')

    # Plot each algorithm's quadrangle with different colors
    colors = dunno_colors
    for algo, edge_color, values in zip(algorithms, colors, normalized_performance.values()):
        values = np.append(values, values[0])  # Ensure closed loop
        x = values * np.cos(angles)
        y = values * np.sin(angles)

        # Create and add the polygon with edge color only
        points = np.c_[x, y]
        polygon = Polygon(
            points, closed=True, fill=False, edgecolor=edge_color, linewidth=2, label=algo
        )
        ax.add_patch(polygon)
        ax.plot(x, y, color=edge_color, linewidth=2, alpha=0.6)  # Outline with edge color

    # Add radial axis lines, numerical ticks, and labels
    tick_values = np.linspace(0.05, 0.5, 5)  # Tick levels
    for i, metric in enumerate(metrics):
        angle = angles[i]
        x_end = np.cos(angle)
        y_end = np.sin(angle)

        # Draw radial line
        ax.plot([0, x_end * 0.55], [0, y_end * 0.55], color='gray', linewidth=1, linestyle='dotted')

        # Add ticks and numerical labels along the axis
        for tick in tick_values:
            x_tick = tick * x_end
            y_tick = tick * y_end
            ax.scatter(x_tick, y_tick, color='gray', s=10)  # Mark the tick
            ax.text(
                x_tick * 1.02, y_tick * 1.02, f'{tick:.1f}',
                ha='center', va='center', fontsize=12, color='black'
            )

    # Add metric labels at the outer edges
    for i, metric in enumerate(metrics):
        angle = angles[i]
        x_label = 0.55 * np.cos(angle)
        y_label = 0.55 * np.sin(angle)
        ax.text(x_label, y_label, metric, ha='center', va='center', fontsize=14, weight='bold')

    # Add concentric circles for reference levels
    for tick in tick_values:
        circle = plt.Circle((0, 0), radius=tick, color='gray', fill=False, linestyle='dotted', alpha=0.7)
        ax.add_artist(circle)

    # Add legend and title
    # ax.legend(loc='upper right')
    plt.legend(loc='center left', bbox_to_anchor=(0.85, 0.5))
    ax.set_title(plot_name, fontsize=14)

    # Adjust layout to provide more space for the title
    # plt.subplots_adjust(top=1.0)  # Increase space at the top
    import matplotlib.image as mpimg

    # Load the background image
    background_image = mpimg.imread('./labs.jpg')  # Replace with your image path

    # Add the image to the background
    # ax.imshow(
    #     background_image,
    #     extent=[-1.0, 1.0, -1.0, 1.0],  # Adjust extent to cover the entire plot area
    #     aspect='auto',
    #     alpha=0.8,  # Transparency of the image
    #     zorder=0  # Ensure the image is in the background
    # )

    # Remove default axis
    ax.axis('off')
    plt.tight_layout()
    plt.savefig(f"out/janus_{suffix}.pdf", dpi=300, bbox_inches='tight')
    plt.close()
    # plt.show()


def coverage_plots(profiles_data, algorithms, split, metric, env, title, metric_name, legend=True):
    sns.set(style="whitegrid", font_scale=1.75)
    # sns.set(font=prop.get_name(), style="whitegrid", font_scale=1.5)
    algorithms = list(algorithms)
    extracted_metric = extract_metric(profiles_data, split, metric, dataset_filter="full")
    filtered_metric = {}
    for alg in algorithms:
        filtered_metric[alg] = {}
        for ds in extracted_metric[alg]:
            if ds.split('_')[0] != env:
                continue
            eddited_name = ds.replace('goals_', '').replace('data_', '').replace('_histories', '').replace('_',
                                                                                                           '-').replace(
                'k2d', 'K2D9').replace('DR-', 'DR9-')
            filtered_metric[alg][eddited_name] = extracted_metric[alg][ds]

    env = env.replace('k2d', 'K2D9')
    if env == "DR":
        env = "DR9"
    sizes = None
    if 'K2D' in env:
        sizes = ['250', '500', '1000']
    elif env == 'DR9':
        sizes = ['20', '40', '70']
    elif env == 'DR19':
        sizes = ['75', '150', '300']

    # Prepare data storage
    alg_data = {}
    for alg in algorithms:
        alg_data[alg] = {
            '5': {'means': [], 'stds': []},
            '1': {'means': [], 'stds': []}
        }
        for sz in sizes:
            # Get data for 5 coverage
            key5 = f"{env}-{sz}-5"
            values5 = filtered_metric[alg][key5]
            alg_data[alg]['5']['means'].append(np.mean(values5))
            alg_data[alg]['5']['stds'].append(np.std(values5))

            # Get data for 1 coverage
            key1 = f"{env}-{sz}-1"
            values1 = filtered_metric[alg][key1]
            alg_data[alg]['1']['means'].append(np.mean(values1))
            alg_data[alg]['1']['stds'].append(np.std(values1))

    # Plot parameters
    bar_width = 0.08
    n_algorithms = len(algorithms)
    x_base = np.arange(len(sizes))  # Main x positions for size groups
    colors = dunno_colors
    hatches = ['//', '\\\\']

    plt.figure(figsize=(12, 6))

    # Plot bars for each size group
    for size_idx, size in enumerate(sizes):
        # Calculate x positions for this group
        group_center = x_base[size_idx]
        start_pos = group_center - (n_algorithms * bar_width) + bar_width / 2

        for alg_idx, alg in enumerate(algorithms):
            color = colors[alg_idx]

            # Calculate x positions for this algorithm within the group
            alg_pos = start_pos + alg_idx * bar_width * 2

            # Plot 5 coverage bar
            plt.bar(alg_pos,
                    alg_data[alg]['5']['means'][size_idx],
                    width=bar_width,
                    color=color,
                    edgecolor='black',
                    # hatch=hatches[0],
                    alpha=0.8,
                    yerr=alg_data[alg]['5']['stds'][size_idx],
                    error_kw={'capsize': 3})

            # Plot 1 coverage bar
            plt.bar(alg_pos + bar_width,
                    alg_data[alg]['1']['means'][size_idx],
                    width=bar_width,
                    color=color,
                    edgecolor='black',
                    hatch=hatches[1],
                    alpha=0.8,
                    yerr=alg_data[alg]['1']['stds'][size_idx],
                    error_kw={'capsize': 3})

    # Create custom x-axis labels
    plt.xticks(x_base, sizes)
    plt.xlabel("Number of train targets")
    plt.ylabel(metric_name)
    plt.title(title)
    plt.grid(True, axis='y')

    # Create custom legend
    legend_elements = [
        *[Patch(facecolor=colors[i], edgecolor='black', label=alg) for i, alg in enumerate(algorithms)],
        Patch(facecolor='white', edgecolor='black', label='5 histories/target'),
        Patch(facecolor='white', edgecolor='black', hatch=hatches[1], label='1 histories/target')
    ]
    if legend:
        plt.legend(handles=legend_elements, loc='upper left')

    plt.tight_layout()
    plt.savefig(f"out/coverage_{env}_{metric}_{split}.pdf", dpi=300, bbox_inches='tight')
    plt.close()


def get_diffs(original_scores, order_scores, algorithms, data_filter=""):
    result = {}
    for alg in algorithms:
        result[alg] = {}
        for dataset in order_scores[alg]:
            if data_filter not in dataset:
                continue
            result[alg][dataset] = {
                k: (np.array((order_scores[alg][dataset][k])) - original_scores[alg]["_".join(dataset.split('_')[:-1])][k]) / original_scores[alg]["_".join(dataset.split('_')[:-1])][k]
                for k in order_scores[alg][dataset]
            }
    return result


def orders_plot(profiles_data, algorithms, split, metric, env, title, metric_name, legend=True):
    sns.set(style="whitegrid", font_scale=1.75)
    # sns.set(font=prop.get_name(), style="whitegrid", font_scale=1.5)
    algorithms = list(algorithms)
    extracted_metric = extract_metric(profiles_data, split, metric, dataset_filter="full")

    filtered_metric = {}
    for alg in algorithms:
        filtered_metric[alg] = {}
        for ds in extracted_metric[alg]:
            if ds.split('_')[0] != env:
                continue
            eddited_name = ds.replace('goals_', '').replace('data_', '').replace('_histories', '').replace('_',
                                                                                                           '-').replace(
                'k2d', 'K2D9').replace('DR-', 'DR9-')
            filtered_metric[alg][eddited_name] = extracted_metric[alg][ds]

    sizes = None
    if 'K2D' in env:
        sizes = ['250', '500', '1000']
    elif env == 'DR9':
        sizes = ['20', '40', '70']
    elif env == 'DR19':
        sizes = ['75', '150', '300']

    # Prepare data storage
    alg_data = {}
    # print(algorithms)
    for alg in algorithms:
        alg_data[alg] = {
            # 'Learning History': {'means': [], 'stds': []},
            'Random': {'means': [], 'stds': []},
            'Sorted Sample': {'means': [], 'stds': []},
        }
        for sz in sizes:
            # values = filtered_metric[alg][f"{env}-{sz}-1-default"]
            # alg_data[alg]['Learning History']['means'].append(np.mean(values))
            # alg_data[alg]['Learning History']['stds'].append(np.std(values))

            values = filtered_metric[alg][f"{env}-{sz}-1-random"]
            alg_data[alg]['Random']['means'].append(np.mean(values))
            alg_data[alg]['Random']['stds'].append(np.std(values))

            values = filtered_metric[alg][f"{env}-{sz}-1-sample"]
            alg_data[alg]['Sorted Sample']['means'].append(np.mean(values))
            alg_data[alg]['Sorted Sample']['stds'].append(np.std(values))

    ### START EDITING HERE
    # Plot parameters
    bar_width = 0.08
    n_algorithms = len(algorithms)
    x_base = np.arange(len(sizes))  # Main x positions for size groups
    colors = dunno_colors
    hatches = ['', '-', '//']  # No hatch, then patterns for each category
    hatches = ['-', '//']  # No hatch, then patterns for each category

    plt.figure(figsize=(12, 6))

    # print(alg_data)
    # Plot bars for each size group
    for size_idx, size in enumerate(sizes):
        # Calculate x positions for this group
        group_center = x_base[size_idx]
        # Calculate starting position to center all algorithm bars within the group
        # start_pos = group_center - (n_algorithms * 3 * bar_width) / 2
        start_pos = group_center - (n_algorithms * 2 * bar_width) / 2

        for alg_idx, alg in enumerate(algorithms):
            color = colors[alg_idx]
            # Position for current algorithm's first bar in this group
            # alg_pos = start_pos + alg_idx * 3 * bar_width
            alg_pos = start_pos + alg_idx * 2 * bar_width

            # Plot each category (History, Random, Sorted Sample)
            categories = ['Learning History', 'Random', 'Sorted Sample']
            categories = ['Random', 'Sorted Sample']
            for cat_idx, category in enumerate(categories):
                plt.bar(
                    alg_pos + cat_idx * bar_width,
                    alg_data[alg][category]['means'][size_idx],
                    width=bar_width,
                    color=color,
                    edgecolor='black',
                    hatch=hatches[cat_idx],
                    alpha=0.8,
                    yerr=alg_data[alg][category]['stds'][size_idx],
                    error_kw={'capsize': 3}
                )

    # Create custom x-axis labels
    plt.xticks(x_base, sizes)
    plt.xlabel("Number of train targets")
    plt.ylabel(metric_name)
    plt.title(title)
    plt.grid(True, axis='y')
    # plt.minorticks_on()
    ax = plt.gca()
    from matplotlib.ticker import MultipleLocator
    ax.yaxis.set_major_locator(MultipleLocator(0.05))
    # ax.yaxis.set_minor_locator(MultipleLocator(0.05))

    # Create custom legend: algorithms (colors) and categories (hatches)
    legend_elements = [
        *[Patch(facecolor=colors[i], edgecolor='black', label=alg)
          for i, alg in enumerate(algorithms)],
        # Patch(facecolor='white', edgecolor='black',
        #       hatch=hatches[0], label='Learning History'),
        Patch(facecolor='white', edgecolor='black',
              hatch=hatches[0], label='Random'),
        Patch(facecolor='white', edgecolor='black',
              hatch=hatches[1], label='Sorted Sample')
    ]
    ### END EDITING HERE
    if legend:
        plt.legend(handles=legend_elements, loc='upper left')

    plt.tight_layout()
    plt.savefig(f"out/order_{env}_{metric}_{split}.pdf", dpi=300, bbox_inches='tight')
    plt.close()


algorithms = ['AD', 'IC-DQN', 'IC-CQL', 'IC-IQL']
with open('bin/scores.pickle', 'rb') as handle:
    pre_profiles_data = pickle.load(handle)

profiles_data = {}
for alg in pre_profiles_data:
    if alg == "AD":
        profiles_data[alg] = pre_profiles_data[alg]
    else:
        profiles_data["IC-" + alg] = pre_profiles_data[alg]
#
# plot_metrics(profiles_data, algorithms, "test", "auc", "Test NAUC")
# plot_metrics(profiles_data, algorithms, "train", "auc", "Train NAUC")
# plot_metrics(profiles_data, algorithms, "test", "mean_return", "Test 100 episodes mean return", max_val=2)
# plot_metrics(profiles_data, algorithms, "train", "mean_return", "Train 100 episodes mean return", max_val=2)
# plot_metrics(profiles_data, algorithms, "test", "auc", "Test NAUC, early datasets", dataset_filter='early')
# plot_metrics(profiles_data, algorithms, "test", "auc", "Test NAUC, mid datasets", dataset_filter='mid')
# plot_metrics(profiles_data, algorithms, "test", "auc", "Test NAUC, late datasets", dataset_filter='late')
# plot_metrics(profiles_data, algorithms, "test", "auc", "Test NAUC, complete datasets", dataset_filter='full')
# #
# plot_metrics(profiles_data, algorithms, "test", "mean_return", "Test 100 episodes mean return, early datasets", dataset_filter='early', max_val=2)
# plot_metrics(profiles_data, algorithms, "test", "mean_return", "Test 100 episodes mean return, mid datasets", dataset_filter='mid', max_val=2)
# plot_metrics(profiles_data, algorithms, "test", "mean_return", "Test 100 episodes mean return, late datasets", dataset_filter='late', max_val=2)
# plot_metrics(profiles_data, algorithms, "test", "mean_return", "Test 100 episodes mean return, complete datasets", dataset_filter='full', max_val=2)


# get_table(profiles_data, algorithms, "test", "auc", dataset_filter="full")
# get_table(profiles_data, algorithms, "test", "auc", dataset_filter="early")
# get_table(profiles_data, algorithms, "test", "auc", dataset_filter="mid")
# get_table(profiles_data, algorithms, "test", "auc", dataset_filter="late")
#
# get_table(profiles_data, algorithms, "test", "mean_return", dataset_filter="full")
# get_table(profiles_data, algorithms, "test", "mean_return", dataset_filter="early")
# get_table(profiles_data, algorithms, "test", "mean_return", dataset_filter="mid")
# get_table(profiles_data, algorithms, "test", "mean_return", dataset_filter="late")
#
# get_table(profiles_data, algorithms, "train", "auc", dataset_filter="full")
# get_table(profiles_data, algorithms, "train", "auc", dataset_filter="early")
# get_table(profiles_data, algorithms, "train", "auc", dataset_filter="mid")
# get_table(profiles_data, algorithms, "train", "auc", dataset_filter="late")
#
# get_table(profiles_data, algorithms, "train", "mean_return", dataset_filter="full")
# get_table(profiles_data, algorithms, "train", "mean_return", dataset_filter="early")
# get_table(profiles_data, algorithms, "train", "mean_return", dataset_filter="mid")
# get_table(profiles_data, algorithms, "train", "mean_return", dataset_filter="late")

# plot_quadrangle(profiles_data, algorithms, "test", "All datasets test performance")
# plot_quadrangle(profiles_data, algorithms, "train", "All datasets train performance")
# #
# plot_quadrangle(profiles_data, algorithms, "test", "Complete datasets test performance", "full")
# plot_quadrangle(profiles_data, algorithms, "test", "Early datasets test performance", "early")
# plot_quadrangle(profiles_data, algorithms, "test", "Mid datasets test performance", "mid")
# plot_quadrangle(profiles_data, algorithms, "test", "Late datasets test performance", "late")
# #
# coverage_plots(profiles_data, algorithms, "test", "auc", "k2d", "Various K2D9 complete datasets coverage", "NAUC", legend=False)
# coverage_plots(profiles_data, algorithms, "test", "auc", "K2D13", "Various K2D13 complete datasets coverage", "NAUC", legend=False)
# coverage_plots(profiles_data, algorithms, "test", "auc", "DR", "Various DR9 complete datasets coverage", "NAUC")
# coverage_plots(profiles_data, algorithms, "test", "auc", "DR19", "Various DR19 complete datasets coverage", "NAUC", legend=False)

# print([k for k in profiles_data['AD']])
#
algorithms = ['AD', 'IC-DQN', 'IC-CQL', 'IC-IQL']
with open('bin/janus_scores.pickle', 'rb') as handle:
    janus_data = pickle.load(handle)
with open('bin/janus_aug_scores.pickle', 'rb') as handle:
    janus_aug_data = pickle.load(handle)

janus_auc = extract_metric(janus_data, 'test', 'auc')
janus_aug_auc = extract_metric(janus_aug_data, 'test', 'auc')

# janus_aug_diff = {}
# for alg in algorithms:
#     janus_aug_diff[alg] = {}
#     for ds in janus_aug_auc[alg]:
#         janus_aug_diff[alg][ds] = np.array(janus_aug_auc[alg][ds]) - janus_auc[alg]['Janus19_300_5']
# for alg, color in zip(algorithms, ['blue', 'orange', 'green', 'red']):
#     means = []
#     stds = []
#     for aug in [1, 2, 4, 8, 16, 32]:
#         means.append(np.mean(janus_aug_diff[alg][f'Janus19_300_5_{aug}']))
#         stds.append(np.std(janus_aug_diff[alg][f'Janus19_300_5_{aug}']))
#     plt.plot(means, label=alg, color=color)
#     plt.fill_between([0, 1, 2, 3, 4, 5], np.array(means) - stds, np.array(means) + stds, color=color, alpha=0.3)
# plt.legend()
# plt.xticks([0, 1, 2, 3, 4, 5], [1, 2, 4, 8, 16, 32])
# plt.ylabel("NAUC difference")
# plt.xlabel("Number of mixture histories")
# plt.show()

# print(janus_data)
plot_janus(janus_data, algorithms, "test", "Janus complete datasets test performance", "full", save_suffix="janus_")
# get_table(janus_data, algorithms, "test", "auc", dataset_filter="full", save_suffix="janus")
# get_table(janus_data, algorithms, "test", "auc", dataset_filter="full", save_suffix="janus", prefix="inverted")
# get_table(janus_data, algorithms, "test", "auc", dataset_filter="full", save_suffix="janus", prefix="default")
# get_table(janus_data, algorithms, "test", "auc_janus", dataset_filter="full", save_suffix="janus", prefix="inverted")
# get_table(janus_data, algorithms, "test", "auc_janus", dataset_filter="full", save_suffix="janus", prefix="default")


#
algorithms = ['AD', 'IC-DQN', 'IC-CQL', 'IC-IQL']
# algorithms = ['AD', 'IC-DQN', 'IC-CQL']
with open('bin/order_scores.pickle', 'rb') as handle:
    order_data = pickle.load(handle)

diff_scores = get_diffs(profiles_data, order_data, algorithms, data_filter="_1_")
#
# for alg in algorithms:
#     for dataset in [
#         'DR19_75_1',
#         'DR19_150_1',
#         'DR19_300_1',
#         'K2D13_250_1',
#         'K2D13_500_1',
#         'K2D13_1000_1',
#     ]:
#         order_data[alg][dataset] = profiles_data[alg][dataset]

# get_table(order_data, algorithms, "test", "auc", save_suffix="order")
# # get_table(order_data, algorithms, "test", "auc", dataset_filter="random", save_suffix="order")
# # get_table(order_data, algorithms, "test", "auc", dataset_filter="sample", save_suffix="order")
# # get_table(order_data, algorithms, "test", "auc", dataset_filter="sorted", save_suffix="order")
#
# orders_plot(order_data, algorithms, "test", "auc", "DR19", "Various types of data ordering for DR19", "NAUC", legend=True)
# orders_plot(order_data, algorithms, "test", "auc", "K2D13", "Various types of data ordering for K2D13", "NAUC", legend=False)
# # get_table(diff_scores, algorithms, "test", "auc", dataset_filter="random", save_suffix="order")
# # get_table(diff_scores, algorithms, "test", "auc", dataset_filter="sorted", save_suffix="order")
# #
# plot_metrics(order_data, algorithms, "test", "auc", "Test NAUC with random data order", dataset_filter="random")
# plot_metrics(order_data, algorithms, "test", "auc", "Test NAUC with sorted sample data order", dataset_filter="sample")
# plot_metrics(order_data, algorithms, "test", "mean_return", "Test final performance with random data order", dataset_filter="random", max_val=2)
# plot_metrics(order_data, algorithms, "test", "mean_return", "Test final performance with sorted sample order", dataset_filter="sample", max_val=2)

algorithms = ['AD', 'IC-TD3', 'IC-TD3+BC', 'IC-IQL']
with open('bin/cont_scores.pickle', 'rb') as handle:
    cont_data = pickle.load(handle)
# print(cont_data)

# get_table(cont_data, algorithms, "test", "auc", save_suffix="cont", dataset_filter="full")
# get_table(cont_data, algorithms, "test", "auc", save_suffix="cont", dataset_filter="early")
# get_table(cont_data, algorithms, "test", "auc", save_suffix="cont", dataset_filter="mid")
# get_table(cont_data, algorithms, "test", "auc", save_suffix="cont", dataset_filter="late")
#
# get_table(cont_data, algorithms, "test", "mean_return", save_suffix="cont", dataset_filter="full")
# get_table(cont_data, algorithms, "test", "mean_return", save_suffix="cont", dataset_filter="early")
# get_table(cont_data, algorithms, "test", "mean_return", save_suffix="cont", dataset_filter="mid")
# get_table(cont_data, algorithms, "test", "mean_return", save_suffix="cont", dataset_filter="late")
#
# get_table(cont_data, algorithms, "test", "mean_return_quarter", save_suffix="cont", dataset_filter="full")
# get_table(cont_data, algorithms, "test", "mean_return_quarter", save_suffix="cont", dataset_filter="early")
# get_table(cont_data, algorithms, "test", "mean_return_quarter", save_suffix="cont", dataset_filter="mid")
# get_table(cont_data, algorithms, "test", "mean_return_quarter", save_suffix="cont", dataset_filter="late")

# plot_metrics(cont_data, algorithms, "test", "auc", "Test NAUC continuous environments, early datasets", dataset_filter="early", suffix="cont", max_val=1.2)
# plot_metrics(cont_data, algorithms, "test", "auc", "Test NAUC continuous environments, mid datasets", dataset_filter="mid", suffix="cont", max_val=1.2)
# plot_metrics(cont_data, algorithms, "test", "auc", "Test NAUC continuous environments, late datasets", dataset_filter="late",suffix="cont", max_val=1.2)
# plot_metrics(cont_data, algorithms, "test", "auc", "Test NAUC continuous environments, complete datasets", dataset_filter="full", suffix="cont", max_val=1.2)
# plot_metrics(cont_data, algorithms, "test", "auc", "Test NAUC continuous environments, complete datasets", suffix="cont", max_val=1.25)
# plot_quadrangle(cont_data, algorithms, "test", "All continuous datasets test performance", save_suffix="cont", metrics=['NAUC', 'Return after 4 episodes', 'Return after 2 episodes', 'Return after 1 episode',])

def average_env(data, env):
    algs = [alg for alg in data]
    scores = {alg: [] for alg in algs}
    for alg in algs:
        for ds in data[alg]:
            if env in ds.replace('goals_', '').replace('data_', '').replace('_histories', '').replace('_', '-').replace('k2d', 'K2D9').replace('DR-', 'DR9-'):
                scores[alg] += data[alg][ds]
    for alg in algs:
        # print(len(scores[alg]))
        scores[alg] = np.mean(scores[alg])
    return scores


def plot_avg_perfs(data, methods, title, save_path, legend_font_size=20):
    sns.set(style="whitegrid", font_scale=2.0)
    environments = list(data.keys())
    # Plot parameters
    bar_width = 0.15  # Width of each bar
    spacing = 0.02  # Space between bars within a group
    group_width = bar_width * len(methods) + spacing * (len(methods) - 1)
    x = np.arange(len(environments))  # Positions for environment groups

    # Create plot
    plt.figure(figsize=(15, 7))
    colors = dunno_colors

    # Plot bars for each method in all environments
    for i, method in enumerate(methods):
        # Calculate position for each bar within its environment group
        offset = (i - (len(methods) - 1) / 2) * (bar_width + spacing)
        values = [data[env][method] for env in environments]
        plt.bar(x + offset, values, bar_width, label=method, color=colors[i])

    # Formatting
    plt.ylabel('Average test NAUC over all datasets', fontsize=20)
    plt.title(title, fontsize=30)

    plt.xticks(x, environments, rotation=0, fontsize=20)
    for tick in plt.gca().get_xticklabels():
        # print(tick)
        if tick.get_text() == "Overall":
            tick.set_fontweight('bold')  # Make specific labels bold

    plt.legend(fontsize=legend_font_size, loc='upper left')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    # plt.show()
    plt.savefig(f"out/{save_path}.pdf", dpi=300, bbox_inches='tight')
    plt.close()


avg_scores = {}
discrete_auc = extract_metric(profiles_data, "test", "auc")
for en in ['DR9', 'DR19', 'K2D9', 'K2D13']:
    avg_scores[en] = average_env(discrete_auc, en)
janus_auc = extract_metric(janus_data, "test", "auc")
avg_scores['Janus'] = average_env(janus_auc, '')
avg_scores['XLand-Minigrid'] = {'AD': 0.22, 'IC-CQL': 0.40, 'IC-DQN': 0.42, 'IC-IQL': 0.46}

overall = {}
for alg in ['AD', 'IC-DQN', 'IC-CQL', 'IC-IQL']:
    scores = []
    for env in avg_scores:
        scores.append(avg_scores[env][alg])
    overall[alg] = np.mean(scores)
avg_scores['Overall'] = overall
print(avg_scores['Overall'])
# print(avg_scores)
plot_avg_perfs(avg_scores, ['AD', 'IC-DQN', 'IC-CQL', 'IC-IQL'], 'Discrete environments performance', 'overall_discrete')

avg_scores_cont = {}
cont_auc = extract_metric(cont_data, "test", "auc")
for en in ['HCV', 'ANT', 'HPP', 'WLP']:
    avg_scores_cont[en] = average_env(cont_auc, en)
overall = {}
for alg in ['AD', 'IC-TD3', 'IC-TD3+BC', 'IC-IQL']:
    scores = []
    for env in avg_scores_cont:
        scores.append(avg_scores_cont[env][alg])
    overall[alg] = np.mean(scores)
avg_scores_cont['Overall'] = overall
print(avg_scores_cont['Overall'])
plot_avg_perfs(avg_scores_cont, ['AD', 'IC-TD3', 'IC-TD3+BC', 'IC-IQL'], 'Continuous environments performance', 'overall_cont', legend_font_size=17)