import re
from glob import glob
from zipimport import path_sep

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
from rdkit import Chem

from ts_utils import parse_score, get_pareto_indices, jaccard_metric


# functions to generate plots for the paper.
# It probably would have been better to generalize this into a couple of functions

def smi2cansmi(smi_in):
    mol = Chem.MolFromSmiles(smi_in)
    return Chem.MolToSmiles(mol)


def get_color(cycle):
    return 0 if cycle == "ref" else 1


def recovery_stats(df):
    ref_smiles = df.query("cycle == 'ref'").SMILES
    recovery_list = []
    for k, v in df.groupby("cycle"):
        recovery_list.append([k, len(v.query("SMILES in @ref_smiles"))])
    recovery_df = pd.DataFrame(recovery_list, columns=["cycle", "recovered"])
    return recovery_df


def plot_recovery_barplot(df, ax=None, xlabel="Cycle"):
    colors = sns.color_palette("tab10")[1:3]
    pal = [colors[0]] * 11 + [colors[1]]
    ax = sns.barplot(x="cycle", y="recovered", data=df, palette=pal, ax=ax)
    for idx, r in enumerate(df.recovered.values):
        label = str(r)
        if len(label) == 2:
            offset = 0.1
        if len(label) == 3:
            offset = 0.2
        ax.text(idx - offset, 30, label, color="white", fontweight="bold")
    ax.set_xlabel(xlabel)
    ax.set_ylabel("Number of Top 100 Molecules Found")


def plot_stripplot(df, score_col, ax=None):
    df['color'] = df.cycle.apply(get_color)
    colors = sns.color_palette("tab10")[1:3]
    pal = {0: colors[1], 1: colors[0]}
    ax = sns.stripplot(x="cycle", y=score_col, data=df, hue="color", palette=pal, ax=ax)
    ax.set_xlabel("Cycle")
    ax.set_ylabel(f"{score_col} (Larger is Better)")
    ax.get_legend().set_visible(False)
    ax.set_xlabel(None)
    ax.set_xticks([])


def plot_ts(df, score_col):
    recovery_df = recovery_stats(df)
    sns.set_style('white')
    sns.set_context('talk')
    figure, axes = plt.subplots(2, 1, figsize=(20, 10), gridspec_kw={'height_ratios': [2, 1]})
    plot_stripplot(df, score_col=score_col, ax=axes[0])
    plot_recovery_barplot(recovery_df, ax=axes[1])
    plt.tight_layout()


def compile_results(file_spec, ref_file=None):
    df_list = []
    int_re = re.compile("[0-9]+")
    for filename in glob(file_spec):
        df = pd.read_csv(filename)
        cycle, warmup = [int(x) for x in int_re.findall(filename)]
        df = df.sort_values("score", ascending=False).drop_duplicates("SMILES").head(100).copy()
        df['warmup'] = warmup
        df['cycle'] = cycle
        df_list.append(df)
    combo_df = pd.concat(df_list)
    df_concat_3 = combo_df.query("warmup == 3").sort_values("score", ascending=False).drop_duplicates("SMILES").head(
        100).copy()
    df_concat_3['warmup'] = 3
    df_concat_3['cycle'] = "concat"
    df_concat_10 = combo_df.query("warmup == 10").sort_values("score", ascending=False).drop_duplicates("SMILES").head(
        100).copy()
    df_concat_10['warmup'] = 10
    df_concat_10['cycle'] = 'concat'
    df_list.append(df_concat_3)
    df_list.append(df_concat_10)
    if ref_file:
        ref_df = pd.read_csv(ref_file)
        ref_df['warmup'] = "ref"
        ref_df['cycle'] = "ref"
        df_list.append(ref_df)
    combo_df = pd.concat(df_list)
    combo_df.cycle = combo_df.cycle.astype(str)
    return combo_df


