import os
import math
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from typing import List

from utils._utils import methods, graph_types, graph_sizes, graph_sizes_to_nodes, graph_densities,\
    scenario_configs, inference_scenarios, tunable_parameters

# plt.rcParams['text.usetex'] = True
sns.set_style('whitegrid')

# ----------------------- Utilities ----------------------- #
def best_metric(series, metric):
    if metric == "shd":
        return series.min()
    elif metric == "f1":
        return series.max()
    elif metric == "d_top":
        return series.min()


def arg_best_shd(series):
    return series.argmin()


def get_logs_dir(scenario):
    task = 'inference' if scenario != "linear" else "standardized"
    # task='old_f1_inference'
    return f"/home/ec2-user/causal-benchmark/tmp/logs/{task}"


def make_fig_title(noise, scenario_param):
    noise_map = {
        "gauss" : 'Gaussian',
        'nonlin_weak' : "random transform (r'$\theta=1$')",
        'nonlin_mid' : "random transform (r'$\theta=2$')",
        'nonlin_strong' : "random transform (r'$\theta=3$')",
    }

    title = "ANM with " + noise_map[noise] + " noise "
    if scenario_param == "timino":
        title += f"under i.i.d violation"
    elif scenario_param.startswith("measure_err"):
        title += f"under measure error (r'$\theta={scenario_param.split('_')[1]}$)"
    elif scenario_param.startswith("confounded"):
        title += f"with latent confounders (r'$\theta={scenario_param.split('_')[1]}$)"
    elif scenario_param.startswith("linear"):
        title += f"with linear mechanisms (r'$\theta={scenario_param.split('_')[1]}$)"
    elif scenario_param.startswith("unfaithul"):
        title += f"with unfaithful distribution (r'$\theta={scenario_param.split('_')[1]}$)"

    return title


def make_ax_title(graph_type, density):
    if graph_type  == "FC":
        return "Fully connected"
    elif graph_type  == "GRP":
        return "Gaussian Random Partition"
    elif graph_type  == "ER":
        return f"Erdos Renyi {density}"
    elif graph_type  == "SF":
        return f"Scale free {density}"

def make_ax_title_with_size(graph_type, density, size):
    if graph_type  == "FC":
        return f"FC-{size}"
    elif graph_type  == "GRP":
        return f"GRP-{size}"
    elif graph_type  == "ER":
        return f"ER-{size} {density}"
    elif graph_type  == "SF":
        return f"SF-{size} {density}"

def set_box_color(bp, color):
    plt.setp(bp['boxes'], color=color)
    plt.setp(bp['whiskers'], color=color)
    plt.setp(bp['caps'], color=color)
    plt.setp(bp['medians'], color=color)


def adjacent_values(vals, q1, q3):
    upper_adjacent_value = q3 + (q3 - q1) * 1.5
    upper_adjacent_value = np.clip(upper_adjacent_value, q3, vals[-1])

    lower_adjacent_value = q1 - (q3 - q1) * 1.5
    lower_adjacent_value = np.clip(lower_adjacent_value, vals[0], q1)
    return lower_adjacent_value, upper_adjacent_value
    
def update_params_transfer_records(
        method : str,
        graph_type : str,
        scenario : str,
        scenario_param : str,
        noise : str,
        sample_size : str,
        density : str,
        graph_size : str,
        metric : str,
        seaborn_data : List,
        vanilla_scenario : str = None,
        vanilla_scenario_param : str = None,
        default_param = 0.05
):
    """Update records for hyperparams transfer comparison.
    NOTE: vanilla is w.r.t. model assumptions.

    The method reads the logs from a pandas DataFrame corresponding to a scenario,
    and slices the records based on 
    - Noise type
    - Sample size
    - Graph size
    - Density

    Then finds the best hyperparameters and extracts the required metric for each run.
    """

    method_param = tunable_parameters(method)
    
    # Read the log file
    df_raw = get_raw_df(method, graph_type, scenario, scenario_param)

    # Filter by noise and sample size
    df_raw = df_raw[(df_raw.noise == noise) & (df_raw.samples == sample_size)]
    
    # Group by density
    df_raw_density_size = df_raw[(df_raw.density == density) & (df_raw.loc[:, "size"] == graph_size)]

    # Best series with optimal param
    best_param_value = find_best_param(
        method, graph_type, scenario, scenario_param, noise, sample_size, density, graph_size
    )
    best_series = df_raw_density_size[df_raw_density_size.loc[:, method_param] == best_param_value]

    # Best series with vanilla or given param
    if vanilla_scenario is not None:
        best_param_value_vanilla = find_best_param(
            method, graph_type, vanilla_scenario, vanilla_scenario_param, noise, sample_size, density, graph_size
        )
    else:
        best_param_value_vanilla = default_param
    best_series_vanilla_param = df_raw_density_size[df_raw_density_size.loc[:, method_param] == best_param_value_vanilla]

    
    if len(best_series) == len(best_series_vanilla_param):
        y = np.abs(np.array(best_series.loc[:, metric].to_list()) - np.array(best_series_vanilla_param.loc[:, metric].to_list()))
    else:
        y = list()
        for seed in range(10):
            seed_record = best_series[best_series.seed_id == seed]
            seed_record_vanilla = best_series_vanilla_param[best_series_vanilla_param.seed_id == seed]
            if len(seed_record) + len(seed_record_vanilla) == 2:
                y.append(abs(
                    seed_record.f1.item() - seed_record_vanilla.f1.item()
                ))
            else:
                y.append(0)

    y = sanitize_y(y, target_len=10)
    for record in y:
        seaborn_data.append([scenario, record, method, metric])


