import os, csv
import matplotlib.pyplot as plt
from utils import config_id, results_paths

import numpy as np
import pandas as pd
import seaborn as sns

def _load_round_csv(path):
    rounds, loss, acc = [], [], []
    rc, re, rs, cc, ce, cs, ct = [], [], [], [], [], [], []
    round_comm, cum_comm, cum_downlink, cum_uplink = [], [], [], []

    with open(path, "r") as f:
        r = csv.DictReader(f)
        for row in r:
            try:
                rounds.append(int(row["round"]))
                loss.append(float(row["loss"]))
                acc.append(float(row["acc"]))
                rc.append(float(row["round_client_time_sec"]))
                re.append(float(row["round_each_client_time_sec"]))
                rs.append(float(row["round_server_time_sec"]))
                cc.append(float(row["cum_client_time_sec"]))
                ce.append(float(row["cum_each_client_time_sec"]))
                cs.append(float(row["cum_server_time_sec"]))
                ct.append(float(row["cum_time_sec"]))
                
                round_comm.append(float(row.get("round_comm_gb", 0.0)))
                cum_comm.append(float(row.get("cum_comm_gb", 0.0)))
                cum_downlink.append(float(row.get("cum_downlink_gb", 0.0)))
                cum_uplink.append(float(row.get("cum_uplink_gb", 0.0)))
                
            except KeyError as e:
                if 'comm_gb' not in str(e).lower() and 'time_sec' not in str(e).lower():
                    print(f"Warning: Missing expected column in {path}: {e}")
            except ValueError as e:
                print(f"Warning: Corrupted data in {path} at row {row}: {e}")
    return {
        "round": rounds, "loss": loss, "acc": acc,
        "round_client_time_sec": rc, "round_each_client_time_sec": re, "round_server_time_sec": rs,
        "cum_client_time_sec": cc,  "cum_each_client_time_sec": ce,  "cum_server_time_sec": cs,
        "cum_time_sec": ct,
        "round_comm_gb": round_comm, "cum_comm_gb": cum_comm,
        "cum_downlink_gb": cum_downlink, "cum_uplink_gb": cum_uplink
    }

def _csv_path_from_cfg(cfg):
    eid = config_id(cfg)
    path = results_paths(cfg)["round_csv"]
    return eid, (path if os.path.exists(path) else None)

def aggregate_group(cfg_list):
    runs = []
    for cfg in cfg_list:
        _, p = _csv_path_from_cfg(cfg)
        if p:
            try:
                run_data = _load_round_csv(p)
                if run_data.get("round") and len(run_data["round"]) > 0: 
                    runs.append(run_data)
                else:
                    print(f"Warning: No data in CSV: {p}")
            except Exception as e:
                print(f"Warning: Failed to load {p}: {e}")
    if not runs:
        raise FileNotFoundError("no valid CSVs found for given cfg_list")

    L = min(len(r["round"]) for r in runs)
    if L == 0:
        raise ValueError("CSVs found, but all have 0 rounds of data.")

    def stack(key):
        return np.stack([r[key][:L] for r in runs], axis=0)

    rounds = runs[0]["round"][:L]

    acc_mat   = stack("acc")
    loss_mat  = stack("loss")
    rcli_mat  = stack("round_client_time_sec")
    rsvr_mat  = stack("round_server_time_sec")
    ctime_mat = stack("cum_time_sec")
    ccli_mat  = stack("cum_client_time_sec")
    csvr_mat  = stack("cum_server_time_sec")
    ccomm_mat = stack("cum_comm_gb")
    cdown_mat = stack("cum_downlink_gb")
    cupl_mat  = stack("cum_uplink_gb")

    return {
        "round": rounds,
        "acc_mean":  acc_mat.mean(0).tolist(),
        "acc_std":   acc_mat.std(0, ddof=0).tolist(),
        "loss_mean": loss_mat.mean(0).tolist(),
        "round_client_mean": rcli_mat.mean(0).tolist(),
        "round_server_mean": rsvr_mat.mean(0).tolist(),
        "cum_time_mean":     ctime_mat.mean(0).tolist(),
        "cum_time_std":      ctime_mat.std(0, ddof=0).tolist(), 
        "cum_client_mean":   ccli_mat.mean(0).tolist(),
        "cum_server_mean":   csvr_mat.mean(0).tolist(),
        "cum_comm_mean":     ccomm_mat.mean(0).tolist(),
        "cum_comm_std":      ccomm_mat.std(0, ddof=0).tolist(),
        "cum_downlink_mean": cdown_mat.mean(0).tolist(),
        "cum_downlink_std":  cdown_mat.std(0, ddof=0).tolist(),
        "cum_uplink_mean":   cupl_mat.mean(0).tolist(),
        "cum_uplink_std":    cupl_mat.std(0, ddof=0).tolist(),
        "n_runs": len(runs),
    }


