import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Rectangle
import os
import numpy as np
import pandas as pd
from load_all_result import load_all_result
from make_title import make_title
from plot_n_op import create_n_op_plot, save_legend_only_n_op
from plot_n_time import create_n_time_plot, save_legend_only_n_time
from plot_epsilon_op import create_epsilon_op_plot, save_legend_only_epsilon_op
from plot_hist import create_histogram, load_data
import argparse


BLOCK_WIDTH = 6.0
OP_HEIGHT = 5.5
HIST_HEIGHT = 2.5
BLOCK_HEIGHT = OP_HEIGHT + HIST_HEIGHT

def merge_plots(block_layout, output_path, df, vis_setting, mode="n_op"):
    rows = len(block_layout)
    cols = max(len(row) for row in block_layout)
    
    total_rows = rows * 2
    
    fig, axes = plt.subplots(
        total_rows,
        cols,
        figsize=(BLOCK_WIDTH * cols, BLOCK_HEIGHT * rows),
        gridspec_kw={
            'height_ratios': [OP_HEIGHT, HIST_HEIGHT] * rows,
            'hspace': 0.0,
            'wspace': 0.0
        },
        constrained_layout=True
    )
    
    if total_rows == 1:
        axes = axes.reshape(1, -1)
    elif cols == 1:
        axes = axes.reshape(-1, 1)
    
    for i, row in enumerate(block_layout):
        for j, dist in enumerate(row):
            if dist:  # if the cell is empty, skip
                ylabel = True if j == 0 else False
                top_ax = axes[i*2, j] if total_rows > 1 and cols > 1 else axes[i*2] if cols == 1 else axes[i*2, j]
                hist_ax = axes[i*2+1, j] if total_rows > 1 and cols > 1 else axes[i*2+1] if cols == 1 else axes[i*2+1, j]

                if mode == "n_op":
                    create_n_op_plot(df, dist, vis_setting, ax=top_ax, ylabel=ylabel)
                elif mode == "epsilon_op":
                    create_epsilon_op_plot(df, dist, vis_setting, ax=top_ax, ylabel=ylabel)
                elif mode == "n_time":
                    create_n_time_plot(df, dist, vis_setting, ax=top_ax, ylabel=ylabel)
                elif mode == "n_time_profile_pcf":
                    create_n_time_profile_pcf_plot(df, dist, ax=top_ax)
                else:
                    raise ValueError(f"Unknown mode: {mode}")
                
                hist_data_path = f"result/fig/hist/tmp/{dist.replace('/', '_')}_n100000_seed1.txt"
                if os.path.exists(hist_data_path):
                    hist_data = load_data(hist_data_path)
                    create_histogram(hist_data, ax=hist_ax, ylabel=ylabel)
                else:
                    print(f"No data for {hist_data_path}")
                
                # display the distribution name (on the top of the top row)
                top_ax.set_title(make_title(dist), fontsize=36)
                
            else:
                # if the cell is empty, hide the axis
                n_op_ax = axes[i*2, j] if total_rows > 1 and cols > 1 else axes[i*2] if cols == 1 else axes[i*2, j]
                hist_ax = axes[i*2+1, j] if total_rows > 1 and cols > 1 else axes[i*2+1] if cols == 1 else axes[i*2+1, j]
                n_op_ax.set_visible(False)
                hist_ax.set_visible(False)
    
    # hide extra axes
    for i in range(total_rows):
        for j in range(cols):
            if i >= total_rows or j >= cols:
                continue
            if i // 2 >= len(block_layout) or j >= len(block_layout[i // 2]):
                ax = axes[i, j] if total_rows > 1 and cols > 1 else axes[i] if cols == 1 else axes[i, j]
                ax.set_visible(False)
    
    # plt.tight_layout()
    
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    plt.savefig(output_path, bbox_inches='tight')
    plt.close()
    print(f"Saved the merged plot: {output_path}")

def SORT2MARKER(sort):
    if "_approx" in sort:
        return SORT2MARKER(sort.split("_approx")[0])
    SORT2MARKER_DICT = {
        "pcf": "o",
        "quick_sort": "^",
        "std": "^",
        "radix": "D",
        "boost": "s",
        "ips4o": "p",
        "learned_AniKristo": "x",
        "bls": "*",
        "uls": "<",
        "ls21": ">",
        "learned_sort_using_learned_index_binary_search": "v",
        "learned_sort_using_learned_index_btree": "s",
        "learned_sort_using_learned_index_espc": "^",
        "learned_sort_using_learned_index_rmi": "d",
        "learned_sort_using_learned_index_pgm": "x",
    }
    if sort not in SORT2MARKER_DICT:
        print(f"SORT2MARKER: {sort} not found")
        return "o"
    return SORT2MARKER_DICT[sort]