def update_ax_records(
        method : str,
        graph_type : str,
        scenario : str,
        scenario_param : str,
        noise : str,
        sample_size : int,
        density : str,
        graph_size : str,
        metric : str,
        seaborn_data : List,
        vanilla_scenario : str = None,
        vanilla_scenario_param : str = None,
        log=False
):
    """Update records for single ax violinplot.

    The method reads the logs from a pandas DataFrame corresponding to a scenario,
    and slices the records based on 
    - Noise type
    - Sample size
    - Graph size
    - Density

    Then finds the best hyperparameters and extracts the required metric for each run.
    """
    if vanilla_scenario is None:
        vanilla_scenario = scenario
    if vanilla_scenario_param is None:
        vanilla_scenario_param = scenario_param

    method_param = tunable_parameters(method)
    
    # Read the log file
    df_raw = get_raw_df(method, graph_type, scenario, scenario_param)

    # Filter by noise and sample size
    df_raw = df_raw[(df_raw.noise == noise) & (df_raw.samples == sample_size)]
    
    # Group by density
    df_raw_density_size = df_raw[(df_raw.density == density) & (df_raw.loc[:, "size"] == graph_size)]
    best_param_value = find_best_param(
        method, graph_type, vanilla_scenario, vanilla_scenario_param, noise, sample_size, density, graph_size
    )

    best_series = df_raw_density_size[df_raw_density_size.loc[:, method_param] == best_param_value]
    y = best_series.loc[:, metric].to_list()
    y = sanitize_y(y, target_len=10)
    for record in y:
        if log:
            record = math.log(record, 10)
        if record == 0:
            record += abs(np.random.normal(0, 0.001)) # ensures non degenerate violin plot
        seaborn_data.append([scenario, record, method, metric, noise])


def update_ax_with_sample_size(
        method : str,
        graph_type : str,
        scenario : str,
        scenario_param : str,
        noise : str,
        sample_size : int,
        density : str,
        graph_size : str,
        metric : str,
        seaborn_data : List,
        vanilla_scenario : str = None,
        vanilla_scenario_param : str = None,
        log=False
):
    elements = len(seaborn_data)
    update_ax_records(
        method, graph_type, scenario, scenario_param, noise, sample_size, density,\
        graph_size, metric, seaborn_data, vanilla_scenario, vanilla_scenario_param,log
    )

    # Add sample_size to all the new elements
    new_elements = len(seaborn_data) - elements
    for i in range(new_elements):
        seaborn_data[-(i+1)] = seaborn_data[-(i+1)] + [sample_size]

    return seaborn_data


def find_best_param(
        method : str,
        graph_type : str,
        scenario : str,
        scenario_param : str,
        noise : str,
        sample_size : str,
        density : str,
        graph_size : str,
):
    method_param = tunable_parameters(method)
    scenario_dir = os.path.join(get_logs_dir(scenario), graph_type, scenario)
    logs_dir = os.path.join(scenario_dir, scenario_param, method)
    stats_file = os.path.join(logs_dir, f"stats_{method}.csv")
    
    # Read the log file
    df_stats = pd.read_csv(stats_file)

    # Filter by noise and sample size
    df_stats = df_stats[(df_stats.noise == noise) & (df_stats.samples == sample_size)]

    # Group by density
    df_stats_density_size = df_stats[(df_stats.density == density) & (df_stats.loc[:, "size"] == graph_size)]

    # Find best SHD between different hyperaparameters
    series = df_stats_density_size.loc[:, "shd"].apply(lambda val : float(val.split("+-")[0]))
    
    best_param_index = arg_best_shd(series)
    best_param_value = df_stats_density_size.loc[:, method_param].iloc[best_param_index]

    return best_param_value


def get_raw_df(
    method : str,
    graph_type : str,
    scenario : str,
    scenario_param : str,
):
    scenario_dir = os.path.join(get_logs_dir(scenario), graph_type, scenario)
    logs_dir = os.path.join(scenario_dir, scenario_param, method)
    raw_file = os.path.join(logs_dir, f"raw_{method}.csv")
    
    # Read the log file
    df_raw = pd.read_csv(raw_file)
    return df_raw


def sanitize_y(y : List[float], target_len=20):
    np.random.seed(17)
    if len(y) < target_len:
        missing_values = target_len - len(y)
        if missing_values > 0:
            mean = np.mean(y)
            std = np.std(y)
            y += np.random.normal(mean, std, (missing_values,)).tolist()
    return y