def print_group_stats(stats, label=None):
    lbl = (label or f"{stats.get('n_runs', '?')} runs")
    print(f"[stats] {lbl}")
    for r, acc, std, ctm in zip(stats["round"], stats["acc_mean"], stats["acc_std"], stats["cum_time_mean"]):
        print(f"  round {r:>3d} | acc={acc:6.2f}% ± {std:4.2f} | cum_time={ctm:7.2f}s")

def plot_groups(list_of_cfg_lists, labels=None,
                title="Group Comparison", plot_name=None,
                show=True, save_path=None):
    group_stats = []
    labels = labels or [f"group{i+1}" for i in range(len(list_of_cfg_lists))]
    for i, cfgs in enumerate(list_of_cfg_lists):
        st = aggregate_group(cfgs)
        group_stats.append(st)
        print_group_stats(st, label=labels[i])

    fig, axes = plt.subplots(1, 2, figsize=(12, 5), dpi=120)

    ax = axes[0]
    for st, lab in zip(group_stats, labels):
        ax.plot(
            st["round"], st["acc_mean"],
            linestyle="-", linewidth=2, marker="o", markersize=3,
            label=lab
        )
    ax.set_title(f"{title} — Accuracy")
    ax.set_xlabel("Round"); ax.set_ylabel("Accuracy (%)")
    ax.grid(True); ax.legend(loc="best")

    ax = axes[1]
    for st, lab in zip(group_stats, labels):
        ax.plot(
            st["round"], st["cum_time_mean"],
            linestyle="-", linewidth=2, marker="o", markersize=3,
            label=lab
        )
    ax.set_title(f"{title} — Cumulative Time")
    ax.set_xlabel("Round"); ax.set_ylabel("Time (s)")
    ax.grid(True); ax.legend(loc="best")

    if plot_name:
        os.makedirs("./plots", exist_ok=True)
        out_path = os.path.join("./plots", f"{plot_name}.png")
        fig.savefig(out_path, bbox_inches="tight", dpi=300)
        print(f"[viz] saved plot to {out_path}")
    elif save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        fig.savefig(save_path, bbox_inches="tight", dpi=300)
        print(f"[viz] saved plot to {save_path}")

    if show:
        plt.show()
    else:
        plt.close(fig)

def _get_mean_acc_series_from_group(cfg_list):
    try:
        stats = aggregate_group(cfg_list)
        return pd.Series(stats["acc_mean"], index=stats["round"])
    except FileNotFoundError:
        return None

