import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
import statsmodels.api as sm
import itertools
import numpy as np
import matplotlib.cm as cm
import matplotlib.gridspec as gridspec


def plot_dual_axis_scatter_with_smooth_gridspec(
    df_merged,
    x_key: str,
    y1_key: str,
    y2_key: str,
    source_col: str = "source",
    repeat_col: str = "repeat",
    figsize=(8, 7),
    frac=0.3,
    title: str = None,
    alpha: float = 0.8,
    colors_y1: dict = None,
    colors_y2: dict = None,
    lowess_colors: tuple = None,
    y1_label: str = None,
    y2_label: str = None,
    x_label: str = None,
    save_path: str = None
):
    sources = sorted(df_merged[source_col].unique())
    repeats = sorted(df_merged[repeat_col].unique())
    marker_cycle = itertools.cycle(['o', 's', 'D', '^', 'v', '<', '>', 'P', 'X', '*'])
    repeat_markers = {rep: m for rep, m in zip(repeats, marker_cycle)}

    # Color palettes
    if colors_y1 is None:
        cmap1 = cm.get_cmap("Blues", len(sources) + 2)
        colors_y1 = {src: cmap1(i + 1) for i, src in enumerate(sources)}
    if colors_y2 is None:
        cmap2 = cm.get_cmap("Oranges", len(sources) + 2)
        colors_y2 = {src: cmap2(i + 1) for i, src in enumerate(sources)}
    if lowess_colors is None:
        lowess_colors = ("navy", "darkorange")

    # --- Set up GridSpec ---
    fig = plt.figure(figsize=figsize, constrained_layout=True)
    gs = gridspec.GridSpec(nrows=2, ncols=2, height_ratios=[10, 1], width_ratios=[6, 1], figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = ax1.twinx()

    # --- Left Y-axis (y1) ---
    if x_label is None:
        x_label = x_key
    if y1_label is None:
        y1_label = rf"$\bf{{{y1_key.replace('_', r'\ ').upper()}}}$"
    ax1.set_xlabel(x_label)
    ax1.set_ylabel(y1_label, color=lowess_colors[0], labelpad=10)
    for src in sources:
        for rep in repeats:
            subset = df_merged[(df_merged[source_col] == src) & (df_merged[repeat_col] == rep)]
            ax1.scatter(
                subset[x_key], subset[y1_key],
                color=colors_y1[src],
                marker=repeat_markers[rep],
                alpha=alpha
            )
    y1_smooth = sm.nonparametric.lowess(df_merged[y1_key], df_merged[x_key], frac=frac)
    ax1.plot(y1_smooth[:, 0], y1_smooth[:, 1], color=lowess_colors[0],
             linestyle='--', linewidth=2, label=f"{y1_key} (LOWESS)")
    ax1.tick_params(axis='y', labelcolor=lowess_colors[0])

    # --- Right Y-axis (y2) ---
    if y2_label is None:
        y2_label = rf"$\bf{{{y2_key.replace('_', r'\ ').upper()}}}$"
    ax2.set_ylabel(y2_label, color=lowess_colors[1], labelpad=10)
    ax2.yaxis.set_label_coords(1.08, 0.5)
    for src in sources:
        for rep in repeats:
            subset = df_merged[(df_merged[source_col] == src) & (df_merged[repeat_col] == rep)]
            ax2.scatter(
                subset[x_key], subset[y2_key],
                color=colors_y2[src],
                marker=repeat_markers[rep],
                alpha=alpha
            )
    y2_smooth = sm.nonparametric.lowess(df_merged[y2_key], df_merged[x_key], frac=frac)
    ax2.plot(y2_smooth[:, 0], y2_smooth[:, 1], color=lowess_colors[1],
             linestyle='--', linewidth=2, label=f"{y2_key} (LOWESS)")
    ax2.tick_params(axis='y', labelcolor=lowess_colors[1])
    
    # --- Vertical line at obs_per_agent = 1 ---
    vline_x = set(df_merged[df_merged["fraction"]==1]['obs_per_agent']).pop()
    ax1.axvline(x=vline_x, color="dimgray", linestyle=":", linewidth=1.5)

    # Determine y-position for arrows and text
    ylim = ax1.get_ylim()
    arrow_y = ylim[1] - 0.1 * (ylim[1] - ylim[0])      # arrow baseline (near top)
    text_y = arrow_y - 0.05 * (ylim[1] - ylim[0])       # text slightly below arrow

    # --- Left arrow (horizontal, pointing left)
    ax1.annotate(
        "",  # no inline text
        xy=(vline_x - 0.3, arrow_y),        # arrow tip
        xytext=(vline_x - 0.02, arrow_y),   # arrow tail (near vline)
        textcoords='data',
        arrowprops=dict(arrowstyle="->", color="dimgray"),
    )

    ax1.text(
        x=(vline_x - 0.2), y=text_y,
        s="subsampling\nquestions",
        fontsize=9,
        color="dimgray",
        fontstyle="italic",
        ha="center",
        va="top"
    )

    # --- Right arrow (horizontal, pointing right)
    ax1.annotate(
        "",
        xy=(vline_x + 0.3, arrow_y),        # arrow tip
        xytext=(vline_x + 0.02, arrow_y),   # arrow tail (near vline)
        textcoords='data',
        arrowprops=dict(arrowstyle="->", color="dimgray"),
    )

    ax1.text(
        x=(vline_x + 0.2), y=text_y,
        s="subsampling\nendowments",
        fontsize=9,
        color="dimgray",
        fontstyle="italic",
        ha="center",
        va="top"
    )

    # --- Build legend handles ---
    lowess_handle_y1 = mlines.Line2D([], [], color=lowess_colors[0], linestyle='--', label=f"{y1_label} (LOWESS)")
    lowess_handle_y2 = mlines.Line2D([], [], color=lowess_colors[1], linestyle='--', label=f"{y2_label} (LOWESS)")
    repeat_handles = [
        mlines.Line2D([], [], color="grey", marker=repeat_markers[rep], linestyle='None', label=f"{rep}")
        for rep in repeats
    ]

    # --- Bottom stacked legend (row 1, col 0) ---
    ax_legend = fig.add_subplot(gs[1, 0])
    ax_legend.axis('off')

    legend1 = ax_legend.legend(
        handles=[lowess_handle_y1, lowess_handle_y2],
        loc="center",
        ncol=2,
        frameon=False,
        handletextpad=0.5,
        columnspacing=2.0,
        fontsize=10
    )
    ax_legend.add_artist(legend1)

    # --- Right-side vertical legend for repeat markers (row 0, col 1) ---
    ax_legend_right = fig.add_subplot(gs[0, 1])
    ax_legend_right.axis('off')

    legend2 = ax_legend_right.legend(
        handles=repeat_handles,
        loc="center",
        frameon=False,
        title=r"$\bf{REPEAT}$",
        handletextpad=0.6,
        borderaxespad=0.0,
        ncol=1
    )
    legend2.set_title("Repeat", prop={'weight': 'bold'})
    ax_legend_right.add_artist(legend2)

    # --- Title ---
    if title is None:
        title = f"{y1_key} vs {y2_key} across {x_key}"
    ax1.set_title(title, y=1.01)

    plt.tight_layout()
    if save_path:
        fig.savefig(save_path, dpi=300, bbox_inches='tight', transparent=True)
    plt.show()