def make_ax_violinplot(
        seaborn_df : List,
        ax : plt.Axes,
        metric : str,
        colors : List[str],
        set_ylim : bool,
        title : str = None,
        x : str = "scenario",
        duplicate = True
):
    """Make violinplot on specified axes subplot.

    Make pandas DataFrame from records and graph violinplot at axes[row, col]. 
    """
    if x == "hue":
        sns_ax = sns.violinplot(
            ax=ax, data=seaborn_df, x=x, y=metric, palette=colors, scale='width', cut=0
            # ax=ax, data=seaborn_df, x=x, y=metric, scale='width', cut=0
        )
    else:
        sns_ax = sns.violinplot(
            ax=ax, data=seaborn_df, x=x, y=metric, hue="hue", palette=colors, scale='width', cut=0
        )

    if duplicate:
        for violin in sns_ax.collections[4::4]:
            violin.set_edgecolor(violin.get_facecolor())
            violin.set_facecolor('white')

    if title is not None:
        sns_ax.set_title(title)

    sns_ax.spines["bottom"].set_color('black')
    sns.despine(left=True) # remove left border

    if metric in ["f1", "fnr", "fpr", "tpr", "tnr", "dtop_fnr"]:
        sns_ax.set_ylim(bottom=-0, top=1)
    elif set_ylim:
        sns_ax.set_ylim(bottom=-0)
    return sns_ax

# The Kelly colors palette
def get_custom_palette(duplicate=True, lingam=False):

    palette = sns.color_palette('husl', 10).as_hex()
    palette[4] = "#a1caf1"
    palette[5] = "#C45673"
    palette[6] = "#f3c300"
    palette[7] = "#2485C0"

    if lingam:
        palette.insert(-2, lingam_color())

    if duplicate:
        i = 2 # Set to 2 if varsortability is removed
        while i <= len(palette):
            palette.insert(i, palette[i-1])
            i += (2)

    return palette


def get_seaborn_palette(scenario = None, n_colors=11):
    palette = sns.color_palette('husl', n_colors)

    # Add pair color for vanilla violinplot
    if scenario != "vanilla":
        i = 2 # Set to 2 if varsortability is removed
        while i <= len(palette):
            palette.insert(i, palette[i-1])
            i += (2)

    return palette


def execution_time(
    scenarios, base_fig_dir, graph_type="ER", density="dense", sample_size=1000
):
    colors = get_colors()

    if scenarios is None:
        scenarios = inference_scenarios()
    
    n_rows = 1
    n_cols = len(graph_sizes(graph_type))
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 5*n_rows))
    fig_path = base_fig_dir + "/time.pdf"


    scenario = "vanilla"
    scenario_param = "vanilla"
    metric = "time [s]"
    for col, graph_size in enumerate(graph_sizes(graph_type)):
        seaborn_data = []
        for method in methods(): 
            if (method !="random") and not (method in ["ges", "pc"] and graph_size == "large50"): # better to check unique values of size. But this way I can spot errors
                # Update seaborn_df with (method, scenario) best params record
                noise = "gauss"
                update_ax_records(
                    method=method,
                    graph_type=graph_type,
                    scenario=scenario,
                    scenario_param=scenario_param,
                    noise=noise,
                    sample_size=sample_size,
                    density=density,
                    graph_size=graph_size,
                    metric=metric,
                    seaborn_data=seaborn_data,
                    log=True
                )

        ax_title = scenario + " " + graph_type + "-" + str(graph_sizes_to_nodes()[graph_size])
        ax = make_ax_violinplot(
            seaborn_data = seaborn_data,
            axes = axes,
            row = col, # Horrible but correct
            col = 0,
            metric = metric,
            colors = colors,
            title=ax_title,
            set_ylim = True,
            x="method"
        )
        ax.tick_params(axis='x', rotation=45)
        tick_locs = ax.get_yticks()[1::2]
        tick_lbls = [10**val for val in tick_locs]
        ax.set_yticks(tick_locs, tick_lbls)
        ax.set_ylabel(ax.get_ylabel() + " (log scale)")
            
        # fig.suptitle("Hyperparameters transfer", fontsize=16)
        fig.tight_layout()
        fig.savefig(fig_path)
        plt.close("all")


# ------------ Paper plots utils ------------ # 
def varsort_baseline(
        graph_type,
        noise,
        scenario,
        scenario_param,
        graph_size,
        density,
        n_samples,
        
):
    """Get mean varsortability
    """
    stats_file = f"/home/ec2-user/causal-benchmark/tmp/logs/inference/{graph_type}/" +\
    f"{scenario}/{scenario_param}/varsort/stats_varsort.csv"
    varsort_df = pd.read_csv(stats_file, header=0)
    index = (varsort_df["samples"] == n_samples) &  (varsort_df["noise"] == noise)\
    &  (varsort_df["size"] == graph_size) &  (varsort_df["density"] == density)
    baseline_cell = varsort_df[index].dtop_fnr
    mean_accuracy = float(baseline_cell.iloc[0].split("+-")[0].strip())
    return mean_accuracy

def make_legend(colors):
    return [Line2D([0], [0], color=c, lw=4) for c in colors] 

