#!/usr/bin/env python3
"""
Combined plot showing CIR and SR changes for both SR-reward and CIR-reward training.

This combines figure_6 (SR-reward) and figure_7 (CIR-reward) into a single figure
with side-by-side subplots.

Usage:
  python analysis/figure_6_7_combine.py
"""

import json
import os
import re
import numpy as np
import matplotlib.pyplot as plt

# Set font to serif for consistency
plt.rcParams['font.family'] = 'serif'


def get_available_steps(teacher_dir):
    """Automatically detect available training steps from a teacher directory."""
    if not os.path.exists(teacher_dir):
        return []
    steps = []
    for item in os.listdir(teacher_dir):
        if item.startswith('step_'):
            match = re.search(r'step_(\d+)', item)
            if match:
                steps.append(int(match.group(1)))
    return sorted(steps)


def load_metrics_at_step(teacher_dir, step):
    """Load CIR proxy (cot_importance), accuracy, and verifier accuracy from a teacher_responses json."""
    json_path = os.path.join(teacher_dir, f"step_{step}", f"teacher_responses_step_{step}.json")
    if not os.path.exists(json_path):
        return None

    try:
        with open(json_path, 'r') as f:
            data = json.load(f)
    except Exception:
        return None

    cot_importance_values = []
    accuracies = []
    verifier_accuracies = []

    for item in data:
        # CIR proxy: use cot_importance_evaluation js_divergences if present
        if "cot_importance_evaluation" in item:
            eval_data = item["cot_importance_evaluation"]
            js_divs = eval_data.get("js_divergences", [])
            if len(js_divs) >= 2:
                # sample at 0%,10%,...,100% and average
                percentages = [0, 10, 20, 30, 35, 50, 60, 70, 80, 90, 100]
                sampled = []
                for p in percentages:
                    if p == 0:
                        idx = 0
                    else:
                        idx = max(0, min(int((p / 100.0) * len(js_divs)) - 1, len(js_divs) - 1))
                    sampled.append(js_divs[idx])
                cot_importance_values.append(float(np.mean(sampled)))

        # Teacher accuracy (reward_score averaged across k responses)
        if "k_responses" in item:
            k_scores = [kr.get("reward_score") for kr in item["k_responses"] if "reward_score" in kr]
            if k_scores:
                accuracies.append(float(np.mean(k_scores)))

        # SR: verifier accuracy (answers_match with "no answer found" special-case)
        if "verifier_comparison" in item:
            verifier_comp = item["verifier_comparison"]
            answers_match = verifier_comp.get("answers_match", False)

            with_q = verifier_comp.get("with_question_answer", "").strip().lower()
            without_q = verifier_comp.get("without_question_answer", "").strip().lower()

            if answers_match and with_q == "no answer found" and without_q == "no answer found":
                verifier_accuracies.append(0.0)
            else:
                verifier_accuracies.append(1.0 if answers_match else 0.0)

    if not cot_importance_values and not verifier_accuracies and not accuracies:
        return None

    return {
        "cot_importance": float(np.mean(cot_importance_values)) if cot_importance_values else None,
        "accuracy": float(np.mean(accuracies)) if accuracies else None,
        "verifier_accuracy": float(np.mean(verifier_accuracies)) if verifier_accuracies else None,
    }


def load_sr_reward_data(base_path, tasks, verifier_scales, steps_to_keep, scale0_base_path, collect_individual=False):
    """Load data for SR-reward training (figure 6)."""
    all_scale_data = []
    task_scale_data = {task: {} for task in tasks} if collect_individual else None

    for scale in verifier_scales:
        print(f"\n=== Processing SR β={scale} ===")
        step_aggregated = {}

        for task in tasks:
            if scale == 0 or scale == 0.0:
                task_dir = f"{task}_cot_importance/-1.0/teacher"
                full_path = os.path.join(scale0_base_path, task_dir)
            else:
                task_dir = f"cot_verifier_acc_trained_{scale}/{task}/-1.0/teacher"
                full_path = os.path.join(base_path, task_dir)

            if not os.path.exists(full_path):
                print(f"  Warning: Path not found for {task}: {full_path}")
                continue

            training_steps = get_available_steps(full_path)
            if not training_steps:
                print(f"  Warning: No training steps found for {task}")
                continue

            task_step_data = {}
            for step in training_steps:
                if step not in steps_to_keep:
                    continue
                metrics = load_metrics_at_step(full_path, step)
                if metrics:
                    step_aggregated.setdefault(step, []).append(metrics)
                    if collect_individual:
                        task_step_data[step] = metrics

            if collect_individual and task_step_data:
                task_scale_data[task][scale] = task_step_data

        averaged_step_data = {}
        for step, metrics_list in step_aggregated.items():
            cot_vals = [m["cot_importance"] for m in metrics_list if m["cot_importance"] is not None]
            acc_vals = [m["accuracy"] for m in metrics_list if m["accuracy"] is not None]
            ver_vals = [m["verifier_accuracy"] for m in metrics_list if m["verifier_accuracy"] is not None]
            averaged_step_data[step] = {
                "cot_importance": float(np.mean(cot_vals)) if cot_vals else None,
                "cot_importance_std": float(np.std(cot_vals) / np.sqrt(len(cot_vals))) if len(cot_vals) > 0 else None,
                "accuracy": float(np.mean(acc_vals)) if acc_vals else None,
                "accuracy_std": float(np.std(acc_vals) / np.sqrt(len(acc_vals))) if len(acc_vals) > 0 else None,
                "verifier_accuracy": float(np.mean(ver_vals)) if ver_vals else None,
                "verifier_accuracy_std": float(np.std(ver_vals) / np.sqrt(len(ver_vals))) if len(ver_vals) > 0 else None
            }

        if averaged_step_data:
            all_scale_data.append((f"α={scale}", averaged_step_data))
            print(f"  Successfully averaged {len(averaged_step_data)} steps across tasks")
        else:
            print(f"  Warning: No valid data for β={scale}")

    if collect_individual:
        return all_scale_data, task_scale_data
    return all_scale_data