def plot_performance_comparison(
    models_to_compare,
    all_configs,
    non_iid_settings,
    title="Performance Comparison",
    plot_name="performance_comparison",
    last_n_rounds=1,
    save_path=None,
    show=True
):
    sns.set_style("whitegrid")
    fig, axes = plt.subplots(1, 2, figsize=(12.3, 5.5))
    ax_line, ax_bar = axes[0], axes[1]

    colors = plt.cm.viridis(np.linspace(0, 1, len(models_to_compare)))
    linestyles = ['-', '--', ':', '-.']
    
    resilience_scores = {}

    print("--- Generating Performance Comparison Chart ---")

    for i, model_name in enumerate(models_to_compare):
        if model_name not in all_configs:
            print(f"Warning: '{model_name}' not found in configs. Skipping.")
            continue
        
        print(f"Processing model: {model_name}")
        model_partitions = all_configs[model_name]

        iid_acc_series = _get_mean_acc_series_from_group(model_partitions.get("iid", []))
        
        non_iid_series_list = []
        for setting in non_iid_settings:
            s = _get_mean_acc_series_from_group(model_partitions.get(setting, []))
            if s is not None:
                non_iid_series_list.append(s)

        if iid_acc_series is None or not non_iid_series_list:
            print(f"Warning: Missing IID or Non-IID data for '{model_name}'. Skipping.")
            continue

        min_len = min(len(iid_acc_series), *[len(s) for s in non_iid_series_list])
        iid_acc = iid_acc_series.iloc[:min_len]
        
        non_iid_df = pd.concat([s.iloc[:min_len] for s in non_iid_series_list], axis=1)
        non_iid_avg_acc = non_iid_df.mean(axis=1)
        
        relative_performance = (non_iid_avg_acc.values / (iid_acc.values + 1e-9)) * 100
        
        ax_line.plot(iid_acc.index, relative_performance, label=model_name, linewidth=2.5,
                     color=colors[i], linestyle=linestyles[i % len(linestyles)])

        iid_final_acc = iid_acc_series.tail(last_n_rounds).mean()
        
        non_iid_final_scores = [s.tail(last_n_rounds).mean() for s in non_iid_series_list]
        avg_non_iid_final_acc = sum(non_iid_final_scores) / len(non_iid_final_scores) if non_iid_final_scores else 0
        
        ratio = (avg_non_iid_final_acc / iid_final_acc) * 100 if iid_final_acc > 0 else 0
        resilience_scores[model_name] = ratio

    ax_line.set_xlabel('Round', fontsize=14)
    ax_line.set_ylabel('Relative Performance (%)', fontsize=14)
    ax_line.tick_params(labelsize=12)
    ax_line.set_ylim(bottom=0)
    ax_line.grid(True, which='both', linestyle='--', linewidth=0.5)

    labels = list(resilience_scores.keys())
    scores = list(resilience_scores.values())
    
    bars = ax_bar.bar(labels, scores, color=colors, width=0.5)
    ax_bar.set_ylabel('Resilience Score (%)', fontsize=14)
    ax_bar.tick_params(axis='x', labelsize=12)
    ax_bar.tick_params(axis='y', labelsize=12)
    ax_bar.set_ylim(0, max(105, max(scores) * 1.1 if scores else 105))
    ax_bar.bar_label(bars, padding=3, fmt='%.1f%%', fontsize=12)
    ax_bar.grid(axis='y', linestyle='--', linewidth=0.5)

    handles, labels = ax_line.get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.05), ncol=len(models_to_compare), prop={'size': 14})

    fig.tight_layout(rect=[0, 0.1, 1, 0.95])

    if plot_name:
        out_path = save_path if save_path else os.path.join("./plots", f"{plot_name}.png")
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        fig.savefig(out_path, bbox_inches="tight", dpi=300)
        print(f"\nChart saved to: {out_path}")

    if show:
        plt.show()
    else:
        plt.close(fig)

def plot_accuracy_comparison(
    models_to_compare,
    all_configs,
    non_iid_settings_to_average,
    title="Model Accuracy Comparison",
    plot_name="accuracy_comparison",
    save_path=None,
    show=True
):
    sns.set_style("whitegrid")
    plt.figure(figsize=(12, 8))

    palette = sns.color_palette("deep", 6)
    markers = ['o', 's', 'v', 'p', 'D', '^']
    line_index = 0

    print("--- Generating Accuracy Comparison Chart ---")

    for model_name in models_to_compare:
        if model_name not in all_configs:
            print(f"Warning: '{model_name}' not found in configs. Skipping.")
            continue
        
        print(f"Processing model: {model_name}")
        model_partitions = all_configs[model_name]

        iid_acc_series = _get_mean_acc_series_from_group(model_partitions.get("iid", []))
        
        non_iid_series_list = []
        for setting in non_iid_settings_to_average:
            s = _get_mean_acc_series_from_group(model_partitions.get(setting, []))
            if s is not None:
                non_iid_series_list.append(s)

        if iid_acc_series is None or not non_iid_series_list:
            print(f"Warning: Missing IID or averaged Non-IID data for '{model_name}'. Skipping.")
            continue

        all_series = [iid_acc_series] + non_iid_series_list
        min_len = min(len(s) for s in all_series)
        
        iid_acc = iid_acc_series.iloc[:min_len]
        
        non_iid_df = pd.concat([s.iloc[:min_len] for s in non_iid_series_list], axis=1)
        non_iid_avg_acc = non_iid_df.mean(axis=1)
        
        rounds = iid_acc.index.to_numpy()

        if line_index < len(palette):
            plt.plot(rounds, iid_acc.to_numpy(), 
                     label=f'{model_name} (IID)', color=palette[line_index], 
                     linestyle='-', marker=markers[line_index], 
                     markevery=10, linewidth=3, markersize=6)
            line_index += 1

        if line_index < len(palette):
            plt.plot(rounds, non_iid_avg_acc.to_numpy(), 
                     label=f'{model_name} (Non-IID)', color=palette[line_index], 
                     linestyle='--', marker=markers[line_index], 
                     markevery=10, linewidth=3, markersize=6)
            line_index += 1

    plt.xlabel('Round', fontsize=16)
    plt.ylabel('Accuracy (%)', fontsize=16)
    plt.tick_params(labelsize=14)
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.title(title, fontsize=20)

    plt.legend(
        loc='upper center', bbox_to_anchor=(0.5, -0.15),
        fancybox=True, shadow=True, ncol=3, fontsize=14
    )

    plt.tight_layout(rect=[0, 0.1, 1, 0.95])
    
    if plot_name:
        out_path = save_path if save_path else os.path.join("./plots", f"{plot_name}.png")
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        
        plt.savefig(out_path, bbox_inches="tight", dpi=300)
        
        print(f"\nChart saved to: {out_path}")

    if show:
        plt.show()
    else:
        plt.close()
        
