import glob
import os
import numpy as np
import matplotlib.pyplot as plt
import argparse
import re

# ==========================================
#               配置部分
# ==========================================

# --- 表格配置 ---
TARGET_METHODS_TABLE = ['FedAvg', 'FedProx', 'FedSophia', 'FedNew', 'FedNewton']
DATASET_ORDER = ['mnist', 'fashionmnist', 'cifar10', 'cifar100']
DATASET_DISPLAY = {
    'mnist': 'MNIST', 'fashionmnist': 'Fashion-MNIST', 
    'cifar10': 'CIFAR-10', 'cifar100': 'CIFAR-100'
}

# --- 绘图样式配置 ---
PLOT_STYLE = {
    'FedAvg': {'color': 'tab:blue', 'marker': 'o', 'label': 'FedAvg'},
    'FedProx': {'color': 'tab:orange', 'marker': 's', 'label': 'FedProx'},
    'FedSophia': {'color': 'tab:green', 'marker': '^', 'label': 'FedSophia'},
    'FedNew': {'color': 'tab:red', 'marker': 'D', 'label': 'FedNew'},
    'FedNewton': {'color': 'tab:purple', 'marker': '*', 'label': 'FedNewton'},
}

# --- 想要绘制的指标配置 ---
METRICS_CONFIG = {
    'test_acc': {'title': 'Test Accuracy vs Rounds', 'ylabel': 'Accuracy (%)', 'is_percent': True},
    'test_loss': {'title': 'Test Loss vs Rounds', 'ylabel': 'Loss', 'is_percent': False},
    'train_loss': {'title': 'Training Loss vs Rounds', 'ylabel': 'Loss', 'is_percent': False},
    'max_gpu_mem': {'title': 'GPU Memory Usage', 'ylabel': 'Memory (MB)', 'is_percent': False},
}

# --- 全局存储 ---
ALL_RESULTS_FOR_TABLE = [] 
GLOBAL_RESOURCE_STATS = {} 

# ==========================================
#               核心工具函数 (数据加载)
# ==========================================

def extract_info_from_path(exp_path, dataset_name):
    exp_name = os.path.basename(exp_path)
    alpha_match = re.search(r"alpha([\d\.]+)", exp_name)
    alpha_val = float(alpha_match.group(1)) if alpha_match else 999.0
    if alpha_val.is_integer():
        alpha_str = str(int(alpha_val))
    else:
        alpha_str = str(alpha_val)
    return alpha_str, alpha_val

def load_raw_data(log_dir):
    files = glob.glob(os.path.join(log_dir, "*.npz"))
    data = {}
    if not files: return data

    for f in files:
        try:
            filename = os.path.basename(f)
            method_raw = filename.split('_')[0]
            loaded = np.load(f, allow_pickle=True)
            
            if method_raw not in data: 
                data[method_raw] = {
                    'test_acc': [], 'test_loss': [], 
                    'train_loss': [], 'max_gpu_mem': [], 
                    'wall_time': [] 
                }
            
            if 'test_acc' in loaded: data[method_raw]['test_acc'].append(loaded['test_acc'])
            if 'test_loss' in loaded: data[method_raw]['test_loss'].append(loaded['test_loss'])
            if 'train_loss' in loaded: data[method_raw]['train_loss'].append(loaded['train_loss'])
            elif 'loss' in loaded: data[method_raw]['train_loss'].append(loaded['loss'])
            if 'max_gpu_mem' in loaded: data[method_raw]['max_gpu_mem'].append(loaded['max_gpu_mem'])
            if 'wall_time' in loaded: data[method_raw]['wall_time'].append(loaded['wall_time'])
        except: pass
    return data

# ==========================================
#               表格处理模块 (未改动)
# ==========================================

def process_for_table(dataset_key, alpha_str, alpha_num, raw_data):
    stats_entry = {'acc': {}}
    if dataset_key not in GLOBAL_RESOURCE_STATS:
        GLOBAL_RESOURCE_STATS[dataset_key] = {}

    for method, metrics in raw_data.items():
        runs_acc = metrics['test_acc']
        final_accs = []
        for r in runs_acc:
            if len(r) > 0: final_accs.append(np.mean(r[-5:]) * 100)
        if final_accs:
            stats_entry['acc'][method] = (np.mean(final_accs), np.std(final_accs))

        if method not in GLOBAL_RESOURCE_STATS[dataset_key]:
            GLOBAL_RESOURCE_STATS[dataset_key][method] = {'gpu': [], 'time': []}

        runs_gpu = metrics['max_gpu_mem']
        for r in runs_gpu:
            if len(r) > 0:
                GLOBAL_RESOURCE_STATS[dataset_key][method]['gpu'].append(np.max(r))

        runs_time = metrics['wall_time']
        for r in runs_time:
            if len(r) > 1:
                avg_t = r[-1] / max(1, len(r))
                GLOBAL_RESOURCE_STATS[dataset_key][method]['time'].append(avg_t)

    if not stats_entry['acc']: return

    ALL_RESULTS_FOR_TABLE.append({
        'dataset_key': dataset_key,
        'alpha_str': alpha_str,
        'alpha_sort': alpha_num,
        'stats': stats_entry
    })