def SORT2COLOR(sort):
    if "_approx" in sort:
        return SORT2COLOR(sort.split("_approx")[0])
    SORT2COLOR_DICT = {
        "pcf": "tab:cyan",
        "quick_sort": "tab:orange",
        "std": "tab:orange",
        "radix": "tab:gray",
        "boost": "black",
        "ips4o": "tab:olive",
        "learned_AniKristo": "tab:pink",
        "bls": "goldenrod",
        "uls": "teal",
        "ls21": "navy",
        "learned_sort_using_learned_index_binary_search": "tab:purple",
        "learned_sort_using_learned_index_btree": "tab:blue",
        "learned_sort_using_learned_index_espc": "tab:red",
        "learned_sort_using_learned_index_rmi": "tab:brown",
        "learned_sort_using_learned_index_pgm": "tab:green",
    }
    if sort not in SORT2COLOR_DICT:
        print(f"SORT2COLOR: {sort} not found")
        return "tab:gray"
    return SORT2COLOR_DICT[sort]

def SORT2LINES(sort):
    SORT2LINES_DICT = {
        "pcf": "-.",
        "quick_sort": "--",
        "std": "--",
        "radix": ":",
        "boost": "-.",
        "ips4o": "-.",
        "learned_AniKristo": "-.",
        "bls": "-.",
        "uls": "-.",
        "ls21": "-.",
        "learned_sort_using_learned_index_binary_search": "-",
        "learned_sort_using_learned_index_btree": "-",
        "learned_sort_using_learned_index_espc": "-",
        "learned_sort_using_learned_index_rmi": "-",
        "learned_sort_using_learned_index_pgm": "-",
        "learned_sort_using_learned_index_binary_search_approx_8": "--",
        "learned_sort_using_learned_index_btree_approx_8": "--",
        "learned_sort_using_learned_index_espc_approx_8": "--",
        "learned_sort_using_learned_index_rmi_approx_8": "--",
        "learned_sort_using_learned_index_pgm_approx_8": "--",
        "learned_sort_using_learned_index_binary_search_approx_32": "-.",
        "learned_sort_using_learned_index_btree_approx_32": "-.",
        "learned_sort_using_learned_index_espc_approx_32": "-.",
        "learned_sort_using_learned_index_rmi_approx_32": "-.",
        "learned_sort_using_learned_index_pgm_approx_32": "-.",
        "learned_sort_using_learned_index_binary_search_approx_128": ":",
        "learned_sort_using_learned_index_btree_approx_128": ":",
        "learned_sort_using_learned_index_espc_approx_128": ":",
        "learned_sort_using_learned_index_rmi_approx_128": ":",
        "learned_sort_using_learned_index_pgm_approx_128": ":",
        "learned_sort_using_learned_index_btree_approx": "--",
        "learned_sort_using_learned_index_pgm_approx": ":",
    }
    if sort not in SORT2LINES_DICT:
        print(f"SORT2LINES: {sort} not found")
        return "-"
    return SORT2LINES_DICT[sort]

def SORT2MARKERSZIE(sort):
    if "_approx" in sort:
        return SORT2MARKERSZIE(sort.split("_approx")[0])
    SORT2MARKERSZIE_DICT = {
        "pcf": 6.5,
        "quick_sort": 6.5,
        "std": 6.5,
        "radix": 6.5,
        "boost": 6.5,
        "ips4o": 6.5,
        "learned_AniKristo": 6.5,
        "bls": 6.5,
        "uls": 6.5,
        "ls21": 6.5,
        "learned_sort_using_learned_index_binary_search": 6.5,
        "learned_sort_using_learned_index_btree": 6.5,
        "learned_sort_using_learned_index_espc": 6.5,
        "learned_sort_using_learned_index_rmi": 6.5,
        "learned_sort_using_learned_index_pgm": 6.5,
    }
    if sort not in SORT2MARKERSZIE_DICT:
        print(f"SORT2MARKERSZIE: {sort} not found")
        return 6.5
    return SORT2MARKERSZIE_DICT[sort]