def plot_stripplot2(combo_df, include_ref=True, ax=None):
    if include_ref:
        order = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "concat", "ref"]
        hue_order = [3, 10, "ref"]
    else:
        order = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "concat"]
        hue_order = [3, 10]
    if ax is None:
        ax = sns.stripplot(x="cycle", y="score", hue="warmup", data=combo_df,
                           dodge=True, palette="tab10",
                           order=order,
                           hue_order=hue_order)
    else:
        ax = sns.stripplot(x="cycle", y="score", hue="warmup", data=combo_df,
                           dodge=True, palette="tab10",
                           order=order,
                           hue_order=hue_order, ax=ax)

    ax.legend(loc='upper left', bbox_to_anchor=(1.00, 0.75), ncol=1, title='Warmup');
    ax.set_xlabel(None)
    ax.set_xticks([])

    ax.set_ylabel("Tanimoto Coefficient (Bigger is Better)")
    handles = ax.legend_.legend_handles
    for h in handles:
        h.set_markersize(15)


def plot_recovery_barplot2(combo_df, ax=None, xlabel="Cycle"):
    ref_df = combo_df.query("warmup == 'ref'")
    match_list = []
    for i in combo_df.cycle.unique():
        if i == "ref":
            continue
        for warmup in [3, 10]:
            current_df = combo_df.query("cycle == @i and warmup == @warmup and SMILES in @ref_df.SMILES")
            match_list.append([i, len(current_df), warmup])
    match_list.append(["ref", 100, "ref"])
    match_df = pd.DataFrame(match_list, columns=['cycle', 'count', 'warmup'])
    list_ordering = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "concat", "ref"]
    match_df.cycle = pd.Categorical(match_df.cycle, list_ordering, ordered=True)
    if ax is None:
        ax = sns.barplot(x="cycle", y="count", hue="warmup", data=match_df,
                         palette="tab10")
    else:
        ax = sns.barplot(x="cycle", y="count", hue="warmup", data=match_df,
                         palette="tab10", ax=ax)
    ax.set_ylim(0, 105)
    # ax.legend(loc='upper left', bbox_to_anchor=(1.00, 0.75), ncol=1, title='Warmup')
    ax.legend_.remove()
    ax.set_xlabel(xlabel)
    ax.set_ylabel("Number of Top 100 Molecules Found")
    x_val = -0.37
    labels = match_df.sort_values(['cycle', 'warmup'])["count"].values

    for i in range(0, 11):
        ax.text(x_val, 50, labels[i * 2], color="white", fontweight="bold")
        ax.text(x_val + 0.28, 50, labels[i * 2 + 1], color="white", fontweight="bold")
        x_val = x_val + 1
    ax.text(x_val + 0.48, 50, 100, color="white", fontweight="bold")
    return match_df


def plot_iteration_stripplot2(combo_df, include_ref=True, ax=None):
    if include_ref:
        order = ["2000", "5000", "10000", "50000", "100000", "ref"]
        hue_order = [3, 10, "ref"]
    else:
        order = ["2000", "5000", "10000", "50000", "100000"]
        hue_order = [3, 10]
    if ax is None:
        ax = sns.stripplot(x="iterations", y="score", hue="warmup", data=combo_df,
                           order=order,
                           dodge=True, palette="tab10",
                           hue_order=hue_order)
    else:
        ax = sns.stripplot(x="iterations", y="score", hue="warmup", data=combo_df,
                           order=order,
                           dodge=True, palette="tab10",
                           hue_order=hue_order, ax=ax)

    ax.legend(loc='upper left', bbox_to_anchor=(1.00, 0.75), ncol=1, title='Warmup');
    ax.set_xlabel(None)
    ax.set_xticks([])
    ax.set_ylabel("Tanimoto Coefficient (Bigger is Better)")
    handles = ax.legend_.legend_handles
    for h in handles:
        h.set_markersize(15)