def _build_acc_latex_table(df_list):
    latex = []
    latex.append(r"\begin{table*}[ht]")
    latex.append(r"\centering")
    latex.append(r"\caption{Test accuracy comparison ($\%$). Results are averaged over the last 5 rounds.}")
    latex.append(r"\label{tab:comparison_acc}")
    col_def = "|l|c|" + "c|" * len(TARGET_METHODS_TABLE)
    latex.append(r"\begin{tabular}{" + col_def + "}")
    latex.append(r"\hline")
    headers = [r"\textbf{Dataset}", r"\textbf{$\alpha$}"] + \
              [r"\textbf{" + m + "}" for m in TARGET_METHODS_TABLE]
    latex.append(" & ".join(headers) + r" \\ \hline")

    dataset_rows = []
    for i, row in enumerate(df_list):
        ds_key = row['dataset_key']
        ds_display = DATASET_DISPLAY.get(ds_key, ds_key.upper())
        alpha = row['alpha_str']
        current_stats = row['stats']['acc']
        values_tex = []
        means = [current_stats[m][0] if m in current_stats else -1 for m in TARGET_METHODS_TABLE]
        valid_means = [m for m in means if m != -1]
        best_val = max(valid_means) if valid_means else -1

        for m in TARGET_METHODS_TABLE:
            if m in current_stats:
                mean, std = current_stats[m]
                cell_str = f"{mean:.2f}"
                if best_val != -1 and mean >= best_val - 0.01:
                     cell_str = r"\textbf{" + cell_str + "}"
            else:
                cell_str = "-"
            values_tex.append(cell_str)

        dataset_rows.append({'ds_display': ds_display, 'alpha': alpha, 'values': values_tex})
        is_last = (i == len(df_list) - 1) or (df_list[i+1]['dataset_key'] != ds_key)
        
        if is_last:
            num_rows = len(dataset_rows)
            for r_idx, r_data in enumerate(dataset_rows):
                line_parts = []
                if r_idx == 0:
                    line_parts.append(r"\multirow{" + str(num_rows) + r"}{*}{\textbf{" + r_data['ds_display'] + "}}")
                else:
                    line_parts.append("") 
                alpha_tex = r_data['alpha']
                if alpha_tex == "0.1": alpha_tex = r"\textbf{" + alpha_tex + "}"
                line_parts.append(alpha_tex)
                line_parts.extend(r_data['values'])
                latex.append(" & ".join(line_parts) + r" \\")
            latex.append(r"\hline")
            dataset_rows = []
    latex.append(r"\end{tabular}")
    latex.append(r"\end{table*}")
    return "\n".join(latex)

def _build_resource_latex_table():
    if not GLOBAL_RESOURCE_STATS: return ""
    sorted_datasets = sorted(GLOBAL_RESOURCE_STATS.keys(), key=lambda x: DATASET_ORDER.index(x) if x in DATASET_ORDER else 999)
    latex = []
    latex.append(r"\begin{table*}[ht]")
    latex.append(r"\centering")
    latex.append(r"\caption{Comparison of average training time per round (seconds) and peak GPU memory usage (MB). Results are averaged across all $\alpha$ settings. Lower is better.}")
    latex.append(r"\label{tab:resources}")
    col_def = "|l|c|" + "c|" * len(TARGET_METHODS_TABLE)
    latex.append(r"\begin{tabular}{" + col_def + "}")
    latex.append(r"\hline")
    headers = [r"\textbf{Dataset}", r"\textbf{Metric}"] + \
              [r"\textbf{" + m + "}" for m in TARGET_METHODS_TABLE]
    latex.append(" & ".join(headers) + r" \\ \hline")

    for ds_key in sorted_datasets:
        ds_display = DATASET_DISPLAY.get(ds_key, ds_key.upper())
        methods_data = GLOBAL_RESOURCE_STATS[ds_key]
        time_means = {}
        gpu_means = {}
        for m in TARGET_METHODS_TABLE:
            if m in methods_data and methods_data[m]['time']:
                time_means[m] = np.mean(methods_data[m]['time'])
            else: time_means[m] = -1
            if m in methods_data and methods_data[m]['gpu']:
                gpu_means[m] = np.mean(methods_data[m]['gpu'])
            else: gpu_means[m] = -1

        valid_times = [v for v in time_means.values() if v > 0]
        min_time = min(valid_times) if valid_times else -1
        valid_gpus = [v for v in gpu_means.values() if v > 0]
        min_gpu = min(valid_gpus) if valid_gpus else -1

        row_time = [r"\multirow{2}{*}{\textbf{" + ds_display + "}}", "Time (s)"]
        for m in TARGET_METHODS_TABLE:
            val = time_means[m]
            if val > 0:
                txt = f"{val:.2f}"
                if min_time > 0 and val <= min_time + 0.01:
                    txt = r"\textbf{" + txt + "}"
                row_time.append(txt)
            else: row_time.append("-")
        latex.append(" & ".join(row_time) + r" \\")

        row_gpu = ["", "GPU (MB)"]
        for m in TARGET_METHODS_TABLE:
            val = gpu_means[m]
            if val > 0:
                txt = f"{val:.2f}"
                if min_gpu > 0 and val <= min_gpu + 1.0:
                    txt = r"\textbf{" + txt + "}"
                row_gpu.append(txt)
            else: row_gpu.append("-")
        latex.append(" & ".join(row_gpu) + r" \\ \hline")
    latex.append(r"\end{tabular}")
    latex.append(r"\end{table*}")
    return "\n".join(latex)