def load_cir_reward_data(base_path, tasks, coefficients, steps_to_keep, coeff0_base_path, collect_individual=False):
    """Load data for CIR-reward training (figure 7)."""
    all_coeff_data = []
    task_coeff_data = {task: {} for task in tasks} if collect_individual else None

    for coeff in coefficients:
        print(f"\n=== Processing CIR α={coeff} ===")
        step_aggregated = {}

        for task in tasks:
            if coeff == 0.0:
                task_dir = f"{task}_cot_importance/-1.0/teacher"
                full_path = os.path.join(coeff0_base_path, task_dir)
            else:
                task_dir = f"cot_importance_trained_{coeff}/{task}/-1.0/teacher"
                full_path = os.path.join(base_path, task_dir)

            if not os.path.exists(full_path):
                print(f"  Warning: Path not found for {task}: {full_path}")
                continue

            training_steps = get_available_steps(full_path)
            if not training_steps:
                print(f"  Warning: No training steps found for {task}")
                continue

            task_step_data = {}
            for step in training_steps:
                if step not in steps_to_keep:
                    continue
                metrics = load_metrics_at_step(full_path, step)
                if metrics:
                    step_aggregated.setdefault(step, []).append(metrics)
                    if collect_individual:
                        task_step_data[step] = metrics

            if collect_individual and task_step_data:
                task_coeff_data[task][coeff] = task_step_data

        averaged_step_data = {}
        for step, metrics_list in step_aggregated.items():
            cot_vals = [m["cot_importance"] for m in metrics_list if m["cot_importance"] is not None]
            acc_vals = [m["accuracy"] for m in metrics_list if m["accuracy"] is not None]
            ver_vals = [m["verifier_accuracy"] for m in metrics_list if m["verifier_accuracy"] is not None]
            averaged_step_data[step] = {
                "cot_importance": float(np.mean(cot_vals)) if cot_vals else None,
                "cot_importance_std": float(np.std(cot_vals) / np.sqrt(len(cot_vals))) if len(cot_vals) > 0 else None,
                "accuracy": float(np.mean(acc_vals)) if acc_vals else None,
                "accuracy_std": float(np.std(acc_vals) / np.sqrt(len(acc_vals))) if len(acc_vals) > 0 else None,
                "verifier_accuracy": float(np.mean(ver_vals)) if ver_vals else None,
                "verifier_accuracy_std": float(np.std(ver_vals) / np.sqrt(len(ver_vals))) if len(ver_vals) > 0 else None
            }

        if averaged_step_data:
            all_coeff_data.append((f"β={coeff}", averaged_step_data))
            print(f"  Successfully averaged {len(averaged_step_data)} steps across tasks")
        else:
            print(f"  Warning: No valid data for α={coeff}")

    if collect_individual:
        return all_coeff_data, task_coeff_data
    return all_coeff_data