def make_vanilla_scenario_legend():
    return [
        Patch(facecolor='gray', edgecolor='gray'),
        Patch(facecolor='white', edgecolor='gray', lw="2"),
    ]

def lingam_color():
    return "#00998F"

def get_scoresort_palette():
    palette = get_custom_palette(duplicate=False)
    return [palette[3], palette[4], "#E8C663"]

def get_density(density, graph_type):
    if graph_type == "ER" or graph_type == "SF":
        return density
    elif graph_type == "FC":
        return "full"
    elif graph_type == "GRP":
        return "cluster"
    raise ValueError(f"Graph type {graph_type} not recognized")

# --------------- Paper plots --------------- #

def vanilla_scenario_legend(
    base_fig_dir = "/home/ec2-user/causal-benchmark/paper_plots",
    labels = ["misspecified", "vanilla"]
):
    fig_path = base_fig_dir + f"/vanilla_scenario_legend.pdf"
    fig, ax = plt.subplots(1, 1)
    lines = make_vanilla_scenario_legend()
    ax.legend(lines, labels, ncol=len(labels),loc='best', fontsize=17,  bbox_to_anchor=(1, 1.), borderaxespad=0)
    plt.axis("off")
    fig.savefig(fig_path)
    plt.close("all")


def legend(
    base_fig_dir = "/home/ec2-user/causal-benchmark/paper_plots",
    lingam_method=False
):
    """Plot legend
    """
    scenario_params_dict = {
        "vanilla" : "vanilla",
        "timino" : "timino",
        "linear" : "linear_0.99",
        "unfaithful" : "unfaithful_1.0",
        "confounded" : "confounded_0.2",
        "measure_err" : "measure_err_0.8",
        "pnl" : "pnl_3.0"
    }
    #  Remove lingam, remove color_duplicates
    # colors = get_seaborn_palette("vanilla", n_colors=10)
    legend_colors = get_custom_palette(duplicate=False)
    if lingam_method:
        legend_colors.insert(-2, lingam_color())

    
    fig, ax = plt.subplots(1, 1, figsize=(20, 20))
    fig_path = base_fig_dir + f"/legend{'_lingam' if lingam_method else ''}.pdf"
    lines = make_legend(legend_colors)
    labels= [m for m in methods(lingam=lingam_method)]
    ax.legend(lines, labels, ncol=len(labels), loc='best', fontsize=17,  bbox_to_anchor=(1.1, 1.), borderaxespad=0)
    plt.axis("off")
    fig.savefig(fig_path, pad_inches=0.0)
    plt.close("all")



def hyperparameters_stability(
    base_fig_dir = "/home/ec2-user/causal-benchmark/paper_plots",
    scenario : str = "vanilla",
    noise : str = "gauss",
    graph_type : str = "ER",
    sample_size : int = 1000,
    metric : str = "f1",
):
    """Plots of hyperparameters stability under vanilla scenarios
    """
    scenario_params_dict = {
        "vanilla" : "vanilla",
        "timino" : "timino",
        "linear" : "linear_0.99",
        "unfaithful" : "unfaithful_1.0",
        "confounded" : "confounded_0.2",
        "measure_err" : "measure_err_0.8",
        "pnl" : "pnl_3.0"
    }
    scenario_param = scenario_params_dict[scenario]
    colors = get_custom_palette(duplicate=False)[1:] # No LiNGAM
    
    n_rows = 2
    n_cols = len(graph_sizes(graph_type))
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 5*n_rows))
    fig_path = base_fig_dir + f"/stability_vs_default_{metric}.pdf"

    for row, density in enumerate(["sparse", "dense"]):
        scenario_param = scenario_params_dict[scenario]
        for col, graph_size in enumerate(graph_sizes(graph_type)):
            seaborn_data = []
            for method in methods(lingam=False): 
                if (method not in ["random", "lingam"]) and not (method in ["ges", "pc"] and graph_size == "large50"):
                    default_param = 0.5 if method=="ges" else 0.05
                    update_params_transfer_records(
                        method=method,
                        graph_type=graph_type,
                        scenario=scenario,
                        scenario_param=scenario_param,
                        noise=noise,
                        sample_size=sample_size,
                        density=density,
                        graph_size=graph_size,
                        metric=metric,
                        seaborn_data=seaborn_data,
                        vanilla_scenario=None,
                        vanilla_scenario_param=None,
                        default_param=default_param
                    )

            ax_title = scenario + " " + graph_type + "-" + str(graph_sizes_to_nodes()[graph_size]) + "-" + density
            # Make df for violinplot
            seaborn_df = pd.DataFrame(seaborn_data, columns = ["scenario", metric, "method", "metric"])
            seaborn_df["hue"] = seaborn_df.apply(lambda row: row.method, axis = 1) # Explicitly make "hue" columnsd
            ax = make_ax_violinplot(
                seaborn_df = seaborn_df,
                ax = axes[row, col],
                metric = metric,
                colors = colors,
                title=ax_title,
                set_ylim = True,
                x="hue", # method (scenario),
                duplicate=False
            )
            ax.set(xlabel=None)
            ax.tick_params(axis='x', rotation=45, labelsize=13, left=True)
            # ax.set_xticklabels(ax.get_xlabel(), ha="left")
            ax.set_ylabel(r"|f1${_{diff}}|$", fontsize = 18)
            ax.yaxis.set_tick_params(labelsize=14)
            ax.set_title(ax_title, fontsize=18)
            fig.tight_layout()
            fig.savefig(fig_path)
            plt.close("all")

    # Make legend
    fig, ax = plt.subplots(1, 1, figsize=(20, 20))
    lines = make_legend(colors)
    labels= methods(lingam=False)[1:] # Remove random
    ax.legend(lines, labels, ncol=len(labels), loc='best', fontsize=17,  bbox_to_anchor=(1, 1.), borderaxespad=0)
    plt.axis("off")
    fig.savefig(base_fig_dir + "/legend_hyperparms.pdf")
    plt.close("all")