def plot_resilience_comparison(
    all_configs,
    models_to_compare,
    non_iid_settings,
    title="Model Resilience Comparison",
    plot_name="resilience_comparison_chart",
    save_path=None,
    show=True,
    last_n_rounds=1
):
    sns.set_style("whitegrid")
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    ax_line, ax_bar = axes

    line_colors = list(plt.cm.tab10.colors)
    line_linestyles = ['-', '--', ':', '-.', (0, (5, 1)), (0, (3, 5, 1, 5))]

    for i, model_key in enumerate(models_to_compare):
        if model_key not in all_configs:
            continue
        
        partitions = all_configs[model_key]
        iid_cfgs = partitions.get("iid", [])

        non_iid_cfgs_list = []
        for setting in non_iid_settings:
            if setting.lower() != 'iid': non_iid_cfgs_list.append(partitions.get(setting, []))

        iid_acc = _get_mean_acc_series_from_group(iid_cfgs)
        non_iid_acc_series_list = [_get_mean_acc_series_from_group(cfgs) for cfgs in non_iid_cfgs_list]
        non_iid_acc_series_list = [s for s in non_iid_acc_series_list if s is not None]

        if iid_acc is None or not non_iid_acc_series_list:
            continue

        min_len = min(len(iid_acc), *[len(s) for s in non_iid_acc_series_list])
        iid_acc = iid_acc.iloc[:min_len]
        
        non_iid_df = pd.concat([s.iloc[:min_len] for s in non_iid_acc_series_list], axis=1)
        non_iid_avg_acc = non_iid_df.mean(axis=1)
        
        relative_performance = (non_iid_avg_acc.values / (iid_acc.values + 1e-9)) * 100
        
        rounds = iid_acc.index
        ax_line.plot(rounds, relative_performance, label=model_key, linewidth=2.5,
                     color=line_colors[i % len(line_colors)], 
                     linestyle=line_linestyles[i % len(line_linestyles)])

    ax_line.set_xlabel('Round', fontsize=14)
    ax_line.set_ylabel('Relative Performance (%)', fontsize=14)
    ax_line.tick_params(labelsize=12)
    ax_line.set_ylim(bottom=0)
    ax_line.grid(True, linestyle="--", alpha=0.6)
    ax_line.legend(loc='best', fontsize=10)


    def base_name(name: str): return name.split("(")[0].strip()

    bases = []
    for k in models_to_compare:
        b = base_name(k)
        if b not in bases: bases.append(b)

    
    def _get_final_ratio(cfg_list, iid_cfg_list):
        non_iid_series = _get_mean_acc_series_from_group(cfg_list)
        iid_series = _get_mean_acc_series_from_group(iid_cfg_list)
        if non_iid_series is None or iid_series is None: return 0.0
        
        non_iid_final = non_iid_series.iloc[-1]
        
        iid_final = iid_series.iloc[-1]
        
        return (non_iid_final / iid_final) * 100 if iid_final > 0 else 0.0

    cifar_vals, tiny_vals = [], []
    for b in bases:
        key_cifar = f"{b} (CIFAR100)"
        key_tiny = f"{b} (TINYIMAGENET)"
        
        if key_cifar in all_configs:
            cifar_non_iid_cfgs = []
            for setting in non_iid_settings:
                if setting.lower() != 'iid': cifar_non_iid_cfgs.extend(all_configs[key_cifar].get(setting, []))
            cifar_ratio = _get_final_ratio(cifar_non_iid_cfgs, all_configs[key_cifar].get("iid", []))
            cifar_vals.append(cifar_ratio)
        else:
            cifar_vals.append(0)

        if key_tiny in all_configs:
            tiny_non_iid_cfgs = []
            for setting in non_iid_settings:
                if setting.lower() != 'iid': tiny_non_iid_cfgs.extend(all_configs[key_tiny].get(setting, []))
            tiny_ratio = _get_final_ratio(tiny_non_iid_cfgs, all_configs[key_tiny].get("iid", []))
            tiny_vals.append(tiny_ratio)
        else:
            tiny_vals.append(0)

    x = np.arange(len(bases))
    width = 0.35
    bars_cifar = ax_bar.bar(x - width/2, cifar_vals, width, label='CIFAR-100', hatch='///', edgecolor='black', color='royalblue')
    bars_tiny = ax_bar.bar(x + width/2, tiny_vals, width, label='TinyImageNet', edgecolor='black', color='darkorange')

    ax_bar.set_ylabel('Resilience Score (%)', fontsize=14)
    ax_bar.set_xticks(x, [b.replace(" ", "\n") for b in bases], fontsize=12)
    ax_bar.tick_params(axis='y', labelsize=12)
    ax_bar.set_ylim(0, max(110, (max(cifar_vals + tiny_vals) * 1.1) if (cifar_vals + tiny_vals) else 110))
    ax_bar.bar_label(bars_cifar, padding=3, fmt='%.1f%%', fontsize=10)
    ax_bar.bar_label(bars_tiny, padding=3, fmt='%.1f%%', fontsize=10)
    ax_bar.legend(fontsize=12)
    ax_bar.grid(axis='y', linestyle='--', alpha=0.6)

    fig.suptitle(title, fontsize=20)
    fig.tight_layout(rect=[0, 0, 1, 0.96])
    
    if plot_name:
        out_path = save_path if save_path else os.path.join("./plots", f"{plot_name}.png")
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        plt.savefig(out_path, bbox_inches="tight", dpi=300)
        print(f"\nChart saved to: {out_path}")

    if show:
        plt.show()
    else:
        plt.close()
        