def plot_combined(sr_data, cir_data, sr_tasks, cir_tasks, output_dir):
    """Plot combined figure with SR-reward (top row) and CIR-reward (bottom row) as subplots."""
    fig = plt.figure(figsize=(30, 14))

    # Create 2x3 grid: Row 1 = SR-reward (CIR, SR, Acc), Row 2 = CIR-reward (CIR, SR, Acc)
    gs = fig.add_gridspec(2, 3, hspace=0.35, wspace=0.35)

    # More visually distinct color palette
    colors = ["#2E2E2E", "#1E88E5", "#FFA726", "#66BB6A", "#EF5350", "#AB47BC"]
    # Dark gray, blue, orange, green, red, purple

    # Use same marker for all values
    markers = ["o", "o", "o", "o", "o", "o"]

    # Use same color scheme for both SR and CIR
    colors_sr = colors
    markers_sr = markers
    colors_cir = colors
    markers_cir = markers

    # TOP ROW: SR-reward training (Figure 6)
    # Position [0,0]: SR-reward → CIR
    ax1 = fig.add_subplot(gs[0, 0])
    for idx, (scale_name, step_data) in enumerate(sr_data):
        steps = []
        cot_values = []
        for step in sorted(step_data.keys()):
            if step_data[step]["cot_importance"] is not None:
                steps.append(step)
                cot_values.append(step_data[step]["cot_importance"])
        if steps and cot_values:
            color = colors_sr[idx % len(colors_sr)]
            marker = markers_sr[idx % len(markers_sr)]
            ax1.plot(steps, cot_values, linewidth=4.0, color=color, marker=marker,
                     markersize=10, label=scale_name, alpha=0.85, markeredgewidth=1.5,
                     markeredgecolor='white')

    ax1.set_xlabel("RL Training Step", fontsize=35, fontweight="bold", family="serif")
    ax1.set_ylabel("CIR", fontsize=35, fontweight="bold", family="serif")
    ax1.set_title("SR-reward → CIR",
                  fontsize=35, fontweight="bold", pad=15, family="serif")
    ax1.tick_params(axis="both", which="major", labelsize=18, width=1.5)
    ax1.tick_params(axis="both", which="minor", labelsize=16, width=1)
    for label in ax1.get_xticklabels() + ax1.get_yticklabels():
        label.set_family("serif")
    ax1.set_ylim(-0.02, 1.02)
    ax1.grid(True, alpha=0.25, linestyle='-', linewidth=0.8, which='major')
    ax1.grid(True, alpha=0.15, linestyle=':', linewidth=0.5, which='minor')
    ax1.minorticks_on()
    ax1.legend(fontsize=90, loc="upper left", framealpha=0.95, prop={"family": "serif"},
               edgecolor='gray', fancybox=True, shadow=True, markerscale=1)
    ax1.set_axisbelow(True)
    for spine in ax1.spines.values():
        spine.set_linewidth(1.5)

    # Position [0,1]: SR-reward → SR
    ax2 = fig.add_subplot(gs[0, 1])
    for idx, (scale_name, step_data) in enumerate(sr_data):
        steps = []
        verifier_values = []
        for step in sorted(step_data.keys()):
            if step_data[step]["verifier_accuracy"] is not None:
                steps.append(step)
                verifier_values.append(step_data[step]["verifier_accuracy"])
        if steps and verifier_values:
            color = colors_sr[idx % len(colors_sr)]
            marker = markers_sr[idx % len(markers_sr)]
            ax2.plot(steps, verifier_values, linewidth=4.0, color=color, marker=marker,
                     markersize=10, label=scale_name, alpha=0.85, markeredgewidth=1.5,
                     markeredgecolor='white')

    ax2.set_xlabel("RL Training Step", fontsize=35, fontweight="bold", family="serif")
    ax2.set_ylabel("SR", fontsize=35, fontweight="bold", family="serif")
    ax2.set_title("SR-reward → SR",
                  fontsize=35, fontweight="bold", pad=15, family="serif")
    ax2.tick_params(axis="both", which="major", labelsize=18, width=1.5)
    ax2.tick_params(axis="both", which="minor", labelsize=16, width=1)
    for label in ax2.get_xticklabels() + ax2.get_yticklabels():
        label.set_family("serif")
    ax2.set_ylim(-0.02, 1.02)
    ax2.grid(True, alpha=0.25, linestyle='-', linewidth=0.8, which='major')
    ax2.grid(True, alpha=0.15, linestyle=':', linewidth=0.5, which='minor')
    ax2.minorticks_on()
    ax2.legend(fontsize=56, loc="upper left", framealpha=0.95, prop={"family": "serif"},
               edgecolor='gray', fancybox=True, shadow=True, markerscale=1)
    ax2.set_axisbelow(True)
    for spine in ax2.spines.values():
        spine.set_linewidth(1.5)

    # BOTTOM ROW: CIR-reward training (Figure 7)
    # Position [1,0]: CIR-reward → CIR
    ax3 = fig.add_subplot(gs[1, 0])
    for idx, (coeff_name, step_data) in enumerate(cir_data):
        steps = []
        cot_values = []
        for step in sorted(step_data.keys()):
            if step_data[step]["cot_importance"] is not None:
                steps.append(step)
                cot_values.append(step_data[step]["cot_importance"])
        if steps and cot_values:
            color = colors_cir[idx % len(colors_cir)]
            marker = markers_cir[idx % len(markers_cir)]
            ax3.plot(steps, cot_values, linewidth=4.0, color=color, marker=marker,
                     markersize=10, label=coeff_name, alpha=0.85, markeredgewidth=1.5,
                     markeredgecolor='white')

    ax3.set_xlabel("RL Training Step", fontsize=35, fontweight="bold", family="serif")
    ax3.set_ylabel("CIR", fontsize=35, fontweight="bold", family="serif")
    ax3.set_title("CIR-reward → CIR",
                  fontsize=35, fontweight="bold", pad=15, family="serif")
    ax3.tick_params(axis="both", which="major", labelsize=18, width=1.5)
    ax3.tick_params(axis="both", which="minor", labelsize=16, width=1)
    for label in ax3.get_xticklabels() + ax3.get_yticklabels():
        label.set_family("serif")
    ax3.set_ylim(-0.02, 1.02)
    ax3.grid(True, alpha=0.25, linestyle='-', linewidth=0.8, which='major')
    ax3.grid(True, alpha=0.15, linestyle=':', linewidth=0.5, which='minor')
    ax3.minorticks_on()
    ax3.legend(fontsize=75, loc="upper left", framealpha=0.95, prop={"family": "serif"},
               edgecolor='gray', fancybox=True, shadow=True, markerscale=1.2)
    ax3.set_axisbelow(True)
    for spine in ax3.spines.values():
        spine.set_linewidth(1.5)

    # Position [1,1]: CIR-reward → SR
    ax4 = fig.add_subplot(gs[1, 1])
    for idx, (coeff_name, step_data) in enumerate(cir_data):
        steps = []
        verifier_values = []
        for step in sorted(step_data.keys()):
            if step_data[step]["verifier_accuracy"] is not None:
                steps.append(step)
                verifier_values.append(step_data[step]["verifier_accuracy"])
        if steps and verifier_values:
            color = colors_cir[idx % len(colors_cir)]
            marker = markers_cir[idx % len(markers_cir)]
            ax4.plot(steps, verifier_values, linewidth=4.0, color=color, marker=marker,
                     markersize=10, label=coeff_name, alpha=0.85, markeredgewidth=1.5,
                     markeredgecolor='white')

    ax4.set_xlabel("RL Training Step", fontsize=35, fontweight="bold", family="serif")
    ax4.set_ylabel("SR", fontsize=35, fontweight="bold", family="serif")
    ax4.set_title("CIR-reward → SR",
                  fontsize=35, fontweight="bold", pad=15, family="serif")
    ax4.tick_params(axis="both", which="major", labelsize=18, width=1.5)
    ax4.tick_params(axis="both", which="minor", labelsize=16, width=1)
    for label in ax4.get_xticklabels() + ax4.get_yticklabels():
        label.set_family("serif")
    ax4.set_ylim(-0.02, 1.02)
    ax4.grid(True, alpha=0.25, linestyle='-', linewidth=0.8, which='major')
    ax4.grid(True, alpha=0.15, linestyle=':', linewidth=0.5, which='minor')
    ax4.minorticks_on()
    ax4.legend(fontsize=56, loc="upper left", framealpha=0.95, prop={"family": "serif"},
               edgecolor='gray', fancybox=True, shadow=True, markerscale=1)
    ax4.set_axisbelow(True)
    for spine in ax4.spines.values():
        spine.set_linewidth(1.5)

    # Position [0,2]: SR-reward → Accuracy
    ax5 = fig.add_subplot(gs[0, 2])
    for idx, (scale_name, step_data) in enumerate(sr_data):
        steps = []
        accuracy_values = []
        for step in sorted(step_data.keys()):
            if step_data[step]["accuracy"] is not None:
                steps.append(step)
                accuracy_values.append(step_data[step]["accuracy"])
        if steps and accuracy_values:
            color = colors_sr[idx % len(colors_sr)]
            marker = markers_sr[idx % len(markers_sr)]
            ax5.plot(steps, accuracy_values, linewidth=4.0, color=color, marker=marker,
                     markersize=10, label=scale_name, alpha=0.85, markeredgewidth=1.5,
                     markeredgecolor='white')

    ax5.set_xlabel("RL Training Step", fontsize=35, fontweight="bold", family="serif")
    ax5.set_ylabel("Accuracy", fontsize=35, fontweight="bold", family="serif")
    ax5.set_title("SR-reward → Accuracy",
                  fontsize=35, fontweight="bold", pad=15, family="serif")
    ax5.tick_params(axis="both", which="major", labelsize=18, width=1.5)
    ax5.tick_params(axis="both", which="minor", labelsize=16, width=1)
    for label in ax5.get_xticklabels() + ax5.get_yticklabels():
        label.set_family("serif")
    ax5.set_ylim(-0.02, 1.02)
    ax5.grid(True, alpha=0.25, linestyle='-', linewidth=0.8, which='major')
    ax5.grid(True, alpha=0.15, linestyle=':', linewidth=0.5, which='minor')
    ax5.minorticks_on()
    ax5.legend(fontsize=56, loc="upper left", framealpha=0.95, prop={"family": "serif"},
               edgecolor='gray', fancybox=True, shadow=True, markerscale=1)
    ax5.set_axisbelow(True)
    for spine in ax5.spines.values():
        spine.set_linewidth(1.5)

    # Position [1,2]: CIR-reward → Accuracy
    ax6 = fig.add_subplot(gs[1, 2])
    for idx, (coeff_name, step_data) in enumerate(cir_data):
        steps = []
        accuracy_values = []
        for step in sorted(step_data.keys()):
            if step_data[step]["accuracy"] is not None:
                steps.append(step)
                accuracy_values.append(step_data[step]["accuracy"])
        if steps and accuracy_values:
            color = colors_cir[idx % len(colors_cir)]
            marker = markers_cir[idx % len(markers_cir)]
            ax6.plot(steps, accuracy_values, linewidth=4.0, color=color, marker=marker,
                     markersize=10, label=coeff_name, alpha=0.85, markeredgewidth=1.5,
                     markeredgecolor='white')

    ax6.set_xlabel("RL Training Step", fontsize=35, fontweight="bold", family="serif")
    ax6.set_ylabel("Accuracy", fontsize=35, fontweight="bold", family="serif")
    ax6.set_title("CIR-reward → Accuracy",
                  fontsize=35, fontweight="bold", pad=15, family="serif")
    ax6.tick_params(axis="both", which="major", labelsize=18, width=1.5)
    ax6.tick_params(axis="both", which="minor", labelsize=16, width=1)
    for label in ax6.get_xticklabels() + ax6.get_yticklabels():
        label.set_family("serif")
    ax6.set_ylim(-0.02, 1.02)
    ax6.grid(True, alpha=0.25, linestyle='-', linewidth=0.8, which='major')
    ax6.grid(True, alpha=0.15, linestyle=':', linewidth=0.5, which='minor')
    ax6.minorticks_on()
    ax6.legend(fontsize=56, loc="upper left", framealpha=0.95, prop={"family": "serif"},
               edgecolor='gray', fancybox=True, shadow=True, markerscale=1)
    ax6.set_axisbelow(True)
    for spine in ax6.spines.values():
        spine.set_linewidth(1.5)

    plt.tight_layout(pad=2.0)

    png_path = os.path.join(output_dir, "figure6_7_combine.png")
    plt.savefig(png_path, dpi=300, bbox_inches="tight")

    pdf_path = os.path.join(output_dir, "figure6_7_combine.pdf")
    plt.savefig(pdf_path, bbox_inches="tight")

    plt.close()
    return png_path, pdf_path