def merge_plot_n_op(result_df_all, output_path, layout):
    result_df = result_df_all.copy()

    sort_list = [
        {"sort": "learned_sort_using_learned_index_binary_search", "label": "Index Sort"},
        {"sort": "learned_sort_using_learned_index_btree", "label": "w/ B-tree"},
        {"sort": "learned_sort_using_learned_index_espc", "label": "w/ ESPC-Index"},
        {"sort": "learned_sort_using_learned_index_rmi", "label": "w/ RMI-Index"},
        {"sort": "learned_sort_using_learned_index_pgm", "label": "w/ PGM-Index"},
        {"sort": "learned_sort_using_learned_index_binary_search_approx_8", "label": "Index Sort (8)"},
        {"sort": "learned_sort_using_learned_index_btree_approx_8", "label": "w/ B-tree (8)"},
        {"sort": "learned_sort_using_learned_index_espc_approx_8", "label": "w/ ESPC-Index (8)"},
        {"sort": "learned_sort_using_learned_index_rmi_approx_8", "label": "w/ RMI-Index (8)"},
        {"sort": "learned_sort_using_learned_index_pgm_approx_8", "label": "w/ PGM-Index (8)"},
        {"sort": "learned_sort_using_learned_index_binary_search_approx_32", "label": "Index Sort (32)"},
        {"sort": "learned_sort_using_learned_index_btree_approx_32", "label": "w/ B-tree (32)"},
        {"sort": "learned_sort_using_learned_index_espc_approx_32", "label": "w/ ESPC-Index (32)"},
        {"sort": "learned_sort_using_learned_index_rmi_approx_32", "label": "w/ RMI-Index (32)"},
        {"sort": "learned_sort_using_learned_index_pgm_approx_32", "label": "w/ PGM-Index (32)"},
        {"sort": "learned_sort_using_learned_index_binary_search_approx_128", "label": "Index Sort (128)"},
        {"sort": "learned_sort_using_learned_index_btree_approx_128", "label": "w/ B-tree (128)"},
        {"sort": "learned_sort_using_learned_index_espc_approx_128", "label": "w/ ESPC-Index (128)"},
        {"sort": "learned_sort_using_learned_index_rmi_approx_128", "label": "w/ RMI-Index (128)"},
        {"sort": "learned_sort_using_learned_index_pgm_approx_128", "label": "w/ PGM-Index (128)"},
    ]

    vis_setting = {"sorts": {}}
    for sort in sort_list:
        vis_setting["sorts"][sort["sort"]] = {
            "label": sort["label"],
            "marker": SORT2MARKER(sort["sort"]),
            "markersize": SORT2MARKERSZIE(sort["sort"]),
            "linestyle": SORT2LINES(sort["sort"]),
            "color": SORT2COLOR(sort["sort"]),
        }

    mask = result_df["sort"].astype(str).str.contains("_approx", na=False)

    def fmt_epsilon(eps):
        if pd.isna(eps):
            return None
        eps = float(eps)
        return str(int(eps)) if eps.is_integer() else f"{eps:g}"

    result_df.loc[mask, "sort"] = (
        result_df.loc[mask]
        .apply(lambda r: f"{r['sort']}_{fmt_epsilon(r['epsilon'])}" if fmt_epsilon(r["epsilon"]) is not None else r["sort"], axis=1)
    )

    merge_plots(layout, output_path, result_df, vis_setting, mode="n_op")
    save_legend_only_n_op(vis_setting, output_path.replace(".pdf", "_legend.pdf"))

def merge_plot_epsilon_op(result_df_all, output_path, layout):
    result_df = result_df_all.copy()

    sort_list = [
        {"sort": "learned_sort_using_learned_index_btree_approx", "label": "w/ B-tree"},
        {"sort": "learned_sort_using_learned_index_pgm_approx", "label": "w/ PGM-Index"},
    ]

    cond_and_vis_settings = []
    for sort in sort_list:
        cond_and_vis_settings.append({
            "cond": {
                "sort": sort["sort"],
                "n": 10000000
            },
            "vis_setting": {
                "label": sort["label"],
                "marker": SORT2MARKER(sort["sort"]),
                "markersize": SORT2MARKERSZIE(sort["sort"]),
                "linestyle": SORT2LINES(sort["sort"]),
                "color": SORT2COLOR(sort["sort"]),
            }
        })

    merge_plots(layout, output_path, result_df, cond_and_vis_settings, mode="epsilon_op")
    save_legend_only_epsilon_op(cond_and_vis_settings, output_path.replace(".pdf", "_legend.pdf"))


