#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Radar chart for Merge-{1,3,6,9} across 4 quadrants × 5 dims.
Saves a PNG to ./figs/radar_merges.png at 300 dpi.
"""

from __future__ import annotations

from pathlib import Path
from typing import Tuple, Sequence

import numpy as np
import pandas as pd  # noqa: F401  # kept if later needed
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import patheffects as pe
from matplotlib.lines import Line2D

# =========================
# Global style (NeurIPS-like)
# =========================
mpl.rcParams.update(
    {
        "figure.dpi": 160,
        "font.size": 11.2,
        "axes.titlesize": 14,
        "axes.labelsize": 11.2,
        "legend.fontsize": 10.5,
        "xtick.labelsize": 10.2,
        "ytick.labelsize": 10.2,
        "axes.linewidth": 0.9,
        "grid.linewidth": 0.7,
        "font.family": "serif",
        "mathtext.fontset": "cm",
        "figure.facecolor": "white",
        "axes.facecolor": "#fcfcfd",
        "savefig.bbox": "tight",
    }
)

merge1_loss = [
    # AVG TA TIES DARE MT
    0.613099042603274, 0.586, 0.586, 0.6126, 0.6130990260444132,
    0.5560376859487295, 0.5393, 0.528556, 0.5548, 0.5560376859487294,
    0.5173048848081686, 0.5036, 0.4956623, 0.5165, 0.5173055070441479,
    0.5275673645302832, 0.5068, 0.4872287, 0.5285, 0.5275673645302832
]

merge3_loss = [
    # AVG TA TIES DARE MT
    0.5581826633767137, 0.548, 0.542679, 0.5581, 0.5518,
    0.49121946628097446, 0.4886, 0.48749, 0.4915, 0.4890,
    0.4704905380850974, 0.4675, 0.4655754, 0.4708, 0.4794,
    0.4641109168273429, 0.4589, 0.4539, 0.4644, 0.4622
]

merge6_loss = [
    # AVG TA TIES DARE MT
    0.5473081616059823, 0.5399, 0.53912430, 0.5474, 0.5286,
    0.4755628969548889, 0.4794, 0.4705158, 0.4759, 0.4506,
    0.4576205182769025, 0.4601, 0.4511964, 0.4580, 0.4382,
    0.44952614152268744, 0.4501, 0.4406559, 0.4499, 0.4369
]

merge9_loss = [
    # AVG TA TIES DARE MT
    0.543777, 0.5381, 0.53513, 0.5440, 0.49378053118842496,
    0.47000209482221295, 0.4766, 0.4599, 0.4692, 0.4244308040733097,
    0.4529993244926477, 0.4586, 0.4427, 0.4523, 0.40635109124937585, 
    0.4437873007086783, 0.4469, 0.43272, 0.4437, 0.41101645017534755
]

def loss2score(x):
    x = -np.array(x)
    # x = (1 / x) * 0.4
    x = (x+0.65) / (0.26)
    return x

# =========================
# Data generation
# =========================
QUADRANTS: Sequence[str] = ("7B", "14B", "32B", "72B")
DIMS: Sequence[str] = ("AVG", "TA", "TIES", "DARE", "MultiTask")
Q, D = len(QUADRANTS), len(DIMS)
N = Q * D
ANGLES = np.linspace(0, 2 * np.pi, N, endpoint=False)

RNG = np.random.default_rng(20240921)

# task difficulty per dim (harder → lower base, more gain potential)
TASK_DIFFICULTY = np.array([0.10, 0.18, 0.25, 0.22, 0.15])
TASK_DIFFICULTY = np.clip(TASK_DIFFICULTY, 0.0, 0.35)


def base_for_quadrant(q_index: int) -> float:
    """Larger models have higher baseline."""
    return 0.48 + 0.10 * q_index + 0.02 * (q_index > 1)


def make_block(q_index: int) -> np.ndarray:
    base = base_for_quadrant(q_index)
    shared = RNG.normal(0, 0.022)
    raw = base - TASK_DIFFICULTY + shared + RNG.normal(0, 0.030, size=D)
    return np.clip(raw, 0.18, 0.97)


def close_polygon(values: np.ndarray, angles: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Close a polar polygon by repeating first point."""
    return np.r_[values, values[0]], np.r_[angles, angles[0]]


