import os
import json
import numpy as np
import matplotlib.pyplot as plt

plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.style.use('seaborn-v0_8-whitegrid')

METHODS = {
    'SAIM-NOSABCD': r"experiment_results\ViT-B-32\SAIM\NOSABCD\IND",
    'SAIM-SABCD': r"experiment_results\ViT-B-32\SAIM\SABCD\IND",
    'TA-NOSABCD': r"experiment_results\ViT-B-32\TA\NOSABCD\IND",
    'TA-SABCD': r"experiment_results\ViT-B-32\TA\SABCD\IND"
}
COLORS = {
    'SAIM-NOSABCD': '#FF6B6B',
    'TA-NOSABCD': '#2C73D2',
    'TA-SABCD': '#2C73D2',
    'SAIM-SABCD': '#FF6B6B'
}
LABELS = {
    'SAIM-NOSABCD': 'SAIM w/o SA-BCD',
    'TA-NOSABCD': 'TA w/o SA-BCD',
    'TA-SABCD': 'TA w/ SA-BCD',
    'SAIM-SABCD': 'SAIM w/ SA-BCD'
}
MARKERS = {
    'SAIM-NOSABCD': 'o',
    'TA-NOSABCD': 's',
    'TA-SABCD': '^',
    'SAIM-SABCD': 'D'
}

def collect_orders(method_dir, max_order=5):
    orders = []
    for i in range(1, max_order+1):
        order_path = os.path.join(method_dir, f"order{i}")
        if os.path.isdir(order_path):
            orders.append(order_path)
    return orders

def load_order_metrics(order_dir, method, max_stage=20):
    """Load accuracy and SAR curve for a single order, compatible with SAIM and TA formats"""
    exp_path = os.path.join(order_dir, "experiment_results.json")
    sar_path = os.path.join(order_dir, "sar_analysis_results.json")
    accs, sars = [], []
    stages = []
    # Load accuracy
    if os.path.exists(exp_path):
        with open(exp_path, "r", encoding="utf-8") as f:
            exp_data = json.load(f)
        if "SAIM" in exp_data:
            method_data = exp_data["SAIM"]
            key_prefix = "task_"
        elif "task_arithmetic" in exp_data:
            method_data = exp_data["task_arithmetic"]
            key_prefix = "task_"
        else:
            method_data = exp_data.get("SAIM", {})
            key_prefix = "task_"
        for i in range(1, max_stage+1):
            stage_key = f"{key_prefix}{i}_"
            stage_json = [k for k in method_data if k.startswith(stage_key)]
            if stage_json:
                avg_acc = method_data[stage_json[0]].get("average_accuracy", None)
                if avg_acc is not None:
                    accs.append(avg_acc * 100)
                    stages.append(i)
                else:
                    print(f"[{order_dir}] Task {i} does not have average_accuracy field")
            else:
                print(f"[{order_dir}] Missing data for task {i}")
    else:
        print(f"[{order_dir}] experiment_results.json does not exist")
    # Load SAR
    if os.path.exists(sar_path):
        with open(sar_path, "r", encoding="utf-8") as f:
            sar_data = json.load(f)
        if method.startswith("TA"):
            sar_prefix = "task_arithmetic_task_"
        else:
            sar_prefix = "SAIM_task_"
        for i in range(1, max_stage+1):
            sar_key = f"{sar_prefix}{i}_"
            sar_json = [k for k in sar_data if k.startswith(sar_key)]
            if sar_json:
                sar_val = sar_data[sar_json[0]].get("overall_avg_sar", None)
                if sar_val is not None:
                    sars.append(sar_val)
                else:
                    print(f"[{order_dir}] Task {i} does not have overall_avg_sar field")
            else:
                print(f"[{order_dir}] SAR missing data for task {i}")
    else:
        print(f"[{order_dir}] sar_analysis_results.json does not exist")
    min_len = min(len(accs), len(sars), len(stages))
    if min_len < max_stage:
        print(f"[{order_dir}] Only {min_len} valid tasks, should be {max_stage}")
    return stages[:min_len], accs[:min_len], sars[:min_len]