def plot_iteration_barplot(combo_df, ax=None, xlabel="Cycle"):
    ref_df = combo_df.query("warmup == 'ref'")
    match_list = []
    for i in combo_df.cycle.unique():
        if i == "ref":
            continue
        for warmup in [3, 10]:
            current_df = combo_df.query("cycle == @i and warmup == @warmup and SMILES in @ref_df.SMILES")
            match_list.append([i, len(current_df), warmup])
    match_list.append(["ref", 100, "ref"])
    match_df = pd.DataFrame(match_list, columns=['cycle', 'count', 'warmup'])
    list_ordering = ["2000", "5000", "10000", "50000", "100000", "ref"]
    match_df.cycle = pd.Categorical(match_df.cycle, list_ordering, ordered=True)
    if ax is None:
        ax = sns.barplot(x="cycle", y="count", hue="warmup", data=match_df,
                         palette="tab10")
    else:
        ax = sns.barplot(x="cycle", y="count", hue="warmup", data=match_df,
                         palette="tab10", ax=ax)
    ax.set_ylim(0, 105)
    ax.legend(loc='upper left', bbox_to_anchor=(1.00, 0.75), ncol=1, title='Warmup')
    ax.set_xlabel(xlabel)
    ax.set_ylabel("Number of Top 100 Molecules Found")
    ax.legend_.remove()

    x_val = -0.33
    labels = match_df.sort_values(['cycle', 'warmup'])["count"].values

    for i in range(0, 5):
        ax.text(x_val, 50, labels[i * 2], color="white", fontweight="bold")
        ax.text(x_val + 0.28, 50, labels[i * 2 + 1], color="white", fontweight="bold")
        x_val = x_val + 1
    ax.text(x_val + 0.52, 50, 100, color="white", fontweight="bold")

    return match_df


def compile_iteration_data(file_spec, reference_file=None):
    df_list = []
    int_re = re.compile("[0-9]+")
    for filename in glob(file_spec):
        df = pd.read_csv(filename)
        warmup, iterations = [int(x) for x in int_re.findall(filename)]
        df = df.sort_values("score", ascending=False).drop_duplicates("SMILES").head(100).copy()
        df['warmup'] = warmup
        df['iterations'] = iterations
        df_list.append(df)

    if reference_file:
        ref_df = pd.read_csv(reference_file)
        ref_df['warmup'] = "ref"
        ref_df['iterations'] = "ref"
        df_list.append(ref_df)

    combo_df = pd.concat(df_list)
    combo_df.iterations = combo_df.iterations.astype(str)
    return combo_df


def plot_random_stripplot(ref_filespec, random_filespec, ts_filespec, ax=None):
    colors = sns.color_palette("tab10")[1:4]
    int_re = re.compile("([0-9]+)")
    ref_df = pd.read_csv(ref_filespec)
    ref_df['cycle'] = 'ref'
    ref_df['method'] = 'ref'

    random_df_list = []
    for filename in glob(random_filespec):
        cycle = int_re.findall(filename)[0]
        df = pd.read_csv(filename)
        df['cycle'] = cycle
        df['method'] = 'random'
        random_df_list.append(df)
    concat_df = pd.concat(random_df_list).sort_values("score", ascending=False).drop_duplicates("SMILES").head(
        100).copy()
    concat_df['cycle'] = 'concat'
    concat_df['method'] = 'random'
    random_df_list.append(concat_df)

    ts_df_list = []
    for filename in glob(ts_filespec):
        cycle = int_re.findall(filename)[0]
        df = pd.read_csv(filename)
        df = df.sort_values("score", ascending=False).drop_duplicates("SMILES").head(100)
        df['cycle'] = cycle
        df['method'] = 'ts'
        ts_df_list.append(df)
    concat_df = pd.concat(ts_df_list).sort_values("score", ascending=False).drop_duplicates("SMILES").head(100).copy()
    concat_df['cycle'] = 'concat'
    concat_df['method'] = 'ts'
    ts_df_list.append(concat_df)

    display_df = pd.concat([ref_df] + random_df_list + ts_df_list)
    display_df.cycle = pd.Categorical(display_df.cycle,
                                      categories=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "concat", "ref"],
                                      ordered=True)
    display_df.method = pd.Categorical(display_df.method,
                                       categories=["random", "ts", "ref"], ordered=True)
    display_df = display_df.reset_index()
    ax = sns.stripplot(x="cycle", y="score", hue="method", data=display_df, ax=ax, dodge=True,
                       palette={"ref": colors[1], "ts": colors[0], "random": colors[2]})
    ax.legend(loc='upper left', bbox_to_anchor=(1.00, 0.75), ncol=1, title='Method')
    handles = ax.legend_.legend_handles
    for h in handles:
        h.set_markersize(15)
    ax.set_ylabel("Tanimoto Coefficient (Bigger is Better)")
    ax.set_xlabel(None)
    ax.set_xticks([])
    return display_df