def plot_individual_tasks_sr(task_scale_data, output_dir, tasks, scales):
    """Plot individual task metrics for SR-reward training in a grid layout."""
    num_tasks = len(tasks)
    fig_height = max(12, num_tasks * 4)

    fig, axes = plt.subplots(num_tasks, 3, figsize=(24, fig_height))

    # Handle case where there's only one task
    if num_tasks == 1:
        axes = axes.reshape(1, -1)

    # More visually distinct color palette (same as main plot)
    colors = ["#2E2E2E", "#1E88E5", "#FFA726", "#66BB6A", "#EF5350", "#AB47BC"]
    markers = ["o", "o", "o", "o", "o", "o"]

    for task_idx, task in enumerate(tasks):
        # Plot 1: CoT Importance (CIR)
        ax_cir = axes[task_idx, 0]
        for scale_idx, scale in enumerate(scales):
            if task in task_scale_data and scale in task_scale_data[task]:
                step_data = task_scale_data[task][scale]
                steps = []
                cot_values = []
                for step in sorted(step_data.keys()):
                    if step_data[step]["cot_importance"] is not None:
                        steps.append(step)
                        cot_values.append(step_data[step]["cot_importance"])

                if steps and cot_values:
                    color = colors[scale_idx % len(colors)]
                    marker = markers[scale_idx % len(markers)]
                    ax_cir.plot(steps, cot_values, linewidth=3.5, color=color, marker=marker,
                               markersize=8, label=f'α={scale}', alpha=0.85, markeredgewidth=1.2,
                               markeredgecolor='white')

        ax_cir.set_ylabel('CIR', fontsize=20, fontweight='bold', family='serif')
        ax_cir.set_title(f'{task.replace("_", " ").title()}', fontsize=35, fontweight='bold', family='serif', pad=15)
        ax_cir.tick_params(axis='both', which='major', labelsize=16, width=1.2)
        ax_cir.tick_params(axis='both', which='minor', labelsize=14, width=0.8)
        for label in ax_cir.get_xticklabels() + ax_cir.get_yticklabels():
            label.set_family('serif')
        ax_cir.set_ylim(-0.02, 1.02)
        ax_cir.grid(True, alpha=0.25, linestyle='-', linewidth=0.8, which='major')
        ax_cir.grid(True, alpha=0.15, linestyle=':', linewidth=0.5, which='minor')
        ax_cir.minorticks_on()
        if task_idx == 0:
            ax_cir.legend(fontsize=52, loc='best', framealpha=0.95, prop={'family': 'serif'},
                         edgecolor='gray', fancybox=True, shadow=True, markerscale=1)
        ax_cir.set_axisbelow(True)
        for spine in ax_cir.spines.values():
            spine.set_linewidth(1.2)
        if task_idx == len(tasks) - 1:
            ax_cir.set_xlabel('RL Training Step', fontsize=20, fontweight='bold', family='serif')

        # Plot 2: Verifier Accuracy (SR)
        ax_sr = axes[task_idx, 1]
        for scale_idx, scale in enumerate(scales):
            if task in task_scale_data and scale in task_scale_data[task]:
                step_data = task_scale_data[task][scale]
                steps = []
                verifier_values = []
                for step in sorted(step_data.keys()):
                    if step_data[step]["verifier_accuracy"] is not None:
                        steps.append(step)
                        verifier_values.append(step_data[step]["verifier_accuracy"])

                if steps and verifier_values:
                    color = colors[scale_idx % len(colors)]
                    marker = markers[scale_idx % len(markers)]
                    ax_sr.plot(steps, verifier_values, linewidth=3.5, color=color, marker=marker,
                             markersize=8, label=f'α={scale}', alpha=0.85, markeredgewidth=1.2,
                             markeredgecolor='white')

        ax_sr.set_ylabel('SR', fontsize=20, fontweight='bold', family='serif')
        ax_sr.tick_params(axis='both', which='major', labelsize=16, width=1.2)
        ax_sr.tick_params(axis='both', which='minor', labelsize=14, width=0.8)
        for label in ax_sr.get_xticklabels() + ax_sr.get_yticklabels():
            label.set_family('serif')
        ax_sr.set_ylim(-0.02, 1.02)
        ax_sr.grid(True, alpha=0.25, linestyle='-', linewidth=0.8, which='major')
        ax_sr.grid(True, alpha=0.15, linestyle=':', linewidth=0.5, which='minor')
        ax_sr.minorticks_on()
        ax_sr.set_axisbelow(True)
        for spine in ax_sr.spines.values():
            spine.set_linewidth(1.2)
        if task_idx == len(tasks) - 1:
            ax_sr.set_xlabel('RL Training Step', fontsize=20, fontweight='bold', family='serif')

        # Plot 3: Accuracy
        ax_acc = axes[task_idx, 2]
        for scale_idx, scale in enumerate(scales):
            if task in task_scale_data and scale in task_scale_data[task]:
                step_data = task_scale_data[task][scale]
                steps = []
                accuracy_values = []
                for step in sorted(step_data.keys()):
                    if step_data[step]["accuracy"] is not None:
                        steps.append(step)
                        accuracy_values.append(step_data[step]["accuracy"])

                if steps and accuracy_values:
                    color = colors[scale_idx % len(colors)]
                    marker = markers[scale_idx % len(markers)]
                    ax_acc.plot(steps, accuracy_values, linewidth=3.5, color=color, marker=marker,
                              markersize=8, label=f'α={scale}', alpha=0.85, markeredgewidth=1.2,
                              markeredgecolor='white')

        ax_acc.set_ylabel('Accuracy', fontsize=20, fontweight='bold', family='serif')
        ax_acc.tick_params(axis='both', which='major', labelsize=16, width=1.2)
        ax_acc.tick_params(axis='both', which='minor', labelsize=14, width=0.8)
        for label in ax_acc.get_xticklabels() + ax_acc.get_yticklabels():
            label.set_family('serif')
        ax_acc.set_ylim(-0.02, 1.02)
        ax_acc.grid(True, alpha=0.25, linestyle='-', linewidth=0.8, which='major')
        ax_acc.grid(True, alpha=0.15, linestyle=':', linewidth=0.5, which='minor')
        ax_acc.minorticks_on()
        ax_acc.set_axisbelow(True)
        for spine in ax_acc.spines.values():
            spine.set_linewidth(1.2)
        if task_idx == len(tasks) - 1:
            ax_acc.set_xlabel('RL Training Step', fontsize=20, fontweight='bold', family='serif')

    plt.tight_layout()

    png_path = os.path.join(output_dir, 'figure6_sr_reward_individual.png')
    plt.savefig(png_path, dpi=300, bbox_inches='tight')

    pdf_path = os.path.join(output_dir, 'figure6_sr_reward_individual.pdf')
    plt.savefig(pdf_path, bbox_inches='tight')

    plt.close()

    return png_path, pdf_path