def plot_technique_comparison(
    all_configs_by_technique,
    title="Performance Comparison by Technique",
    plot_name="technique_comparison_chart",
    save_path=None,
    show=True
):
    sns.set_style("whitegrid")
    
    curve_df = None
    technique_names = list(all_configs_by_technique.keys())
    
    print("--- Generating Technique Comparison Chart ---")
    for tech_name in technique_names:
        cfg_list = all_configs_by_technique[tech_name]
        print(f"Processing technique: {tech_name}")
        
        try:
            mean_series = _get_mean_acc_series_from_group(cfg_list)
            if mean_series is None:
                print(f"Warning: No data found for technique '{tech_name}'. Skipping.")
                continue
        except (FileNotFoundError, ValueError) as e:
            print(f"Warning: Could not aggregate data for '{tech_name}'. Skipping. Error: {e}")
            continue
            
        mc = pd.DataFrame(mean_series).reset_index()
        mc.columns = ['round', tech_name]
        
        if curve_df is None:
            curve_df = mc
        else:
            curve_df = curve_df.merge(mc, on="round", how="outer")

    if curve_df is None:
        print("Error: No data could be processed. Aborting plot generation.")
        return

    curve_df = curve_df.sort_values("round").reset_index(drop=True)
    
    plt.figure(figsize=(11, 8))
    ax = plt.gca()

    colors = plt.cm.tab10.colors
    linestyles = ['-', '--', ':', '-.', (0, (5, 1)), (0, (3, 5, 1, 5))]
    
    plot_columns = [col for col in curve_df.columns if col != 'round']

    for i, tech_name in enumerate(plot_columns):
        ax.plot(
            curve_df["round"],
            curve_df[tech_name],
            label=tech_name,
            linewidth=3.5,
            linestyle=linestyles[i % len(linestyles)],
            color=colors[i % len(colors)],
        )

    ax.set_xlabel("Round", fontsize=20)
    ax.set_ylabel("Test Accuracy (%)", fontsize=20)
    ax.tick_params(axis='both', labelsize=18)
    ax.grid(True, linestyle="--", alpha=0.5)
    ax.set_title(title, fontsize=22, pad=20)

    plt.legend(
        loc='upper center', bbox_to_anchor=(0.5, -0.15),
        fancybox=True, shadow=True, ncol=3, fontsize=18
    )

    plt.tight_layout(rect=[0, 0.1, 1, 0.95])
    
    if plot_name:
        out_path = save_path if save_path else os.path.join("./plots", f"{plot_name}.png")
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        plt.savefig(out_path, bbox_inches="tight", dpi=300)
        print(f"\nChart saved to: {out_path}")

    if show:
        plt.show()
    else:
        plt.close()
        