def generate_latex_tables():
    if not ALL_RESULTS_FOR_TABLE: return ""
    df_list = sorted(ALL_RESULTS_FOR_TABLE, key=lambda x: (
        DATASET_ORDER.index(x['dataset_key']) if x['dataset_key'] in DATASET_ORDER else 999, 
        x['alpha_sort']
    ))
    full_latex = ""
    full_latex += _build_acc_latex_table(df_list) + "\n\n"
    full_latex += _build_resource_latex_table() + "\n\n"
    return full_latex

# ==========================================
#               绘图处理模块 (核心修改)
# ==========================================

def get_curve_stats(run_list):
    """
    修改点：不仅计算均值，还计算标准差
    返回: (mean_curve, std_curve) 或者 None
    """
    if not run_list: return None
    min_len = min(len(r) for r in run_list)
    if min_len == 0: return None
    # 截断对齐
    arr = np.array([r[:min_len] for r in run_list])
    mean_val = np.mean(arr, axis=0)
    std_val = np.std(arr, axis=0)
    return mean_val, std_val

def process_for_plot_all_metrics(raw_data):
    processed_metrics = {k: {} for k in METRICS_CONFIG.keys()}
    processed_metrics['time_vs_acc'] = {} 

    for method, data_dict in raw_data.items():
        # 1. 常规指标 (需要标准差)
        for metric_key in METRICS_CONFIG.keys():
            res = get_curve_stats(data_dict.get(metric_key, []))
            if res is not None:
                processed_metrics[metric_key][method] = res  # 存储 (mean, std)
        
        # 2. Time vs Acc (保持原样，只用均值，不用标准差，因为时间轴不对齐难画阴影)
        acc_runs = data_dict.get('test_acc', [])
        time_runs = data_dict.get('wall_time', [])
        if acc_runs and time_runs:
            res_acc = get_curve_stats(acc_runs)
            res_time = get_curve_stats(time_runs)
            
            if res_acc is not None and res_time is not None:
                # 解包 (mean, std)，我们这里只取 mean
                avg_acc = res_acc[0]
                avg_time = res_time[0]
                common_len = min(len(avg_acc), len(avg_time))
                processed_metrics['time_vs_acc'][method] = (avg_time[:common_len], avg_acc[:common_len])
    return processed_metrics