def plot_individual_tasks_cir(task_coeff_data, output_dir, tasks, coefficients):
    """Plot individual task metrics for CIR-reward training in a grid layout."""
    num_tasks = len(tasks)
    fig_height = max(12, num_tasks * 4)

    fig, axes = plt.subplots(num_tasks, 3, figsize=(24, fig_height))

    # Handle case where there's only one task
    if num_tasks == 1:
        axes = axes.reshape(1, -1)

    # More visually distinct color palette (same as main plot)
    colors = ["#2E2E2E", "#1E88E5", "#FFA726", "#66BB6A", "#EF5350", "#AB47BC"]
    markers = ["o", "o", "o", "o", "o", "o"]

    for task_idx, task in enumerate(tasks):
        # Plot 1: CoT Importance (CIR)
        ax_cir = axes[task_idx, 0]
        for coeff_idx, coeff in enumerate(coefficients):
            if task in task_coeff_data and coeff in task_coeff_data[task]:
                step_data = task_coeff_data[task][coeff]
                steps = []
                cot_values = []
                for step in sorted(step_data.keys()):
                    if step_data[step]["cot_importance"] is not None:
                        steps.append(step)
                        cot_values.append(step_data[step]["cot_importance"])

                if steps and cot_values:
                    color = colors[coeff_idx % len(colors)]
                    marker = markers[coeff_idx % len(markers)]
                    ax_cir.plot(steps, cot_values, linewidth=3.5, color=color, marker=marker,
                               markersize=8, label=f'β={coeff}', alpha=0.85, markeredgewidth=1.2,
                               markeredgecolor='white')

        ax_cir.set_ylabel('CIR', fontsize=20, fontweight='bold', family='serif')
        ax_cir.set_title(f'{task.replace("_", " ").title()}', fontsize=35, fontweight='bold', family='serif', pad=15)
        ax_cir.tick_params(axis='both', which='major', labelsize=16, width=1.2)
        ax_cir.tick_params(axis='both', which='minor', labelsize=14, width=0.8)
        for label in ax_cir.get_xticklabels() + ax_cir.get_yticklabels():
            label.set_family('serif')
        ax_cir.set_ylim(-0.02, 1.02)
        ax_cir.grid(True, alpha=0.25, linestyle='-', linewidth=0.8, which='major')
        ax_cir.grid(True, alpha=0.15, linestyle=':', linewidth=0.5, which='minor')
        ax_cir.minorticks_on()
        if task_idx == 0:
            ax_cir.legend(fontsize=52, loc='best', framealpha=0.95, prop={'family': 'serif'},
                         edgecolor='gray', fancybox=True, shadow=True, markerscale=1)
        ax_cir.set_axisbelow(True)
        for spine in ax_cir.spines.values():
            spine.set_linewidth(1.2)
        if task_idx == len(tasks) - 1:
            ax_cir.set_xlabel('RL Training Step', fontsize=20, fontweight='bold', family='serif')

        # Plot 2: Verifier Accuracy (SR)
        ax_sr = axes[task_idx, 1]
        for coeff_idx, coeff in enumerate(coefficients):
            if task in task_coeff_data and coeff in task_coeff_data[task]:
                step_data = task_coeff_data[task][coeff]
                steps = []
                verifier_values = []
                for step in sorted(step_data.keys()):
                    if step_data[step]["verifier_accuracy"] is not None:
                        steps.append(step)
                        verifier_values.append(step_data[step]["verifier_accuracy"])

                if steps and verifier_values:
                    color = colors[coeff_idx % len(colors)]
                    marker = markers[coeff_idx % len(markers)]
                    ax_sr.plot(steps, verifier_values, linewidth=3.5, color=color, marker=marker,
                             markersize=8, label=f'β={coeff}', alpha=0.85, markeredgewidth=1.2,
                             markeredgecolor='white')

        ax_sr.set_ylabel('SR', fontsize=20, fontweight='bold', family='serif')
        ax_sr.tick_params(axis='both', which='major', labelsize=16, width=1.2)
        ax_sr.tick_params(axis='both', which='minor', labelsize=14, width=0.8)
        for label in ax_sr.get_xticklabels() + ax_sr.get_yticklabels():
            label.set_family('serif')
        ax_sr.set_ylim(-0.02, 1.02)
        ax_sr.grid(True, alpha=0.25, linestyle='-', linewidth=0.8, which='major')
        ax_sr.grid(True, alpha=0.15, linestyle=':', linewidth=0.5, which='minor')
        ax_sr.minorticks_on()
        ax_sr.set_axisbelow(True)
        for spine in ax_sr.spines.values():
            spine.set_linewidth(1.2)
        if task_idx == len(tasks) - 1:
            ax_sr.set_xlabel('RL Training Step', fontsize=20, fontweight='bold', family='serif')

        # Plot 3: Accuracy
        ax_acc = axes[task_idx, 2]
        for coeff_idx, coeff in enumerate(coefficients):
            if task in task_coeff_data and coeff in task_coeff_data[task]:
                step_data = task_coeff_data[task][coeff]
                steps = []
                accuracy_values = []
                for step in sorted(step_data.keys()):
                    if step_data[step]["accuracy"] is not None:
                        steps.append(step)
                        accuracy_values.append(step_data[step]["accuracy"])

                if steps and accuracy_values:
                    color = colors[coeff_idx % len(colors)]
                    marker = markers[coeff_idx % len(markers)]
                    ax_acc.plot(steps, accuracy_values, linewidth=3.5, color=color, marker=marker,
                              markersize=8, label=f'β={coeff}', alpha=0.85, markeredgewidth=1.2,
                              markeredgecolor='white')

        ax_acc.set_ylabel('Accuracy', fontsize=20, fontweight='bold', family='serif')
        ax_acc.tick_params(axis='both', which='major', labelsize=16, width=1.2)
        ax_acc.tick_params(axis='both', which='minor', labelsize=14, width=0.8)
        for label in ax_acc.get_xticklabels() + ax_acc.get_yticklabels():
            label.set_family('serif')
        ax_acc.set_ylim(-0.02, 1.02)
        ax_acc.grid(True, alpha=0.25, linestyle='-', linewidth=0.8, which='major')
        ax_acc.grid(True, alpha=0.15, linestyle=':', linewidth=0.5, which='minor')
        ax_acc.minorticks_on()
        ax_acc.set_axisbelow(True)
        for spine in ax_acc.spines.values():
            spine.set_linewidth(1.2)
        if task_idx == len(tasks) - 1:
            ax_acc.set_xlabel('RL Training Step', fontsize=20, fontweight='bold', family='serif')

    plt.tight_layout()

    png_path = os.path.join(output_dir, 'figure7_cir_reward_individual.png')
    plt.savefig(png_path, dpi=300, bbox_inches='tight')

    pdf_path = os.path.join(output_dir, 'figure7_cir_reward_individual.pdf')
    plt.savefig(pdf_path, bbox_inches='tight')

    plt.close()

    return png_path, pdf_path