def scores_for_merge(
    merge_k: int, merge1_scores: np.ndarray, base_gain: np.ndarray
) -> np.ndarray:
    """
    Strictly monotonic improvements:
    linear scale: Merge-1 → 0; Merge-3 → 2/8; Merge-6 → 5/8; Merge-9 → 8/8
    """
    factor = (merge_k - 1) / 4.0
    return np.clip(merge1_scores + base_gain * factor, 0.0, 1.0)


# ---- generate per-point baseline (Merge-1) ----
# merge1_scores = np.concatenate([make_block(q) for q in range(Q)])
merge1_scores = loss2score(merge1_loss)
# per-point potential gain (harder tasks → slightly larger gains)
gain_mean = 0.06 + 0.12 * (TASK_DIFFICULTY - TASK_DIFFICULTY.min()) / (
    np.ptp(TASK_DIFFICULTY) + 1e-6
)
base_gain = np.concatenate(
    [
        np.clip(RNG.normal(gain_mean, 0.012, size=D), 0.04, 0.16)
        for _ in range(Q)
    ]
)

# merge3_scores = scores_for_merge(3, merge1_scores, base_gain)
# merge6_scores = scores_for_merge(6, merge1_scores, base_gain)
# merge9_scores = scores_for_merge(9, merge1_scores, base_gain)
merge3_scores = loss2score(merge3_loss)
merge6_scores = loss2score(merge6_loss)
merge9_scores = loss2score(merge9_loss)

m1_vals, m1_ang = close_polygon(merge1_scores, ANGLES)
m3_vals, m3_ang = close_polygon(merge3_scores, ANGLES)
m6_vals, m6_ang = close_polygon(merge6_scores, ANGLES)
m9_vals, m9_ang = close_polygon(merge9_scores, ANGLES)