def plot_metrics(methods_data, output_dir="analyze/fig"):
    os.makedirs(output_dir, exist_ok=True)
    max_stage = 20

    # 1. Accuracy curve with scatter (dotted points)
    fig, ax = plt.subplots(figsize=(8, 4))
    for method, orders_data in methods_data.items():
        accs_list = [accs for _, accs, _ in orders_data]
        stages_list = [stages for stages, _, _ in orders_data]
        min_len = min(len(accs) for accs in accs_list)
        accs_arr = np.array([accs[:min_len] for accs in accs_list])
        stages = stages_list[0][:min_len]
        mean_acc = np.mean(accs_arr, axis=0)
        # Determine line style
        if "NOSABCD" in method:
            linestyle = ':'
        else:
            linestyle = '-'
        ax.plot(
            stages, mean_acc,
            label=LABELS.get(method, method),
            color=COLORS.get(method, '#333333'),
            marker=MARKERS.get(method, 'o'),
            linestyle=linestyle,
            linewidth=2, markersize=6, alpha=0.9
        )
        for accs in accs_arr:
            ax.scatter(stages, accs, color=COLORS.get(method, '#333333'), alpha=0.25, s=20)
    ax.set_xlabel('Task Number', fontsize=22)
    ax.set_ylabel('Test Accuracy (%)', fontsize=22)
    ax.set_xlim(0.5, max_stage + 0.5)
    ax.set_ylim(50, 100)
    ax.set_xticks(range(1, max_stage+1, 2))
    ax.grid(True, linestyle='--', alpha=0.4)
    ax.legend(loc='upper right', frameon=True, framealpha=0.6, facecolor='white', edgecolor='gray', fontsize=22)
    for spine in ['top', 'right']:
        ax.spines[spine].set_visible(False)
    ax.tick_params(axis='both', labelsize=22)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'accuracy_scatter.png'), dpi=300, bbox_inches='tight')

    # 2. SAR curve with scatter (dotted points)
    fig, ax = plt.subplots(figsize=(8, 4))
    for method, orders_data in methods_data.items():
        sars_list = [sars for _, _, sars in orders_data]
        stages_list = [stages for stages, _, _ in orders_data]
        min_len = min(len(sars) for sars in sars_list)
        sars_arr = np.array([sars[:min_len] for sars in sars_list])
        stages = stages_list[0][:min_len]
        mean_sar = np.mean(sars_arr, axis=0)
        # Determine line style
        if "NOSABCD" in method:
            linestyle = ':'
        else:
            linestyle = '-'
        ax.plot(
            stages, mean_sar,
            label=LABELS.get(method, method),
            color=COLORS.get(method, '#333333'),
            marker=MARKERS.get(method, 'o'),
            linestyle=linestyle,
            linewidth=2, markersize=6, alpha=0.9
        )
        for sars in sars_arr:
            ax.scatter(stages, sars, color=COLORS.get(method, '#333333'), alpha=0.25, s=20)
    ax.set_xlabel('Task Number', fontsize=22)
    ax.set_ylabel('Average SAR', fontsize=22)
    ax.set_xlim(0.5, max_stage + 0.5)
    ax.set_ylim(0.75, 1.0)
    ax.set_xticks(range(1, max_stage+1, 2))
    ax.grid(True, linestyle='--', alpha=0.4)
    ax.legend(loc='lower left', frameon=True, framealpha=0.6, facecolor='white', edgecolor='gray', fontsize=22)
    for spine in ['top', 'right']:
        ax.spines[spine].set_visible(False)
    ax.tick_params(axis='both', labelsize=22)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'sar_scatter.png'), dpi=300, bbox_inches='tight')

    # 3. Side-by-side comparison (with scatter dotted points)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5))
    # Accuracy
    for method, orders_data in methods_data.items():
        accs_list = [accs for _, accs, _ in orders_data]
        stages_list = [stages for stages, _, _ in orders_data]
        min_len = min(len(accs) for accs in accs_list)
        accs_arr = np.array([accs[:min_len] for accs in accs_list])
        stages = stages_list[0][:min_len]
        mean_acc = np.mean(accs_arr, axis=0)
        # Determine line style
        if "NOSABCD" in method:
            linestyle = ':'
        else:
            linestyle = '-'
        ax1.plot(
            stages, mean_acc,
            label=LABELS.get(method, method),
            color=COLORS.get(method, '#333333'),
            marker=MARKERS.get(method, 'o'),
            linestyle=linestyle,
            linewidth=2, markersize=6, alpha=0.9
        )
        for accs in accs_arr:
            ax1.scatter(stages, accs, color=COLORS.get(method, '#333333'), alpha=0.25, s=20)
    ax1.set_xlabel('Task Number', fontsize=22)
    ax1.set_ylabel('Test Accuracy (%)', fontsize=22)
    ax1.set_xlim(0.5, max_stage + 0.5)
    ax1.set_ylim(50, 100)
    ax1.set_xticks(range(1, max_stage+1, 2))
    ax1.grid(True, linestyle='--', alpha=0.4)
    for spine in ['top', 'right']:
        ax1.spines[spine].set_visible(False)
    ax1.tick_params(axis='both', labelsize=22)

    # SAR
    for method, orders_data in methods_data.items():
        sars_list = [sars for _, _, sars in orders_data]
        stages_list = [stages for stages, _, _ in orders_data]
        min_len = min(len(sars) for sars in sars_list)
        sars_arr = np.array([sars[:min_len] for sars in sars_list])
        stages = stages_list[0][:min_len]
        mean_sar = np.mean(sars_arr, axis=0)
        # Determine line style
        if "NOSABCD" in method:
            linestyle = ':'
        else:
            linestyle = '-'
        ax2.plot(
            stages, mean_sar,
            label=LABELS.get(method, method),
            color=COLORS.get(method, '#333333'),
            marker=MARKERS.get(method, 'o'),
            linestyle=linestyle,
            linewidth=2, markersize=6, alpha=0.9
        )
        for sars in sars_arr:
            ax2.scatter(stages, sars, color=COLORS.get(method, '#333333'), alpha=0.25, s=20)
    ax2.set_xlabel('Task Number', fontsize=22)
    ax2.set_ylabel('Average SAR', fontsize=22)
    ax2.set_xlim(0.5, max_stage + 0.5)
    ax2.set_ylim(0.75, 1.0)
    ax2.set_xticks(range(1, max_stage+1, 2))
    ax2.grid(True, linestyle='--', alpha=0.4)
    for spine in ['top', 'right']:
        ax2.spines[spine].set_visible(False)
    ax2.tick_params(axis='both', labelsize=22)

    handles, labels = ax1.get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.04),
               frameon=False, ncol=len(methods_data), fontsize=24)
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.13)
    plt.savefig(os.path.join(output_dir, 'combined_metrics.png'), dpi=300, bbox_inches='tight')
    plt.close('all')
    print(f"Figures saved to {output_dir}")

def main():
    methods_data = {}
    for method, method_dir in METHODS.items():
        orders = collect_orders(method_dir)
        orders_data = []
        for order_dir in orders:
            stages, accs, sars = load_order_metrics(order_dir, method)
            if stages and accs and sars:
                orders_data.append((stages, accs, sars))
        if orders_data:
            methods_data[method] = orders_data
            print(f"{method} loaded {len(orders_data)} orders")
        else:
            print(f"{method} has no valid data")
    if not methods_data:
        print("Failed to load data for any method")
        return
    plot_metrics(methods_data)

if __name__ == "__main__":
    main()