def main():
    # SR-reward training data (Figure 6)
    sr_tasks = ["binary_alternation", "binary_matrix", "bitwise_arithmetic", "count_bits", "manipulate_matrix", "futoshiki", "mini_sudoku", "rotate_matrix", "string_manipulation", "tsumego"]
    verifier_scales = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]

    # CIR-reward training data (Figure 7)
    cir_tasks = ["binary_alternation", "binary_matrix", "bitwise_arithmetic", "count_bits", "manipulate_matrix", "futoshiki", "mini_sudoku", "rotate_matrix", "string_manipulation", "tsumego"]
    coefficients = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]

    steps_to_keep = [2, 30, 60, 90, 120, 150, 156]

    base_path = "/nlp/scr/qinanyu/rl-explanations/evaluate/results/grpo_qwen-3b-instruct"
    scale0_base_path = "/nlp/scr/qinanyu/rl-explanations/evaluate/results/grpo_qwen-3b-instruct/cot_importance"
    coeff0_base_path = "/nlp/scr/qinanyu/rl-explanations/evaluate/results/grpo_qwen-3b-instruct/cot_importance"

    # Load SR-reward data (with individual task collection)
    sr_data, task_scale_data = load_sr_reward_data(base_path, sr_tasks, verifier_scales, steps_to_keep, scale0_base_path, collect_individual=True)

    # Load CIR-reward data (with individual task collection)
    cir_data, task_coeff_data = load_cir_reward_data(base_path, cir_tasks, coefficients, steps_to_keep, coeff0_base_path, collect_individual=True)

    if not sr_data or not cir_data:
        print("\nError: Missing data for SR or CIR training")
        return

    output_dir = "/nlp/scr/qinanyu/rl-explanations/analysis/graph"
    os.makedirs(output_dir, exist_ok=True)
    print(f"\nOutput directory: {output_dir}")

    # Generate combined plot (Figure 6 + 7)
    print("\n=== Generating combined plot ===")
    png_path, pdf_path = plot_combined(sr_data, cir_data, sr_tasks, cir_tasks, output_dir)
    print(f"Saved: {png_path}")
    print(f"Saved: {pdf_path}")

    # Generate individual task plot for SR-reward (Figure 6 individual)
    print("\n=== Generating individual task plots for SR-reward ===")
    sr_ind_png, sr_ind_pdf = plot_individual_tasks_sr(task_scale_data, output_dir, sr_tasks, verifier_scales)
    print(f"Saved: {sr_ind_png}")
    print(f"Saved: {sr_ind_pdf}")

    # Generate individual task plot for CIR-reward (Figure 7 individual)
    print("\n=== Generating individual task plots for CIR-reward ===")
    cir_ind_png, cir_ind_pdf = plot_individual_tasks_cir(task_coeff_data, output_dir, cir_tasks, coefficients)
    print(f"Saved: {cir_ind_png}")
    print(f"Saved: {cir_ind_pdf}")


if __name__ == "__main__":
    main()