def merge_plot_n_time(result_df_all, output_path, layout):
    result_df = result_df_all.copy()

    sort_list = [
        {"sort": "std", "label": "std::sort"},
        {"sort": "radix", "label": "Radix Sort"},
        {"sort": "learned_sort_using_learned_index_binary_search", "label": "Index Sort"},
        {"sort": "boost", "label": "boost::sort::spreadsort::float_sort"},
        {"sort": "ips4o", "label": "IS4o"},
        {"sort": "learned_AniKristo", "label": "Learned Sort 2.0"},
        {"sort": "pcf", "label": "PCF Learned Sort"},
        {"sort": "bls", "label": "BLS"},
        {"sort": "uls", "label": "ULS"},
        {"sort": "ls21", "label": "Learned Sort 2.1"},
        {"sort": "learned_sort_using_learned_index_btree", "label": "w/ B-tree"},
        {"sort": "learned_sort_using_learned_index_espc", "label": "w/ ESPC-Index"},
        {"sort": "learned_sort_using_learned_index_rmi", "label": "w/ RMI-Index"},
        {"sort": "learned_sort_using_learned_index_pgm", "label": "w/ PGM-Index"},
    ]

    vis_setting = {"sorts": {}}
    for sort in sort_list:
        vis_setting["sorts"][sort["sort"]] = {
            "label": sort["label"],
            "marker": SORT2MARKER(sort["sort"]),
            "markersize": SORT2MARKERSZIE(sort["sort"]),
            "linestyle": SORT2LINES(sort["sort"]),
            "color": SORT2COLOR(sort["sort"]),
        }

    merge_plots(layout, output_path, result_df, vis_setting, mode="n_time")
    save_legend_only_n_time(vis_setting, output_path.replace(".pdf", "_legend.pdf"))


LAYOUT = [
    [
        "uniform", "normal", "exponential", "lognormal", "data/chic_start"
    ],
    [
        "uniform_shift", "normal_shift", "exponential_shift", "lognormal_shift", "data/chic_tot"
    ],
    [
        "data/nyc_pickup", "data/nyc_dist", "data/nyc_tot", "data/sof_hum", "data/sof_press"
    ],
    [
        "data/sof_temp", "data/wiki_ts_200M_uint64", "data/osm_cellids_800M_uint64", "data/books_800M_uint64", "data/fb_200M_uint64",
    ],
    [
        "data/stks_vol", "data/stks_open", "data/stks_date", "data/stks_low", "", 
    ]
]

LAYOUT_SMALL = [
    [
        "normal", "normal_shift", "data/wiki_ts_200M_uint64", "data/osm_cellids_800M_uint64"
    ]
]

if __name__ == "__main__":
    result_df = load_all_result("result", "sort_op")
    result_df = result_df[(result_df["n"] >= 10000) & (result_df["n"] <= 10000000)]
    output_path = f"result/fig/merged/n_op_index2sort.pdf"
    merge_plot_n_op(result_df, output_path, layout=LAYOUT)
    output_path = f"result/fig/merged/n_op_index2sort_small.pdf"
    merge_plot_n_op(result_df, output_path, layout=LAYOUT_SMALL)

    result_df = load_all_result("result", "sort_time")
    result_df = result_df[(result_df["n"] >= 10000) & (result_df["n"] <= 10000000)]
    output_path = f"result/fig/merged/n_time_index2sort.pdf"
    merge_plot_n_time(result_df, output_path, layout=LAYOUT)
    output_path = f"result/fig/merged/n_time_index2sort_small.pdf"
    merge_plot_n_time(result_df, output_path, layout=LAYOUT_SMALL)

    result_df = load_all_result("result", "sort_op")
    result_df = result_df[(result_df["n"] >= 10000) & (result_df["n"] <= 10000000)]
    output_path = f"result/fig/merged/epsilon_op_index2sort.pdf"
    merge_plot_epsilon_op(result_df, output_path, layout=LAYOUT)
    output_path = f"result/fig/merged/epsilon_op_index2sort_small.pdf"
    merge_plot_epsilon_op(result_df, output_path, layout=LAYOUT_SMALL)