def plot_random_recovery_barplot(combo_df, ax=None, xlabel="Replicate"):
    colors = sns.color_palette("tab10")[1:4]
    ref_df = combo_df.query("method == 'ref'")
    match_list = []
    for i in combo_df.cycle.unique():
        if i == "ref":
            continue
        for method in ["random", "ts"]:
            current_df = combo_df.query("cycle == @i and method == @method and  SMILES in @ref_df.SMILES.values")
            match_list.append([i, len(current_df), method])
    match_list.append(["ref", 100, "ref"])
    match_df = pd.DataFrame(match_list, columns=['cycle', 'count', 'method'])
    match_df.cycle = pd.Categorical(match_df.cycle,
                                    categories=["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "concat", "ref"],
                                    ordered=True)
    match_df.method = pd.Categorical(match_df.method,
                                     categories=["random", "ts", "ref"], ordered=True)
    ax = sns.barplot(x="cycle", y="count", hue="method", dodge=True, data=match_df,
                     palette={"ref": colors[1], "ts": colors[0], "random": colors[2]})
    ax.legend(loc='upper left', bbox_to_anchor=(1.00, 0.75), ncol=1, title='Method')
    labels = match_df.sort_values(['cycle', 'method'], ascending=[True, True])["count"].values

    x_val = -0.39
    for i in range(0, 11):
        ax.text(x_val + 0.0, 50, labels[i * 2], color="black", fontweight="bold")
        ax.text(x_val + 0.28, 50, labels[i * 2 + 1], color="white", fontweight="bold")
        x_val = x_val + 1
    ax.text(x_val + 0.50, 50, 100, color="white", fontweight="bold")
    ax.set_ylabel("Number of Top 100 Molecules Found")
    ax.set_xlabel(xlabel)
    ax.legend_.remove()
    plt.tight_layout()


def plot_scores_compare(
        dfs,
        labels,
        title="Comparison of Molecule Scores",
        xlabel="Score 1",
        ylabel="Score 2",
        save_png=False,
        save_file=None,
        set_title=True
):
    datasets = []

    for df, label in zip(dfs, labels):
        df_local = df.copy()
        df_local['score'] = df_local['score'].apply(parse_score)
        scores_df = np.stack(df_local['score'].values)
        if scores_df.ndim != 2 or scores_df.shape[1] != 2:
            raise ValueError("DataFrame score arrays must be 2D with shape (n, 2).")
        datasets.append((scores_df, label))

    plt.figure(figsize=(7, 5))

    markers = ['o', 's', '^', 'x']
    colors = ['C0', 'C1', 'C2', 'C3']
    for i, (scores, label) in enumerate(datasets):
        plt.scatter(scores[:, 0], scores[:, 1],
                   alpha=0.7,
                   label=label,
                   marker=markers[i % len(markers)],
                   color=colors[i % len(colors)])

    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    if set_title and not save_png:
        plt.title(title)
    plt.grid(True)
    plt.legend()

    if save_png and save_file is not None:
        plt.savefig(save_file, format="png", dpi=300)

    plt.show()


def plot_pareto_percent_learning_curves(
        dfs_list,
        method_names,
        pareto_smiles,
        steps,
        xlabel='Number of Search Steps',
        ylabel='Percent of Pareto Front Discovered',
        title='Pareto Optimal Molecules Rediscovery Learning Curve',
        figsize=(8, 6),
        ylim=None,
        set_title=True,
        save_png=False,
        save_file=None,
):
    """
    Plot learning curves for multiple methods showing % of Pareto front discovered over steps.

    Parameters:
    - dfs_list: List of lists of dataframes (one list per method, multiple replicates each).
    - method_names: List of method names (strings) for each entry in dfs_list.
    - pareto_smiles: Set of Pareto-optimal SMILES strings.
    - steps: Array of step numbers (used for plotting x-axis).
    - xlabel/ylabel/title: Labels for the plot.
    - figsize: Figure size.
    - ylim: Optional y-axis limits.
    """
    num_pareto = len(pareto_smiles)
    learning_curves = []

    for df_list in dfs_list:
        method_learning_curves = []
        for df in df_list:
            discovered = set()
            method_curve = []
            for _, row in df.iterrows():
                smiles = row['SMILES']
                if smiles in pareto_smiles:
                    discovered.add(smiles)
                percent_found = len(discovered) / num_pareto * 100
                method_curve.append(percent_found)
            method_learning_curves.append(method_curve)
        learning_curves.append(method_learning_curves)

    plt.figure(figsize=figsize)
    for i, (method_curve, method_name) in enumerate(zip(learning_curves, method_names)):
        mean_curve = np.mean(method_curve, axis=0)
        std_curve = np.std(method_curve, axis=0)
        plt.plot(steps, mean_curve, label=method_name)
        ci = 1.96 * std_curve / np.sqrt(len(method_curve))
        plt.fill_between(steps, mean_curve - ci, mean_curve + ci, alpha=0.2)

    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    if set_title and not save_png:
        plt.title(title)
    plt.grid()
    plt.legend(loc='best')
    if ylim:
        plt.ylim(ylim)
    plt.tight_layout()
    if save_png and save_file is not None:
        plt.savefig(save_file, format="png", dpi=300)
    plt.show()


def plot_pareto_jaccard_learning_curves(
        dfs_list,
        method_names,
        pareto_smiles,
        steps,
        xlabel='Number of Search Steps',
        ylabel='Jaccard Metric',
        title='Pareto Optimal Molecules Rediscovery Learning Curve',
        figsize=(8, 6),
        ylim=None,
        set_title=True,
        save_png=False,
        save_file=None,
):
    """
    Plot learning curves for multiple methods showing the evolution of the Jaccard similarity between the optimal discovered molecules and the ground truth over steps.

    Parameters:
    - dfs_list: List of lists of dataframes (one list per method, multiple replicates each).
    - method_names: List of method names (strings) for each entry in dfs_list.
    - pareto_smiles: Set of Pareto-optimal SMILES strings.
    - steps: Array of step numbers (used for plotting x-axis).
    - xlabel/ylabel/title: Labels for the plot.
    - figsize: Figure size.
    - ylim: Optional y-axis limits.
    """
    learning_curves = []

    for df_list in dfs_list:
        method_learning_curves = []
        for df in df_list:
            pareto_scores = []      # list of np.array
            pareto_indices = []     # list of ints
            method_curve = []
            for i, row in df.iterrows():
                score = row['score_array']
                pareto_scores.append(score)
                pareto_indices.append(i)
                scores_array = np.vstack(pareto_scores)
                optimal_indices = get_pareto_indices(scores_array)
                pareto_indices = [pareto_indices[i] for i in optimal_indices]
                pareto_scores = [scores_array[i] for i in optimal_indices]
                discovered_smiles = set(df.iloc[pareto_indices]['SMILES'].values)
                method_curve.append(jaccard_metric(discovered_smiles, pareto_smiles))
            method_learning_curves.append(method_curve)
            method_learning_curves.append(method_curve)
        learning_curves.append(method_learning_curves)

    plt.figure(figsize=figsize)
    for i, (method_curve, method_name) in enumerate(zip(learning_curves, method_names)):
        mean_curve = np.mean(method_curve, axis=0)
        std_curve = np.std(method_curve, axis=0)
        plt.plot(steps, mean_curve, label=method_name)
        ci = 1.96 * std_curve / np.sqrt(len(method_curve))
        plt.fill_between(steps, mean_curve - ci, mean_curve + ci, alpha=0.2)

    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    if set_title and not save_png:
        plt.title(title)
    plt.grid()
    plt.legend(loc='best')
    if ylim:
        plt.ylim(ylim)
    plt.tight_layout()
    if save_png and save_file is not None:
        plt.savefig(save_file, format="png", dpi=300)
    plt.show()


def plot_pareto_jaccard_and_percent_learning_curves(
        dfs_list,
        method_names,
        pareto_smiles,
        steps,
        xlabel='Number of Search Steps',
        left_ylabel='Jaccard Metric',
        right_ylabel='Percent of Pareto Front Discovered',
        title='Pareto Optimal Molecules Rediscovery Learning Curves',
        figsize=(8, 6),
        jaccard_ylim=(0, 1),
        percent_ylim=(0, 100),
        set_title=True,
        save_png=False,
        save_file=None
):
    """
    Plot both Jaccard and % Pareto rediscovery learning curves on one figure
    with two y-axes.
    """
    num_pareto = len(pareto_smiles)

    # ---- precompute curves for each method ----
    jaccard_curves_all_methods = []
    percent_curves_all_methods = []

    for df_list in dfs_list:
        method_jaccard_curves = []
        method_percent_curves = []

        for df in df_list:
            # percent learning curve
            discovered = set()
            percent_curve = []

            # jaccard learning curve
            pareto_scores = []
            pareto_indices = []
            jaccard_curve = []

            for i, row in df.iterrows():
                # percent curve
                smiles = row['SMILES']
                if smiles in pareto_smiles:
                    discovered.add(smiles)
                percent_found = len(discovered) / num_pareto * 100.0
                percent_curve.append(percent_found)

                # jaccard curve
                score = row['score_array']
                pareto_scores.append(score)
                pareto_indices.append(i)

                scores_array = np.vstack(pareto_scores)
                optimal_indices = get_pareto_indices(scores_array)
                pareto_indices = [pareto_indices[k] for k in optimal_indices]
                pareto_scores = [scores_array[k] for k in optimal_indices]

                discovered_smiles = set(df.iloc[pareto_indices]['SMILES'].values)
                jaccard_curve.append(jaccard_metric(discovered_smiles, pareto_smiles))

            method_percent_curves.append(percent_curve)
            method_jaccard_curves.append(jaccard_curve)

        jaccard_curves_all_methods.append(method_jaccard_curves)
        percent_curves_all_methods.append(method_percent_curves)

    # ---- plotting with twin y-axes ----
    fig, ax_left = plt.subplots(figsize=figsize)
    ax_right = ax_left.twinx()

    # style cycles so corresponding methods share color but different linestyle
    color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']

    for idx, method_name in enumerate(method_names):
        color = color_cycle[idx % len(color_cycle)]

        # Jaccard (left axis)
        jaccard_curves = jaccard_curves_all_methods[idx]
        j_mean = np.mean(jaccard_curves, axis=0)
        j_std = np.std(jaccard_curves, axis=0)
        j_ci = 1.96 * j_std / np.sqrt(len(jaccard_curves))
        ax_left.plot(steps, j_mean, color=color, linestyle='-', label=f'{method_name} (Jaccard)')
        ax_left.fill_between(steps, j_mean - j_ci, j_mean + j_ci,
                             color=color, alpha=0.15)

        # Percent (right axis)
        p_curves = percent_curves_all_methods[idx]
        p_mean = np.mean(p_curves, axis=0)
        p_std = np.std(p_curves, axis=0)
        p_ci = 1.96 * p_std / np.sqrt(len(p_curves))
        ax_right.plot(steps, p_mean, color=color, linestyle='--', label=f'{method_name} (%)')
        ax_right.fill_between(steps, p_mean - p_ci, p_mean + p_ci,
                              color=color, alpha=0.10)

    ax_left.set_xlabel(xlabel)
    ax_left.set_ylabel(left_ylabel)
    ax_right.set_ylabel(right_ylabel)

    if jaccard_ylim is not None:
        ax_left.set_ylim(jaccard_ylim)
    if percent_ylim is not None:
        ax_right.set_ylim(percent_ylim)

    ax_left.grid(True)

    # build a joint legend
    lines_left, labels_left = ax_left.get_legend_handles_labels()
    lines_right, labels_right = ax_right.get_legend_handles_labels()
    ax_left.legend(lines_left + lines_right, labels_left + labels_right,
                   loc='best')

    if set_title and not save_png:
        ax_left.set_title(title)
    fig.tight_layout()
    if save_png and save_file is not None:
        plt.savefig(save_file, format='png', dpi=300)
    plt.show()


def plot_pareto_component_metrics(dfs_per_comp, component_names, metrics):
    num_components = len(component_names)
    fig, axes = plt.subplots(len(metrics), num_components, figsize=(6 * num_components, 5 * len(metrics)))

    for i, metric in enumerate(metrics):
        for comp in range(num_components):
            ax = axes[i][comp]

            metrics_for_comp = [df[comp][metric] for df in dfs_per_comp]
            mean_metric = np.mean(metrics_for_comp, axis=0)
            std_metric = np.std(metrics_for_comp, axis=0)

            ax.plot(mean_metric, label=f'Mean {metric}')
            ax.fill_between(range(len(mean_metric)), mean_metric - std_metric, mean_metric + std_metric, alpha=0.3)

            ax.set_title(f'{metric} for {component_names[comp]}')
            ax.set_xlabel('Step')
            ax.set_ylabel(metric)
            ax.set_ylim([0, 1])
            ax.legend()

    plt.tight_layout()
    plt.show()


def plot_pareto_jaccard_and_uncertainty_learning_curves(
        dfs_list,
        method_names,
        pareto_smiles,
        steps,
        uncertainty_df,
        xlabel='Search Steps',
        left_ylabel='Jaccard Metric',
        right_ylabel='Uncertainty (Bhattacharyya Avg)',
        title='Pareto Optimal Molecules Rediscovery Learning Curves',
        figsize=(8, 6),
        jaccard_ylim=(0, 1),
        uncertainty_ylim=(0, 0.008),
        set_title=True,
        save_png=False,
        save_file=None
):
    """
    Plot both Jaccard and % Pareto rediscovery learning curves on one figure
    with two y-axes.
    """
    num_pareto = len(pareto_smiles)

    # ---- precompute curves for each method ----
    jaccard_curves_all_methods = []

    for df_list in dfs_list:
        method_jaccard_curves = []

        for df in df_list:
            pareto_scores = []
            pareto_indices = []
            jaccard_curve = []

            for i, row in df.iterrows():
                score = row['score_array']
                pareto_scores.append(score)
                pareto_indices.append(i)

                scores_array = np.vstack(pareto_scores)
                optimal_indices = get_pareto_indices(scores_array)
                pareto_indices = [pareto_indices[k] for k in optimal_indices]
                pareto_scores = [scores_array[k] for k in optimal_indices]

                discovered_smiles = set(df.iloc[pareto_indices]['SMILES'].values)
                jaccard_curve.append(jaccard_metric(discovered_smiles, pareto_smiles))

            method_jaccard_curves.append(jaccard_curve)

        jaccard_curves_all_methods.append(method_jaccard_curves)

    # ---- plotting with twin y-axes ----
    fig, ax_left = plt.subplots(figsize=figsize)
    ax_right = ax_left.twinx()

    for idx, method_name in enumerate(method_names):

        # Jaccard (left axis)
        jaccard_curves = jaccard_curves_all_methods[idx]
        j_mean = np.mean(jaccard_curves, axis=0)
        j_std = np.std(jaccard_curves, axis=0)
        j_ci = 1.96 * j_std / np.sqrt(len(jaccard_curves))
        linewidth = 2.5 if method_name == "Random" else 1.5
        ax_left.plot(steps, j_mean, linestyle='-', label=f'{method_name} (Jaccard)', linewidth=linewidth)
        ax_left.fill_between(steps, j_mean - j_ci, j_mean + j_ci, alpha=0.15)

    stats = uncertainty_df.groupby("timestep")["uncertainty"].agg(["mean", "std", "count"])
    stats["ci"] = 1.96 * stats["std"] / np.sqrt(stats["count"])
    ax_right.plot(stats.index, stats["mean"], label="TTPFTS (Uncertainty)", linestyle='--', color='tab:green')
    ax_right.fill_between(stats.index, stats["mean"] - stats["ci"], stats["mean"] + stats["ci"], alpha=0.15, color='tab:green')

    ax_left.set_xlabel(xlabel)
    ax_left.set_ylabel(left_ylabel)
    ax_right.set_ylabel(right_ylabel)

    if jaccard_ylim is not None:
        ax_left.set_ylim(jaccard_ylim)
    if uncertainty_ylim is not None:
        ax_right.set_ylim(uncertainty_ylim)

    ax_left.grid(True)

    # build a joint legend
    lines_left, labels_left = ax_left.get_legend_handles_labels()
    lines_right, labels_right = ax_right.get_legend_handles_labels()
    ax_left.legend(lines_left + lines_right, labels_left + labels_right,
                   loc='best')

    if set_title and not save_png:
        ax_left.set_title(title)
    fig.tight_layout()
    if save_png and save_file is not None:
        plt.savefig(save_file, format='png', dpi=500)
    plt.show()