def plot_relative_performance_summary(
    all_configs_by_technique,
    baseline_technique="iid",
    title="Relative Performance and Resilience Summary",
    plot_name="relative_performance_summary",
    last_n_rounds=1,
    save_path=None,
    show=True
):
    sns.set_style("whitegrid")
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    ax_line, ax_bar = axes

    baseline_key_found = None
    for key in all_configs_by_technique.keys():
        if key.lower() == baseline_technique.lower():
            baseline_key_found = key
            break
    
    if not baseline_key_found:
        print(f"Error: Baseline technique '{baseline_technique}' not found in the provided configs.")
        plt.close(fig)
        return

    try:
        baseline_series = _get_mean_acc_series_from_group(all_configs_by_technique[baseline_key_found])
        if baseline_series is None:
            raise FileNotFoundError
    except (FileNotFoundError, ValueError) as e:
        print(f"Error: Could not load data for baseline technique '{baseline_key_found}'. Skipping. Error: {e}")
        plt.close(fig)
        return

    comparison_techniques = [k for k in all_configs_by_technique.keys() if k.lower() != baseline_technique.lower()]
    
    colors = plt.cm.tab10.colors
    linestyles = ['--', ':', '-.', (0, (5, 1)), (0, (3, 5, 1, 5)), '-']
    final_scores = {}

    for i, tech_name in enumerate(comparison_techniques):
        try:
            tech_series = _get_mean_acc_series_from_group(all_configs_by_technique[tech_name])
            if tech_series is None: continue
        except (FileNotFoundError, ValueError):
            print(f"Warning: No data for comparison technique '{tech_name}'. Skipping.")
            continue

        min_len = min(len(baseline_series), len(tech_series))
        if min_len == 0: continue
        
        baseline_aligned = baseline_series.iloc[:min_len]
        tech_aligned = tech_series.iloc[:min_len]

        relative_performance = (tech_aligned.values / (baseline_aligned.values + 1e-9)) * 100
        
        ax_line.plot(baseline_aligned.index, relative_performance, label=tech_name,
                     linestyle=linestyles[i % len(linestyles)], color=colors[i % len(colors)], linewidth=2.5)

        final_scores[tech_name] = np.mean(relative_performance[-last_n_rounds:])

    ax_line.set_xlabel("Round", fontsize=14)
    ax_line.set_ylabel("Relative Performance (%)", fontsize=14)
    ax_line.tick_params(labelsize=12)
    ax_line.grid(True, linestyle="--", alpha=0.6)
    
    labels = list(final_scores.keys())
    scores = list(final_scores.values())

    if len(labels) == 1:
        bar_width = 0.2
    elif len(labels) < 5:
        bar_width = 0.3
    else:
        bar_width = 0.4
    
    bars = ax_bar.bar(labels, scores, color=colors[:len(labels)], width=bar_width)
    
    ax_bar.set_ylim(bottom=min(0, min(scores)*0.9) if scores else 0, 
                    top=max(105, max(scores) * 1.1) if scores else 105)
    
    ax_bar.tick_params(axis='x', rotation=0, labelsize=12)
    ax_bar.tick_params(axis='y', labelsize=12)
    
    ax_bar.bar_label(bars, padding=3, fmt='%.1f%%', fontsize=12)

    if len(labels) == 1:
        ax_bar.set_xlim(-1, 1)

    handles, legend_labels = ax_line.get_legend_handles_labels()
    fig.legend(handles, legend_labels, loc='lower center', bbox_to_anchor=(0.5, -0.05), ncol=3, fontsize=12)
    
    fig.suptitle(title, fontsize=18)
    fig.tight_layout(rect=[0, 0.1, 1, 0.95])

    if plot_name:
        out_path = save_path if save_path else os.path.join("./plots", f"{plot_name}.png")
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        plt.savefig(out_path, bbox_inches="tight", dpi=300)
        print(f"\nChart saved to: {out_path}")

    if show:
        plt.show()
    else:
        plt.close()
        
