from __future__ import annotations

import argparse
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


def main() -> None:
    ap = argparse.ArgumentParser(description="Toy plot with two horizontal dashed lines (AR above, Independent below)")
    ap.add_argument("--out", type=str, default="", help="Optional output filepath (PNG/PDF)")
    ap.add_argument("--title", type=str, default="", help="Figure suptitle (big title)")
    ap.add_argument("--title-size", type=float, default=28.0, help="Suptitle font size")
    ap.add_argument("--title-y", type=float, default=0.97, help="Suptitle vertical position [0-1]")
    ap.add_argument("--font", type=str, default="", help="Preferred sans-serif font family name")
    args = ap.parse_args()

    # Build preferred font stack
    _sans_stack = [
        "Inter",
        "SF Pro Display",
        "Segoe UI",
        "Helvetica Neue",
        "Helvetica",
        "Arial",
        "DejaVu Sans",
    ]
    if args.font:
        _sans_stack = [args.font] + [f for f in _sans_stack if f.lower() != args.font.lower()]

    # Consistent style (white background), no grids, larger fonts
    sns.set_theme(
        style="white",
        font="sans-serif",
        font_scale=2.0,
        rc={
            "axes.facecolor": "#ffffff",
            "figure.facecolor": "#ffffff",
            "grid.linestyle": "",
            "axes.grid": False,
            # Prefer nicer sans-serif fonts with graceful fallback
            "font.family": "sans-serif",
            "font.sans-serif": _sans_stack,
            "axes.titleweight": "bold",
        },
    )

    # Paired palette, indices consistent with other plots
    paired_palette_rgb = [
        (166, 206, 227), (31, 120, 180),  # 0,1
        (178, 223, 138), (51, 160, 44),   # 2,3
        (251, 154, 153), (227, 26, 28),   # 4,5
        (253, 191, 111), (255, 127, 0),   # 6,7
        (202, 178, 214), (106, 61, 154),  # 8,9
        (255, 255, 153), (177, 89, 40),   # 10,11
    ]
    palette = [(r/255.0, g/255.0, b/255.0) for (r, g, b) in paired_palette_rgb]

    # K from 0 to 32
    K = np.arange(0, 33, dtype=float)

    # Two horizontal lines: AR above, Independent below
    y_ar = np.full_like(K, 1.0)
    y_ind = np.full_like(K, 0.7)

    # A third line that starts near AR and approaches Independent as K increases
    # Decay chosen so that at K=2 we're halfway; enforce a minimum gap above Independent
    decay = np.exp(-(np.log(2.0) / 2.0) * K)
    diff = (y_ar - y_ind)
    diff_min = 0.08  # keep ours noticeably above independent for all K
    y_ours = y_ind + np.maximum(diff * decay, diff_min)

    # Wider figure for readability
    fig, ax = plt.subplots(figsize=(14, 7))

    # Dataset color mapping
    dataset_colors = {
        'cars': 'tab:green',     # explicitly green
        'airplanes': 'tab:blue',
        'electric': 'tab:orange',
        'bikes': 'tab:red',
    }

    # Dataset-specific vertical offsets so each dataset's lines are at different values
    offsets = {
        'cars': 0.00,
        'airplanes': 0.03,
        'electric': -0.02,
        'bikes': 0.06,
    }

    # Repeat the three lines for each dataset with its color and offset
    for name, color in dataset_colors.items():
        off = offsets.get(name, 0.0)
        y_ar_d = y_ar + off
        y_ind_d = y_ind + off
        # Recompute "ours" per dataset to stay between and above indep
        decay = np.exp(-(np.log(2.0) / 2.0) * K)
        diff_d = (y_ar_d - y_ind_d)
        diff_min = 0.08
        y_ours_d = y_ind_d + np.maximum(diff_d * decay, diff_min)

        # AR (above): dashed '--'
        ax.plot(
            K, y_ar_d,
            linestyle='--', linewidth=8.0, alpha=0.8,
            color=color
        )
        # Independent (below): dotted ':'
        ax.plot(
            K, y_ind_d,
            linestyle=':', linewidth=8.0, alpha=0.8,
            color=color
        )
        # Ours: solid line with markers at integer K
        ax.plot(
            K, y_ours_d,
            linestyle='-', marker='o', markersize=12, linewidth=8.0, alpha=0.8,
            markeredgecolor='none', markeredgewidth=0.0,
            color=color
        )

    # K axis: 0..32
    ax.set_xlim(0, 32)
    ax.set_xticks(list(range(0, 33, 4)))
    ax.set_xlabel('$K$')
    ax.set_ylabel('Log-likelihood')

    # Clean spines
    for spine in ['left', 'bottom']:
        ax.spines[spine].set_linewidth(2)
        ax.spines[spine].set_color('black')
    for spine in ['right', 'top']:
        ax.spines[spine].set_visible(False)

    # Build two legends: one for line styles, one for datasets (colors)
    from matplotlib.lines import Line2D
    style_handles = [
        Line2D([0], [0], color='black', linestyle='--', linewidth=6, label='AR'),
        Line2D([0], [0], color='black', linestyle=':', linewidth=6, label='Independent'),
        Line2D([0], [0], color='black', linestyle='-', marker='o', markersize=8, linewidth=6, label='Ours'),
    ]
    dataset_handles = [
        Line2D([0], [0], color=c, linestyle='-', linewidth=6, label=n.title())
        for n, c in dataset_colors.items()
    ]

    # Leave room on the right for legends and optionally room at the top for a suptitle
    fig.subplots_adjust(right=0.72)
    if args.title:
        fig.suptitle(args.title, fontsize=args.title_size, fontweight='bold', y=args.title_y)
        # Lift content a bit to avoid overlap with suptitle
        fig.tight_layout(rect=[0, 0, 1, 0.92])
    # Place line-type legend to the right of the plot
    leg1 = ax.legend(handles=style_handles, loc='center left', bbox_to_anchor=(1.02, 0.65), frameon=False, title='Line type')
    ax.add_artist(leg1)
    # Dataset legend below it, also to the right
    ax.legend(handles=dataset_handles, loc='center left', bbox_to_anchor=(1.02, 0.25), frameon=False, title='Dataset')
    elif True:
        fig.tight_layout()

    if args.out:
        fig.savefig(args.out, bbox_inches='tight')
        print(f"Saved: {args.out}")
    else:
        plt.show()


if __name__ == "__main__":
    main()