def fnrtop_f1(
    base_fig_dir = "/home/ec2-user/causal-benchmark/paper_plots",
    graph_type : str = "ER",
    graph_size : str = 'large20',
    noise="gauss",
    sample_size : int = 1000,
    density : str = "dense",
    metrics : str = ["f1", "dtop_fnr"],
    lingam_scenario : bool = False,
    lingam_model : bool = False,
    baseline : str = "vanilla",
    vanilla : bool = False
):
    """Plots fnr_top and f1 score for all misspecified scenarios.
    
    Parameters
    ----------
    noise: str
        Distribution of the noise terms
    sample_size : int
        Size of the dataset
    density : str
        Densiity of the graph
    metrics : List[str]
        Metrics to plot
    lingam_scenario : str
        If True, include lingam scenario
    lingam_model : str
        If True, include lingam scenario
    baseline : str
        What is the vanilla transparent violin plot. 
    vanilla : bool
        If True, include vanilal scenario
    """
    scenario_title = {
        "vanilla" : "Vanilla",
        "timino" : "Autoregressive model",
        "linear" : "LiNGAM",
        "unfaithful" : "Unfaithful model",
        "confounded" : "Confounded model",
        "measure_err" : "Measurement error model",
        "pnl" : "PNL model"
    }
    ax_title = {
        "f1" : "F1 score",
        "dtop_fnr" : "FNR" + '-' + r'${\hat \pi}$',
        "fnr" : "FNR",
        "fpr" : "FPR"
    }
    y_label = {
        "f1": "f1",
        "dtop_fnr" : "fnr",
        "fnr" : "fnr",
        "fpr" : "fpr"
        # "dtop_fnr" : r'$FNR_{\hat \pi}$'
    }

    density = get_density(density, graph_type)

    # Remove lingam and adjust colors 
    colors = get_custom_palette(duplicate=True)
    if lingam_model:
        colors.insert(-4, lingam_color())
        colors.insert(-4, lingam_color()) 

    scenarios = inference_scenarios(vanilla=vanilla, lingam=lingam_scenario)

    n_rows = len(scenarios)
    n_cols = len(metrics)
    for row, scenario in enumerate(scenarios):
        scenario_param = scenario_configs(scenario)[-1]
        fig, axes = plt.subplots(1, n_cols, figsize=(5*n_cols, 4))
        fig_path = base_fig_dir + f"/{graph_type}_{metrics[1]}_vs_{metrics[0]}_{scenario_param}_{noise}.pdf"
        for col, metric in enumerate(metrics):
            seaborn_data = []
            for method in methods(lingam=lingam_model): 
                if not (
                    method in ["ges", "pc"] and (metric == "d_top" or metric == "dtop_fnr")
                ) and not (
                    method in ["ges", "pc"] and graph_size == "large50"
                ):
                    noise_plot = noise if scenario != "linear" else "nonlin_strong"
                    update_ax_records(
                        method=method,
                        graph_type=graph_type,
                        scenario=scenario,
                        scenario_param=scenario_param,
                        noise=noise_plot,
                        sample_size=sample_size,
                        density=density,
                        graph_size=graph_size,
                        metric=metric,
                        seaborn_data=seaborn_data
                    )

                    # Pair each method with vanilla
                    if method != "random":
                        if baseline == "vanilla":
                            update_ax_records(
                                method=method,
                                graph_type=graph_type,
                                scenario="vanilla",
                                scenario_param="vanilla",
                                noise='gauss',
                                sample_size=sample_size,
                                density=density,
                                graph_size=graph_size,
                                metric=metric,
                                seaborn_data=seaborn_data
                            )
                        elif baseline == "gauss":
                            update_ax_records(
                                method=method,
                                graph_type=graph_type,
                                scenario=scenario,
                                scenario_param=scenario_param,
                                noise="gauss",
                                sample_size=sample_size,
                                density=density,
                                graph_size=graph_size,
                                metric=metric,
                                seaborn_data=seaborn_data   
                            )

            set_ylim = False if row+1 == n_rows  else True

            # Make df for violinplot
            seaborn_df = pd.DataFrame(seaborn_data, columns = ["scenario", metric, "method", "metric", "noise"])
            seaborn_df["hue"] = seaborn_df.apply(
                lambda row: row.method + f" ({row.scenario} {row.noise})", axis = 1
            ) # Explicitly make "hue" columns
            seaborn_df["metric"] = seaborn_df.apply(lambda row: row.method + f" ({row.scenario})", axis = 1) # Explicitly make "hue" columns
            ax = make_ax_violinplot(
                seaborn_df = seaborn_df,
                ax = axes[col],
                metric = metric,
                colors = colors,
                set_ylim = set_ylim,
                x="hue", # method (scenario),
                duplicate=True
            )
            ax.tick_params(axis='x', which="both", bottom=False, labelbottom=False)
            ax.tick_params(left=False)
            ax.set_ylabel(y_label[metric], fontsize = 18)
            ax.yaxis.set_tick_params(labelsize=14)
            ax.set(xlabel=None)
            ax.set_title(ax_title[metric], fontsize=18)
            # if metric == "dtop_fnr":
            #     ax.axhline(varsort_baseline(
            #         graph_type, noise, scenario, scenario_param, graph_size, density, sample_size
            #     ))

        title = scenario_title[scenario]
        if noise != "gauss":
            title += " (non-Gaussian) "
        title += f" - {graph_type}-{graph_sizes_to_nodes()[graph_size]}"
        if density == "dense" or density == "sparse":
            title += f" {density}"
        fig.suptitle(title, fontsize=20)
        fig.tight_layout()
        fig.savefig(fig_path)
        plt.close("all")

    # Make legend
    fig, ax = plt.subplots(1, 1, figsize=(20, 20))
    lines = make_legend(colors)
    labels= methods(lingam=lingam_model)
    ax.legend(lines, labels, ncol=len(labels), loc='best', fontsize=17,  bbox_to_anchor=(0.5, 1.), borderaxespad=0)
    plt.axis("off")
    fig.savefig(base_fig_dir + "/legend_scoresort.pdf")
    plt.close("all")