def plot_hyperparameter_impact(
    all_configs_by_hparam,
    non_iid_settings,
    title="Hyperparameter Impact on Resilience",
    plot_name="hparam_impact_chart",
    last_n_rounds=1,
    save_path=None,
    show=True
):
    sns.set_style("whitegrid")
    fig, axes = plt.subplots(1, 2, figsize=(12.3, 5))
    ax_line, ax_bar = axes

    hparam_names = list(all_configs_by_hparam.keys())
    colors = plt.cm.viridis(np.linspace(0, 1, len(hparam_names)))
    linestyles = ['-', '--', ':', '-.']
    resilience_scores = {}

    print("--- Generating Hyperparameter Impact Chart ---")

    for i, hparam_name in enumerate(hparam_names):
        print(f"Processing hparam: {hparam_name}")
        
        hparam_results = all_configs_by_hparam[hparam_name]
        try:
            iid_series = _get_mean_acc_series_from_group(hparam_results.get("iid", []))
        except (FileNotFoundError, ValueError) as e:
            print(f"Warning: No IID data for hparam '{hparam_name}'. Skipping. Error: {e}")
            iid_series = None
            
        if iid_series is None:
            continue

        non_iid_series_list = []
        for setting in non_iid_settings:
            try:
                s = _get_mean_acc_series_from_group(hparam_results.get(setting, []))
                if s is not None:
                    non_iid_series_list.append(s)
            except (FileNotFoundError, ValueError):
                print(f"Warning: No Non-IID data for '{setting}' in hparam '{hparam_name}'. Skipping setting.")
                
        if not non_iid_series_list:
            print(f"Warning: Missing all Non-IID data for '{hparam_name}'. Skipping.")
            continue

        min_len = min(len(iid_series), *[len(s) for s in non_iid_series_list])
        if min_len == 0: continue
            
        iid_aligned = iid_series.iloc[:min_len]
        
        non_iid_df = pd.concat([s.iloc[:min_len] for s in non_iid_series_list], axis=1)
        non_iid_avg_acc = non_iid_df.mean(axis=1)
        
        relative_performance = (non_iid_avg_acc.values / (iid_aligned.values + 1e-9)) * 100
        ax_line.plot(iid_aligned.index, relative_performance, label=hparam_name,
                     color=colors[i], linestyle=linestyles[i % len(linestyles)], linewidth=3)
        
        resilience_scores[hparam_name] = np.mean(relative_performance[-last_n_rounds:])

    ax_line.set_xlabel('Round', fontsize=14)
    ax_line.set_ylabel('Relative Performance (%)', fontsize=14)
    ax_line.tick_params(labelsize=12)
    ax_line.set_ylim(bottom=0)
    ax_line.grid(True, linestyle="--", alpha=0.6)

    labels = list(resilience_scores.keys())
    scores = list(resilience_scores.values())

    if len(labels) == 1: bar_width = 0.2
    elif len(labels) < 5: bar_width = 0.3
    else: bar_width = 0.4
    
    bars = ax_bar.bar(labels, scores, color=colors, width=bar_width)
    ax_bar.set_xlabel('Hyperparameter Setting', fontsize=14)
    ax_bar.tick_params(axis='y', labelsize=12)
    ax_bar.tick_params(axis='x', rotation=0, labelsize=12)
    ax_bar.set_ylim(0, max(110, max(scores) * 1.1 if scores else 110))
    ax_bar.bar_label(bars, padding=3, fmt='%.1f%%', fontsize=14)
    
    if len(labels) == 1:
        ax_bar.set_xlim(-1, 1)
    
    handles, legend_labels = ax_line.get_legend_handles_labels()
    fig.legend(handles, legend_labels, loc='lower center', bbox_to_anchor=(0.5, -0.07), ncol=len(hparam_names), fontsize=14)
    
    fig.suptitle(title, fontsize=18)
    fig.tight_layout(rect=[0, 0.1, 1, 0.95])
    
    if plot_name:
        out_path = save_path if save_path else os.path.join("./plots", f"{plot_name}.png")
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        plt.savefig(out_path, bbox_inches="tight", dpi=300)
        print(f"\nChart saved to: {out_path}")

    if show:
        plt.show()
    else:
        plt.close()