def plot_generic_metric(dataset_name, metric_key, alpha_data_map, output_dir):
    alphas = sorted(alpha_data_map.keys())
    has_data = False
    for a in alphas:
        if alpha_data_map[a].get(metric_key): has_data = True
    if not has_data: return

    config = METRICS_CONFIG[metric_key]
    num_plots = len(alphas)
    fig, axes = plt.subplots(1, num_plots, figsize=(5 * num_plots, 3.5), sharey=False)
    if num_plots == 1: axes = [axes]
    if not isinstance(axes, np.ndarray): axes = [axes]

    for idx, alpha in enumerate(alphas):
        ax = axes[idx]
        curves_map = alpha_data_map[alpha].get(metric_key, {})
        sorted_methods = sorted(curves_map.keys(), key=lambda x: list(PLOT_STYLE.keys()).index(x) if x in PLOT_STYLE else 999)

        for method in sorted_methods:
            # 解包 mean 和 std
            mean_curve, std_curve = curves_map[method]
            
            if config['is_percent']: 
                mean_curve = mean_curve * 100
                std_curve = std_curve * 100
                
            x = np.arange(len(mean_curve))
            style = PLOT_STYLE.get(method, {})
            label = style.get('label', method)
            color = style.get('color', None)
            
            # 1. 绘制均值线
            ax.plot(x, mean_curve, label=label, color=color, 
                    marker=style.get('marker', None), markevery=max(1, len(x)//10), 
                    linewidth=2, alpha=0.9) # 稍微提高透明度让线更清楚
            
            # 2. 绘制标准差阴影 (fill_between)
            ax.fill_between(x, mean_curve - std_curve, mean_curve + std_curve, 
                            color=color, alpha=0.2, linewidth=0) # alpha=0.2 比较淡，适合做阴影
            
        ax.set_title(f"$\\alpha = {alpha}$", fontsize=16)
        ax.set_xlabel("Rounds", fontsize=12)
        ax.grid(True, linestyle='--', alpha=0.3)
        if idx == 0: 
            ax.set_ylabel(config['ylabel'], fontsize=12)
            # Legend 放在第一个子图内
            ax.legend(fontsize=10, frameon=True, loc='best') 

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{dataset_name}_{metric_key}.png"), dpi=300, bbox_inches='tight')
    plt.close()

def plot_time_vs_acc(dataset_name, alpha_data_map, output_dir):
    # 这里的逻辑保持不变，因为 Time 轴不对齐，一般不画阴影，太乱
    alphas = sorted(alpha_data_map.keys())
    has_data = any(alpha_data_map[a].get('time_vs_acc') for a in alphas)
    if not has_data: return

    num_plots = len(alphas)
    fig, axes = plt.subplots(1, num_plots, figsize=(5 * num_plots, 3.5), sharey=True)
    if num_plots == 1: axes = [axes]
    if not isinstance(axes, np.ndarray): axes = [axes]
    
    for idx, alpha in enumerate(alphas):
        ax = axes[idx]
        curves_map = alpha_data_map[alpha].get('time_vs_acc', {})
        sorted_methods = sorted(curves_map.keys(), key=lambda x: list(PLOT_STYLE.keys()).index(x) if x in PLOT_STYLE else 999)

        for method in sorted_methods:
            t_curve, acc_curve = curves_map[method]
            style = PLOT_STYLE.get(method, {})
            label = style.get('label', method)
            ax.plot(t_curve, acc_curve * 100, label=label, color=style.get('color', None), marker=style.get('marker', None), markevery=max(1, len(t_curve)//10), linewidth=2, alpha=0.85)
        
        ax.set_title(f"$\\alpha = {alpha}$", fontsize=16)
        ax.set_xlabel("Time (s)", fontsize=12)
        ax.grid(True, linestyle='--', alpha=0.3)
        if idx == 0: 
            ax.set_ylabel("Test Accuracy (%)", fontsize=12)
            ax.legend(fontsize=10, frameon=True, loc='best')

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{dataset_name}_time_acc.png"), dpi=300, bbox_inches='tight')
    plt.close()

# ==========================================
#               主程序
# ==========================================

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--root_dir", type=str, default=".", help="Root directory")
    parser.add_argument("--output_dir", type=str, default="analysis_results_full", help="Output directory")
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    PLOT_DATA_STORAGE = {}
    print(f"Scanning directory: {args.root_dir} ...")
    
    exp_folders = [d for d in os.listdir(args.root_dir) if os.path.isdir(os.path.join(args.root_dir, d)) and 'alpha' in d]
    
    for exp_name in sorted(exp_folders):
        exp_path = os.path.join(args.root_dir, exp_name)
        ds_folders = [d for d in os.listdir(exp_path) if os.path.isdir(os.path.join(exp_path, d)) and not d.startswith('.')]

        for ds_name in ds_folders:
            logs_path = os.path.join(exp_path, ds_name, "logs")
            if not os.path.exists(logs_path): continue
            
            alpha_str, alpha_val = extract_info_from_path(exp_path, ds_name)
            raw_data = load_raw_data(logs_path)
            if not raw_data: continue

            process_for_table(ds_name, alpha_str, alpha_val, raw_data)
            metrics_data = process_for_plot_all_metrics(raw_data)
            
            if ds_name not in PLOT_DATA_STORAGE: PLOT_DATA_STORAGE[ds_name] = {}
            PLOT_DATA_STORAGE[ds_name][alpha_val] = metrics_data

    print("\nGenerating LaTeX Tables...")
    latex_code = generate_latex_tables()
    with open(os.path.join(args.output_dir, "unified_tables.tex"), "w") as f:
        f.write(latex_code)
    print(f"    -> Saved: unified_tables.tex")

    print("\nGenerating Plots...")
    for ds_name, alpha_map in PLOT_DATA_STORAGE.items():
        for metric_key in METRICS_CONFIG.keys():
            plot_generic_metric(ds_name, metric_key, alpha_map, args.output_dir)
        plot_time_vs_acc(ds_name, alpha_map, args.output_dir)

    print("\nAll tasks completed.")