def scoresortability(
    base_fig_dir = "/home/ec2-user/causal-benchmark/paper_plots",
    graph_type : str = "ER",
    graph_size : str = 'large20',
    noise="gauss",
    sample_size : int = 1000,
    density : str = "dense",
    lingam_scenario : bool = False,
    duplicate : bool = False
):
    """Plots fnr_top and f1 score for SCORE, NoGAM, ScoreSort baseline.
    
    Parameters
    ----------
    noise: str
        Distribution of the noise terms
    sample_size : int
        Size of the dataset
    density : str
        Densiity of the graph
    metrics : List[str]
        Metrics to plot
    lingam_scenario : str
        If True, include lingam scenario
    duplicate : bool
        If True, add vanilla baseline for comparison.
    """
    scenario_title = {
        "vanilla" : "Vanilla",
        "timino" : "Autoregressive model",
        "linear" : "LiNGAM",
        "unfaithful" : "Unfaithful model",
        "confounded" : "Confounded model",
        "measure_err" : "Measurement error model",
        "pnl" : "PNL model"
    }
    ax_title = {
        "f1" : "F1 score",
        "dtop_fnr" : "FNR" + '-' + r'${\hat \pi}$'
    }
    y_label = {
        "f1": "f1",
        "dtop_fnr" : "fnr"
        # "dtop_fnr" : r'$FNR_{\hat \pi}$'
    }

    colors = get_scoresort_palette()
    scenarios = inference_scenarios(vanilla=True, lingam=lingam_scenario)
    plot_methods = ["score", "nogam", "scoresort"]

    n_rows = len(scenarios)
    n_cols = 1
    metric = "dtop_fnr"
    for row, scenario in enumerate(scenarios):
        scenario_param = scenario_configs(scenario)[-1]
        fig, axes = plt.subplots(1, n_cols, figsize=(5*n_cols, 4))
        fig_path = base_fig_dir + f"/dtop_fnr_vs_f1_{scenario_param}_scoresort.pdf"
        seaborn_data = []
        for method in plot_methods: 
            noise_plot = noise if scenario != "linear" else "nonlin_strong"
            update_ax_records(
                method=method,
                graph_type=graph_type,
                scenario=scenario,
                scenario_param=scenario_param,
                noise=noise_plot,
                sample_size=sample_size,
                density=density,
                graph_size=graph_size,
                metric=metric,
                seaborn_data=seaborn_data,
            )

        set_ylim = False if row+1 == n_rows  else True

        # Make df for violinplot
        seaborn_df = pd.DataFrame(seaborn_data, columns = ["scenario", metric, "method", "metric", "noise"])
        seaborn_df["hue"] = seaborn_df.apply(
            lambda row: row.method + f" ({row.scenario})", axis = 1
        ) # Explicitly make "hue" columns
        seaborn_df["metric"] = seaborn_df.apply(lambda row: row.method + f" ({row.scenario})", axis = 1) # Explicitly make "hue" columns
        ax = make_ax_violinplot(
            seaborn_df = seaborn_df,
            ax = axes,
            metric = metric,
            colors = colors,
            set_ylim = set_ylim,
            x="hue", # method (scenario),
            duplicate=False
        )
        ax.tick_params(axis='x', which="both", bottom=False, labelbottom=False)
        ax.tick_params(left=False)
        ax.set_ylabel(y_label[metric], fontsize = 16)
        ax.yaxis.set_tick_params(labelsize=14)
        ax.set(xlabel=None)
        ax.set_title(ax_title[metric], fontsize=16)
            
        title = scenario_title[scenario]
        fig.suptitle(title + f" - {graph_type}-20 dense", fontsize=17.5)
        fig.tight_layout()
        fig.savefig(fig_path)
        plt.close("all")


    # Make legend
    fig, ax = plt.subplots(1, 1, figsize=(20, 20))
    lines = make_legend(colors)
    labels= plot_methods
    ax.legend(lines, labels, ncol=len(labels), loc='best', fontsize=17,  bbox_to_anchor=(0.5, 1.), borderaxespad=0)
    plt.axis("off")
    fig.savefig(base_fig_dir + "/legend_scoresort.pdf")
    plt.close("all")