def print_final_acc_and_convergence_stats(
    all_configs, 
    models_to_compare, 
    partitions_to_compare, 
    last_n_rounds=1, 
    target_ratio=0.95
):
    """
    Calculates and prints:
    1. Final Accuracy (mean/std over seeds) based on avg of last_n_rounds.
    2. Rounds to Reach Target (mean/std over seeds), where target is target_ratio * final_acc.
    3. Total Time (s) (mean/std over seeds).
    4. Total Comm. (GB) (mean/std over seeds).
    5. Total Downlink/Uplink (GB)
    """
    stats_summary = {}
    print(f"\n--- Final Accuracy, Convergence, Time, & Communication Stats (Mean ± Std over Seeds) ---")
    print(f"Target: {target_ratio*100}% of Final Acc (avg of last {last_n_rounds} round(s))")

    print("\n{:<15} | {:<12} | {:<20} | {:<20} | {:<18} | {:<18} | {:<18} | {:<18}".format(
        "Model", "Partition", "Final Acc (%)", f"Acc@{target_ratio*100:g}% (Rounds)", "Total Time (s)", "Total Comm. (GB)", "Downlink (GB)", "Uplink (GB)"
    ))
    print("-" * 153) 

    def _get_single_run_metrics(cfg, last_n_rounds, target_ratio):
        _, p = _csv_path_from_cfg(cfg)
        if not p: return None
        try:
            data = _load_round_csv(p)
            if not data["acc"] or len(data["acc"]) < last_n_rounds: return None
        except Exception: return None

        acc_series = np.array(data["acc"])
        rounds = np.array(data["round"])
        
        final_acc_val = np.mean(acc_series[-last_n_rounds:])
        
        target_acc = final_acc_val * target_ratio
        crossed_indices = np.where(acc_series >= target_acc)[0] 
        rtt = rounds[crossed_indices[0]] if len(crossed_indices) > 0 else np.nan

        comm_total = data.get("cum_comm_gb", [0.0])[-1]
        comm_downlink = data.get("cum_downlink_gb", [0.0])[-1]
        comm_uplink = data.get("cum_uplink_gb", [0.0])[-1]
        
        time_total = data.get("cum_time_sec", [0.0])[-1]

        return {
            "final_acc": final_acc_val,
            "rounds_to_target": rtt,
            "time_sec": time_total,
            "comm_gb": comm_total,
            "comm_downlink_gb": comm_downlink,
            "comm_uplink_gb": comm_uplink
        }


    for model_name in models_to_compare:
        stats_summary[model_name] = {}
        model_partitions = all_configs.get(model_name, {})

        for partition in partitions_to_compare:
            cfg_list = model_partitions.get(partition, [])
            
            run_metrics = [
                _get_single_run_metrics(cfg, last_n_rounds, target_ratio) for cfg in cfg_list
            ]
            run_metrics = [m for m in run_metrics if m is not None]
            n_seeds = len(run_metrics)
            
            if n_seeds > 0:
                df = pd.DataFrame(run_metrics)
                
                mean_final_acc = df["final_acc"].mean()
                std_final_acc = df["final_acc"].std(ddof=1) if n_seeds > 1 else 0.0
                
                rounds_to_target = df["rounds_to_target"].values
                n_reached = (~np.isnan(rounds_to_target)).sum()
                mean_rtt = np.nanmean(rounds_to_target)
                std_rtt = np.nanstd(rounds_to_target, ddof=1) if n_reached > 1 else 0.0

                mean_rtt_int = int(round(mean_rtt)) if not np.isnan(mean_rtt) else -1
                std_rtt_int = int(round(std_rtt)) if not np.isnan(std_rtt) else 0

                mean_time_sec = df["time_sec"].mean()
                std_time_sec = df["time_sec"].std(ddof=1) if n_seeds > 1 else 0.0
                
                mean_comm_gb = df["comm_gb"].mean()
                std_comm_gb = df["comm_gb"].std(ddof=1) if n_seeds > 1 else 0.0
                mean_downlink_gb = df["comm_downlink_gb"].mean()
                std_downlink_gb = df["comm_downlink_gb"].std(ddof=1) if n_seeds > 1 else 0.0
                mean_uplink_gb = df["comm_uplink_gb"].mean()
                std_uplink_gb = df["comm_uplink_gb"].std(ddof=1) if n_seeds > 1 else 0.0

                stats_summary[model_name][partition] = {
                    "final_acc_mean": mean_final_acc, "final_acc_std": std_final_acc,
                    "rounds_to_target_mean": mean_rtt, "rounds_to_target_std": std_rtt,
                    "time_sec_mean": mean_time_sec, "time_sec_std": std_time_sec,
                    "comm_gb_mean": mean_comm_gb, "comm_gb_std": std_comm_gb,
                    "comm_downlink_mean": mean_downlink_gb, "comm_downlink_std": std_downlink_gb,
                    "comm_uplink_mean": mean_uplink_gb, "comm_uplink_std": std_uplink_gb,
                    "n_seeds": n_seeds, "n_reached_target": n_reached
                }
                
                acc_str = f"{mean_final_acc:6.2f} ± {std_final_acc:.2f}"
                time_str = f"{mean_time_sec:6.2f} ± {std_time_sec:.2f}"
                comm_str = f"{mean_comm_gb:6.2f} ± {std_comm_gb:.2f}"
                downlink_str = f"{mean_downlink_gb:6.2f} ± {std_downlink_gb:.2f}"
                uplink_str = f"{mean_uplink_gb:6.2f} ± {std_uplink_gb:.2f}"
                
                if n_reached > 0:
                    rtt_str = f"{mean_rtt_int:6d} ± {std_rtt_int:d}"
                else:
                    rtt_str = "N/A"

                print(f"  {model_name:<13} | {partition:<12} | {acc_str:<20} | {rtt_str:<20} | {time_str:<18} | {comm_str:<18} | {downlink_str:<18} | {uplink_str:<18}")

            else:
                print(f"  {model_name:<13} | {partition:<12} | {'No data':<20} | {'No data':<20} | {'No data':<18} | {'No data':<18} | {'No data':<18} | {'No data':<18}")

    return stats_summary