# =========================
# Plotting
# =========================
def plot_radar_and_save(
    *,
    quadrants: Sequence[str] = QUADRANTS,
    dims: Sequence[str] = DIMS,
    angles: np.ndarray = ANGLES,
    m1: Tuple[np.ndarray, np.ndarray] = (m1_vals, m1_ang),
    m3: Tuple[np.ndarray, np.ndarray] = (m3_vals, m3_ang),
    m6: Tuple[np.ndarray, np.ndarray] = (m6_vals, m6_ang),
    m9: Tuple[np.ndarray, np.ndarray] = (m9_vals, m9_ang),
    scatter_values: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] = (
        merge1_scores,
        merge3_scores,
        merge6_scores,
        merge9_scores,
    ),
    out_dir: Path | str = "figs",
    out_name: str = "radar_merges.png",
    dpi: int = 300,
) -> Path:
    Q = len(quadrants)
    D = len(dims)
    N = Q * D
    sector_width = 2 * np.pi / Q

    # Figure & axes
    fig = plt.figure(figsize=(7.2, 7.2))
    ax = plt.subplot(111, polar=True)
    ax.set_theta_offset(np.pi / 2)
    ax.set_theta_direction(-1)
    # ax.set_ylim(0, 1.0)

    # Radial ticks
    radial_ticks = [0.2, 0.4, 0.6, 0.8, 1.0]
    ax.set_yticks(radial_ticks)
    ax.set_yticklabels([f"{t:.1f}" for t in radial_ticks], color="gray", fontsize=12)
    ax.yaxis.grid(True, lw=0.7, alpha=0.52, linestyle="--")
    ax.xaxis.grid(False)
    ax.tick_params(axis="y", pad=6)

    # Background sectors (Quadrants)
    # quad_colors = [
    # # "#003f5c",
    # # "#2f4b7c",
    # # "#665191",
    # # "#a05195",
    # # "#d50000",
    # # "#eb5f42",
    # # "#fb977d",
    # # "#ffcbbd"
    # "#008437",
    # "#43a25a",
    # "#6dc07e",
    # "#96dfa3",
    # #adffd5
    # #ffffff
    # #ffe1b9
    # # "#6d0000",
    # # "#923d26",
    # # "#b76c4d",
    # # "#db9b7a",
    # #6d0000
    
    
    
    # #ffccab
    # # "#d45087"
    # # "#ffa600"
    # ][::-1]
    
    #f95d6a
    #ff7c43
    #ffa600
    quad_colors = ["#0d0129", "#433658", "#9d88a7", "#ffe5ff"][::-1]
    rmin, rmax = 0.0, 1.0
    ax.set_ylim(rmin, rmax)

    # Background sectors (Quadrants)
    # for q in range(Q):
    #     mid = q * sector_width + sector_width / 2
    #     ax.bar(
    #         x=mid,
    #         height=rmax - rmin,   # 注意高度用区间长度
    #         width=sector_width * 1,
    #         bottom=rmin,          # 改成rmin而不是0
    #         color=quad_colors[q],
    #         align="center",
    #         edgecolor="none",
    #         alpha=0.15,
    #     )
    # === NEW: Draw quadrant frame lines (arc + two radial lines, NON-overlapping) ===
    quad_colors = [
        # "#43366c",
        # "#74608d",
        # "#a38eb0",
        # "#d2bfd6"
        #00307c
# "#03001e",
"#23283e",
"#444c61",
"#687287",
"#8f9bad",
#b8c6d6
#e4f2ff
        #b1d2e9
        #def3ff
    #fff4ff
    ][::-1]  # 顺时针对应 4 象限
    frame_lw = 6
    frame_alpha = 0.9
    sector_width = 2 * np.pi / Q

    # 关键：对每个象限的左右径向边界，分别向内侧偏移 eps，避免与相邻象限重叠
    eps = sector_width * 0.00  # 约 1.1°；可在 0.008~0.02 间调节

    for q in range(Q):
        a0 = q * sector_width
        a1 = a0 + sector_width
        color = quad_colors[q]

        # 两条径向边界：向内偏移，确保不与相邻象限重叠
        left_boundary  = a0 + eps   # 属于当前象限的“左边界”（向内偏一点）
        right_boundary = a1 - eps   # 属于当前象限的“右边界”（向内偏一点）

        ax.plot([left_boundary,  left_boundary],  [rmin, rmax],
                lw=frame_lw/5, color=color, alpha=frame_alpha,
                solid_capstyle="round", zorder= 0.1, linestyle='-')
        ax.plot([right_boundary, right_boundary], [rmin, rmax],
                lw=frame_lw/5, color=color, alpha=frame_alpha,
                solid_capstyle="round", zorder= 0.1, linestyle='-')

        # 外侧弧（与之前一样，稍微内缩避免与最外圈重叠）
        arc_r = rmax - 0.005
        thetas = np.linspace(a0, a1, 180)
        ax.plot(thetas, np.full_like(thetas, arc_r),
                lw=frame_lw, color=color, alpha=frame_alpha, zorder=4)

    # Dimension rays & sector borders
    for a in angles:
        ax.plot([a, a], [0, 1.0], lw=0.5, alpha=0.16, color="0.2")
    for q in range(Q):
        a = q * sector_width
        ax.plot([a, a], [0, 1.0], lw=1.05, alpha=0.33, color="0.2")

    # Angle ticks (no labels)
    tick_angles = [angles[q * D + d] for q in range(Q) for d in range(D)]
    ax.set_xticks(tick_angles)
    ax.set_xticklabels([""] * N)

    # Dimension labels (slightly outside with stroke)
    label_radius = 1.06
    for ang, lab in zip(tick_angles, [d for _ in range(Q) for d in dims]):
        txt = ax.text(
            ang, label_radius, lab, fontsize=15, ha="center", va="center",
            rotation=0, rotation_mode="anchor", zorder=6
        )
        txt.set_path_effects([pe.withStroke(linewidth=2.5, foreground="w", alpha=0.85)])

    # Quadrant titles
    for q in range(Q):
        mid = q * sector_width + sector_width / 2
        txt = ax.text(
            mid, 1.28, quadrants[q], ha="center", va="center", fontsize=18.5, fontweight="bold",
            color=quad_colors[q]
        )
        txt.set_path_effects([pe.withStroke(linewidth=2.2, foreground="w", alpha=0.9)])

    # Line/marker colors (emphasis grows with merge level)
    c9, c6, c3, c1 = "#004c6d", "#3e768f", "#578ca0", "#70a3b2"

    (m1_vals, m1_ang), (m3_vals, m3_ang), (m6_vals, m6_ang), (m9_vals, m9_ang) = m1, m3, m6, m9
    merge1_scores, merge3_scores, merge6_scores, merge9_scores = scatter_values

    # Merge-1
    ax.fill(m1_ang, m1_vals, alpha=0.18, color=c1)
    ax.scatter(angles, merge1_scores, s=76, alpha=1, color=c1, label="Merge-1")
    ax.plot(m1_ang, m1_vals, lw=3.1, color=c1, alpha=0.45)

    # Merge-3
    ax.plot(m3_ang, m3_vals, lw=3.1, color=c3, alpha=0.55)
    ax.fill(m3_ang, m3_vals, alpha=0.11, color=c3)
    ax.scatter(angles, merge3_scores, s=76, alpha=1, color=c3, label="Merge-3")

    # Merge-6
    ax.plot(m6_ang, m6_vals, lw=3.1, color=c6, alpha=0.55)
    ax.fill(m6_ang, m6_vals, alpha=0.03, color=c6)
    ax.scatter(angles, merge6_scores, s=76, alpha=1, color=c6, label="Merge-6")

    # Merge-9
    ax.plot(m9_ang, m9_vals, lw=3.1, color=c9, alpha=0.55)
    ax.fill(m9_ang, m9_vals, alpha=0.1, color=c9)
    ax.scatter(angles, merge9_scores, s=79, alpha=1, color=c9, zorder=3, label="Merge-9")

    # ---- dashed guides from visual center (rmin) to each dim value ----
    def draw_dim_guides(ax, angles, values, rmin, *,
                        lw=1.2, color="#3a3a3a", alpha=0.65, dashes=(0, (3, 3))):
        for ang, val in zip(angles, values):
            ax.plot([ang, ang], [rmin, val],
                    linestyle=dashes, linewidth=lw, color=color, alpha=alpha, zorder=2)
    # 以 Merge-9 作为对齐基准（也可换成 merge6_scores / merge3_scores / merge1_scores）
    draw_dim_guides(ax, angles, [1]*20, rmin,
                    lw=1.5, color=c1, alpha=0.25, dashes=(0, (2.5, 2.5)))

    handles = []
    for color, label in zip([c1, c3, c6, c9], ["Merge-1", "Merge-3", "Merge-6", "Merge-9"]):  # 升序
        handles.append(Line2D(
            [0], [0],
            marker='o',
            linestyle='None',
            markersize=15.0,
            markerfacecolor=color,
            markeredgecolor="white",
            markeredgewidth=0.8,
            label=label
        ))
    # ---- add thicker outer circle ----
    outer_circle = np.linspace(0, 2 * np.pi, 500)
    ax.plot(outer_circle, np.full_like(outer_circle, rmax),
            color="black", lw=2.2, alpha=0.75, zorder=5)
    # Legend
    ax.legend(
        handles=handles,
        title="Models",
        title_fontsize=15,
        fontsize=13,
        loc="upper center",
        bbox_to_anchor=(0.5, -0.08),
        ncol=4,
        frameon=True,
        fancybox=True,
        framealpha=0.95,
        borderpad=0.6,
        handlelength=1.8,
        handletextpad=0.6,
        labelspacing=0.4,
    )

    plt.tight_layout()

    # Save
    out_path = Path(out_dir)
    out_path.mkdir(parents=True, exist_ok=True)
    save_path = out_path / out_name
    fig.savefig(save_path, dpi=dpi, bbox_inches="tight")
    plt.close(fig)
    return save_path


def main() -> None:
    path = plot_radar_and_save()
    print(f"Saved figure to: {path.resolve()}")


if __name__ == "__main__":
    main()