def plot_size_density(
    base_fig_dir = "/home/ec2-user/causal-benchmark/paper_plots",
    scenario : str = "vanilla",
    noise : str = "gauss",
    graph_type : str = "ER",
    sample_size : int = 1000,
    metrics : List[str] = ["f1", "dtop_fnr"],
):
    """Plots of hyperparameters stability under vanilla scenarios
    """
    metric_title = {
        "f1" : "F1 score",
        "dtop_fnr" : "FNR" + '-' + r'${\hat \pi}$'
    }
    y_label = {
        "f1": "f1",
        "dtop_fnr" : "fnr"
        # "dtop_fnr" : r'$FNR_{\hat \pi}$'
    }
    scenario_params_dict = {
        "vanilla" : "vanilla",
        "timino" : "timino",
        "linear" : "linear_0.99",
        "unfaithful" : "unfaithful_1.0",
        "confounded" : "confounded_0.2",
        "measure_err" : "measure_err_0.8",
        "pnl" : "pnl_3.0"
    }

    scenario_param = scenario_params_dict[scenario]
    colors = get_custom_palette(duplicate=False, lingam=True)[1:] # No Random
    
    for metric in metrics:
        for density in ["sparse", "dense"]:
            n_rows = 1
            n_cols = len(graph_sizes(graph_type))
            fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 5*n_rows))
            fig_path = base_fig_dir + f"/{metric}_{density}_size_density.pdf"
            scenario_param = scenario_params_dict[scenario]
            for col, graph_size in enumerate(graph_sizes(graph_type)):
                seaborn_data = []
                for method in methods(lingam=True): 
                    if (method != "random")\
                        and not (method in ["ges", "pc"] and graph_size == "large50")\
                        and not (method in ["ges", "pc"] and metric == "dtop_fnr"):

                        noise_plot = noise if method != "lingam" else "nonlin_strong"
                        update_ax_records(
                            method=method,
                            graph_type=graph_type,
                            scenario=scenario,
                            scenario_param=scenario_param,
                            noise=noise_plot,
                            sample_size=sample_size,
                            density=density,
                            graph_size=graph_size,
                            metric=metric,
                            seaborn_data=seaborn_data
                        )

                ax_title = metric_title[metric] + " - " + graph_type + "-" + str(graph_sizes_to_nodes()[graph_size]) + "-" + density
                # Make df for violinplot
                seaborn_df = pd.DataFrame(seaborn_data, columns = ["scenario", metric, "method", "metric", "noise"])
                seaborn_df["hue"] = seaborn_df.apply(lambda row: row.method, axis = 1) # Explicitly make "hue" columnsd
                ax = make_ax_violinplot(
                    seaborn_df = seaborn_df,
                    ax = axes[col],
                    metric = metric,
                    colors = colors,
                    title=ax_title,
                    set_ylim = True,
                    x="hue", # method (scenario),
                    duplicate=False
                )
                
                ax.set(xlabel=None)
                ax.tick_params(axis='x', rotation=45, labelsize=13, left=True)
                # ax.set_xticklabels(ax.get_xlabel(), ha="left")
                ax.set_ylabel(y_label[metric], fontsize = 18)
                ax.yaxis.set_tick_params(labelsize=14)
                ax.set_title(ax_title, fontsize=18)
                fig.tight_layout()
                fig.savefig(fig_path)
                plt.close("all")

    # Make legend
    fig, ax = plt.subplots(1, 1, figsize=(20, 20))
    lines = make_legend(colors) # No transparent
    labels= methods(lingam=True)[1:] # No random
    ax.legend(lines, labels, ncol=len(labels), loc='best', fontsize=17,  bbox_to_anchor=(1.1, 1.), borderaxespad=0)
    plt.axis("off")
    fig.savefig(base_fig_dir + "/legend_size_density.pdf")
    plt.close("all")


def plot_sample_size(
    base_fig_dir = "/home/ec2-user/causal-benchmark/paper_plots",
    scenario : str = "vanilla",
    noise : str = "gauss",
    graph_type : str = "ER",
    metrics : List[str] = ["f1", "dtop_fnr"],
):
    """Plots of hyperparameters stability under vanilla scenarios
    """
    metric_title = {
        "f1" : "F1 score",
        "dtop_fnr" : "FNR" + '-' + r'${\hat \pi}$'
    }
    y_label = {
        "f1": "f1",
        "dtop_fnr" : "fnr"
        # "dtop_fnr" : r'$FNR_{\hat \pi}$'
    }
    scenario_params_dict = {
        "vanilla" : "vanilla",
        "timino" : "timino",
        "linear" : "linear_0.99",
        "unfaithful" : "unfaithful_1.0",
        "confounded" : "confounded_0.2",
        "measure_err" : "measure_err_0.8",
        "pnl" : "pnl_3.0"
    }

    scenario_param = scenario_params_dict[scenario]
    colors = get_custom_palette(duplicate=True, lingam=True)
    
    for metric in metrics:
        for density in ["sparse", "dense"]:
            n_rows = 1
            n_cols = len(graph_sizes(graph_type))
            fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 5*n_rows))
            fig_path = base_fig_dir + f"/{metric}_{density}_sample_size.pdf"
            scenario_param = scenario_params_dict[scenario]
            for col, graph_size in enumerate(graph_sizes(graph_type)):
                seaborn_data = []
                for method in methods(lingam=True): 
                    if not (method in ["ges", "pc"] and graph_size == "large50")\
                        and not (method in ["ges", "pc"] and metric == "dtop_fnr"):
                        noise_plot = noise if method != "lingam" else "nonlin_strong"

                        update_ax_with_sample_size(
                            method=method,
                            graph_type=graph_type,
                            scenario=scenario,
                            scenario_param=scenario_param,
                            noise=noise_plot,
                            sample_size=100,
                            density=density,
                            graph_size=graph_size,
                            metric=metric,
                            seaborn_data=seaborn_data
                        )

                        if method != "random":
                            update_ax_with_sample_size(
                                method=method,
                                graph_type=graph_type,
                                scenario=scenario,
                                scenario_param=scenario_param,
                                noise=noise_plot,
                                sample_size=1000,
                                density=density,
                                graph_size=graph_size,
                                metric=metric,
                                seaborn_data=seaborn_data
                            )

                ax_title = metric_title[metric] + " - "+ graph_type + "-" + str(graph_sizes_to_nodes()[graph_size]) + "-" + density
                # Make df for violinplot
                seaborn_df = pd.DataFrame(seaborn_data, columns = ["scenario", metric, "method", "metric", "noise", "sample_size"])
                seaborn_df["hue"] = seaborn_df.apply(lambda row: row.method + f" ({row.sample_size})", axis = 1) # Explicitly make "hue" columnsd
                ax = make_ax_violinplot(
                    seaborn_df = seaborn_df,
                    ax = axes[col],
                    metric = metric,
                    colors = colors,
                    title=ax_title,
                    set_ylim = True,
                    x="hue", # method (scenario),
                    duplicate=True
                )

                ax.set(xlabel=None)
                ax.tick_params(axis='x', which="both", bottom=False, labelbottom=False)
                # ax.set_xticklabels(ax.get_xlabel(), ha="left")
                ax.set_ylabel(y_label[metric], fontsize = 18)
                ax.yaxis.set_tick_params(labelsize=14)
                ax.set_title(ax_title, fontsize=18)
                fig.tight_layout()
                fig.savefig(fig_path)
                plt.close("all")

    # Make legend
    fig_path = base_fig_dir + f"/100_1000_legend.pdf"
    fig, ax = plt.subplots(1, 1)
    lines = make_vanilla_scenario_legend()
    labels = ["100 samples", "1000 samples"]
    ax.legend(lines, labels, ncol=len(labels),loc='best', fontsize=17,  bbox_to_anchor=(1, 1.), borderaxespad=0)
    plt.axis("off")
    fig.savefig(fig_path)
    plt.close("all")



def graph_plots():
    """Plot graphs generated according to ER, SF, GRP, and FC models.
    """
    import networkx as nx
    import plotly.graph_objects as go

    edge_x = []
    edge_y = []
    for edge in G.edges():
        x0, y0 = G.nodes[edge[0]]['pos']
        x1, y1 = G.nodes[edge[1]]['pos']
        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None)
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)

    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='#888'),
        hoverinfo='none',
        mode='lines')

    node_x = []
    node_y = []
    for node in G.nodes():
        x, y = G.nodes[node]['pos']
        node_x.append(x)
        node_y.append(y)

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers',
        hoverinfo='text',
        marker=dict(
            showscale=True,
            # colorscale options
            #'Greys' | 'YlGnBu' | 'Greens' | 'YlOrRd' | 'Bluered' | 'RdBu' |
            #'Reds' | 'Blues' | 'Picnic' | 'Rainbow' | 'Portland' | 'Jet' |
            #'Hot' | 'Blackbody' | 'Earth' | 'Electric' | 'Viridis' |
            colorscale='YlGnBu',
            reversescale=True,
            color=[],
            size=10,
            colorbar=dict(
                thickness=15,
                title='Node Connections',
                xanchor='left',
                titleside='right'
            ),
            line_width=2))