import glob
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import argparse

# ==========================================
#         全局绘图样式设置 (巨型字体版 1.5x)
# ==========================================
plt.rcParams.update({
    'font.size': 27,          # 全局默认字体大小 (原18 * 1.5)
    'axes.titlesize': 36,     # 子图标题字体大小 (原24 * 1.5)
    'axes.labelsize': 33,     # 轴标签字体大小 (原22 * 1.5)
    'xtick.labelsize': 27,    # X轴刻度字体大小 (原18 * 1.5)
    'ytick.labelsize': 27,    # Y轴刻度字体大小 (原18 * 1.5)
    'legend.fontsize': 27,    # 图例字体大小 (原18 * 1.5)
    'figure.titlesize': 42,   # 整个画布标题字体大小
    'lines.linewidth': 6,     # 全局线条粗细 (原4 * 1.5)
    'lines.markersize': 15,   # 全局标记大小
})

# --- Configuration: Target Model Aliases ---
BACKBONE_ALIAS = {
    'resnet18': 'ResNet-18',
    'tv_resnet50': 'ResNet-50',
    'vit_base_patch14_dinov2.lvd142m': 'ViT-Base (DinoV2)',
    'vit_base_patch16_clip_224': 'ViT-Base (CLIP)'
}

# --- Configuration: Cutoff Threshold ---
ACC_CUTOFF_THRESHOLD = 0.95  

def load_data(log_dir):
    """
    Recursively loads .npz log files and aggregates data by backbone model.
    """
    search_pattern = os.path.join(log_dir, "**", "*.npz")
    files = glob.glob(search_pattern, recursive=True)
    
    data = {} 
    
    if not files:
        print(f"No .npz files found in {log_dir}")
        return None

    print(f"Found {len(files)} log files.")

    for f in files:
        try:
            rel_path = os.path.relpath(f, log_dir)
            path_parts = rel_path.split(os.sep)
            raw_backbone = path_parts[0] if len(path_parts) > 0 else "unknown"
            
            if raw_backbone not in BACKBONE_ALIAS:
                continue
            
            display_name = BACKBONE_ALIAS[raw_backbone]
            loaded = np.load(f, allow_pickle=True)

            if display_name not in data:
                data[display_name] = {
                    'acc': [], 
                    'test_loss': [], 
                    'train_loss': [], 
                    'time': [],        
                    'gpu_mem': []      
                }
            
            if 'test_acc' in loaded and len(loaded['test_acc']) > 0:
                data[display_name]['acc'].append(loaded['test_acc'])
                data[display_name]['test_loss'].append(loaded['test_loss'])
                data[display_name]['time'].append(loaded['wall_time'])
                
                if 'train_loss' in loaded:
                    data[display_name]['train_loss'].append(loaded['train_loss'])
                
                if 'gpu_mem' in loaded:
                    data[display_name]['gpu_mem'].append(loaded['gpu_mem'])
                elif 'max_gpu_mem' in loaded:
                    val = loaded['max_gpu_mem']
                    data[display_name]['gpu_mem'].append(val)

        except Exception as e:
            print(f"Error loading {f}: {e}")
    
    return data

def get_truncated_data(metrics, metric_key, use_cutoff):
    """
    Truncates metric arrays based on when the accuracy threshold was reached.
    """
    raw_data_list = metrics.get(metric_key, [])
    if not raw_data_list:
        return []

    if isinstance(raw_data_list[0], (int, float, np.number)):
        return raw_data_list

    if not use_cutoff:
        return raw_data_list

    processed_data = []
    ref_acc_list = metrics.get('acc', []) 

    for idx, run in enumerate(raw_data_list):
        if not ref_acc_list or idx >= len(ref_acc_list):
            processed_data.append(run)
            continue
            
        ref_run_acc = np.array(ref_acc_list[idx])
        cutoff_indices = np.where(ref_run_acc >= ACC_CUTOFF_THRESHOLD)[0]
        
        if len(cutoff_indices) > 0:
            cutoff_idx = cutoff_indices[0]
            cutoff_len = min(len(run), cutoff_idx + 2) 
            processed_data.append(run[:cutoff_len])
        else:
            processed_data.append(run)
            
    return processed_data

def plot_metric(data, metric_key, y_label, title, filename, fig_dir, use_cutoff=False):
    """
    Generic plotting function with GIGANTIC FONTS (1.5x larger).
    """
    # 画布尺寸必须显著增加，否则文字会互相重叠
    plt.figure(figsize=(18, 12)) 
    colors = plt.cm.tab10(np.linspace(0, 1, 10))
    markers = ['o', 's', '^', 'D', 'v', '<', '>'] 
    
    has_data = False
    sorted_keys = sorted(data.keys()) 
    ax = plt.gca()

    for i, name in enumerate(sorted_keys):
        metrics = data[name]
        truncated_list = get_truncated_data(metrics, metric_key, use_cutoff)
        
        if not truncated_list: continue

        try:
            min_len = min(len(run) for run in truncated_list if hasattr(run, '__len__'))
        except:
            continue

        if min_len == 0: continue
        
        aligned_data = [run[:min_len] for run in truncated_list]
        arr = np.array(aligned_data) 
        
        if arr.ndim != 2: continue 
        
        has_data = True
        mean = np.mean(arr, axis=0)
        std = np.std(arr, axis=0)
        x = np.arange(1, min_len + 1)

        color = colors[i % len(colors)]
        marker = markers[i % len(markers)]
        
        # 线宽 6.0, 标记大小 18
        plt.plot(x, mean, label=name, linewidth=6.0, color=color, marker=marker, 
                 markevery=max(1, len(x)//10), markersize=18)
        plt.fill_between(x, mean - std, mean + std, color=color, alpha=0.15)
        
    if not has_data:
        print(f"No valid sequence data to plot for {metric_key}.")
        plt.close()
        return
    
    final_title = title
    if use_cutoff:
        final_title += f" (Truncated at {ACC_CUTOFF_THRESHOLD*100:.0f}%)"

    # 显式设置巨型字体 (Labels: 33, Title: 36)
    plt.xlabel("Communication Rounds", fontsize=33, fontweight='bold')
    plt.ylabel(y_label, fontsize=33, fontweight='bold')
    plt.title(final_title, fontsize=36, fontweight='bold', pad=25)
    
    # 图例 (Legend: 27)
    plt.legend(fontsize=27, loc='best', framealpha=0.9)
    
    # 刻度 (Ticks: 27), 加粗加长
    ax.tick_params(axis='both', which='major', labelsize=27, width=4, length=10)

    if metric_key == 'acc':
        plt.axhline(y=ACC_CUTOFF_THRESHOLD, color='gray', linestyle='--', 
                    alpha=0.6, linewidth=4, label=f'{ACC_CUTOFF_THRESHOLD*100:.0f}% Threshold')

    # 网格线加粗
    plt.grid(True, linestyle='--', alpha=0.4, linewidth=2.5)
    plt.tight_layout()
    
    plt.savefig(os.path.join(fig_dir, filename), dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Generated {filename}")

def plot_time_vs_acc(data, fig_dir, use_cutoff=False):
    """
    Plots Test Accuracy against Wall-clock Time with GIGANTIC FONTS.
    """
    # 画布尺寸增加到 18x12
    plt.figure(figsize=(18, 12))
    colors = plt.cm.tab10(np.linspace(0, 1, 10))
    markers = ['o', 's', '^', 'D', 'v', '<', '>']
    
    has_data = False
    sorted_keys = sorted(data.keys())
    ax = plt.gca()

    for i, name in enumerate(sorted_keys):
        metrics = data[name]
        acc_list = get_truncated_data(metrics, 'acc', use_cutoff)
        time_list = get_truncated_data(metrics, 'time', use_cutoff)
        
        if not acc_list or not time_list: continue
            
        min_len = min(min(len(r) for r in acc_list), min(len(r) for r in time_list))
        if min_len == 0: continue
        
        acc_arr = np.array([r[:min_len] for r in acc_list])
        time_arr = np.array([r[:min_len] for r in time_list])
        
        mean_acc = np.mean(acc_arr, axis=0)
        mean_time = np.mean(time_arr, axis=0)
        
        has_data = True
        color = colors[i % len(colors)]
        marker = markers[i % len(markers)]

        # 线宽 6.0, 标记大小 18
        plt.plot(mean_time, mean_acc, label=name, linewidth=6.0, color=color, 
                 marker=marker, markevery=max(1, len(mean_time)//10), markersize=18)

    if not has_data: return

    final_title = "Test Accuracy vs. Wall-clock Time"
    if use_cutoff: final_title += f" (Truncated)"

    # 显式设置巨型字体
    plt.xlabel("Wall-clock Time (s)", fontsize=33, fontweight='bold')
    plt.ylabel("Test Accuracy", fontsize=33, fontweight='bold')
    plt.title(final_title, fontsize=36, fontweight='bold', pad=25)
    plt.legend(fontsize=27, loc='best', framealpha=0.9)
    
    plt.axhline(y=ACC_CUTOFF_THRESHOLD, color='gray', linestyle='--', alpha=0.6, linewidth=4)
    
    # 刻度 (Ticks: 27), 加粗加长
    ax.tick_params(axis='both', which='major', labelsize=27, width=4, length=10)
    
    plt.grid(True, linestyle='--', alpha=0.4, linewidth=2.5)
    plt.tight_layout()
    
    plt.savefig(os.path.join(fig_dir, 'compare_time_vs_acc.png'), dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Generated compare_time_vs_acc.png")

def save_latex_table(df, fig_dir):
    latex_path = os.path.join(fig_dir, "backbone_comparison.tex")
    latex_code = df.to_latex(index=False, caption=f"Model Comparison (Cutoff Acc: {ACC_CUTOFF_THRESHOLD*100:.0f}\\%)", label="tab:backbone_comparison", column_format="lccccc")
    latex_code = latex_code.replace("\\toprule", "\\toprule\n\\textbf{Model} & \\textbf{Final Acc} & \\textbf{Best Acc} & \\textbf{Rounds} & \\textbf{Time (s)} & \\textbf{GPU Mem} \\\\")
    
    with open(latex_path, "w") as f:
        f.write(latex_code)
    
    print("\n" + "="*30 + " LaTeX Output " + "="*30)
    print(latex_code)
    print("="*74)
    print(f"LaTeX table saved to {latex_path}")

def generate_summary_table(data, fig_dir):
    rows = []
    sorted_keys = sorted(data.keys())
    
    for name in sorted_keys:
        metrics = data[name]
        if not metrics['acc']: continue
            
        raw_accs = metrics['acc']
        final_vals, best_vals, time_vals, rounds_vals, gpu_vals = [], [], [], [], []
        
        for i, run_acc in enumerate(raw_accs):
            if len(run_acc) == 0: continue
            run_arr = np.array(run_acc)
            
            best_vals.append(np.max(run_arr) * 100)
            
            target_idx = np.where(run_arr >= ACC_CUTOFF_THRESHOLD)[0]
            if len(target_idx) > 0:
                idx = target_idx[0]
                rounds_vals.append(idx + 1)
                final_vals.append(run_arr[idx] * 100)
                
                if i < len(metrics['time']):
                    t_run = metrics['time'][i]
                    time_vals.append(t_run[idx] if len(t_run) > idx else t_run[-1])
            else:
                rounds_vals.append(np.nan)
                final_vals.append(run_arr[-1] * 100)
                if i < len(metrics['time']):
                    t_run = metrics['time'][i]
                    time_vals.append(t_run[-1] if len(t_run) > 0 else 0)

            if i < len(metrics['gpu_mem']):
                g_mem = metrics['gpu_mem'][i]
                if isinstance(g_mem, (list, np.ndarray)):
                    gpu_vals.append(np.max(g_mem))
                else:
                    gpu_vals.append(g_mem)

        if not final_vals: continue

        mean_final, std_final = np.mean(final_vals), np.std(final_vals)
        mean_best, std_best = np.mean(best_vals), np.std(best_vals)
        mean_time, std_time = (np.mean(time_vals), np.std(time_vals)) if time_vals else (0, 0)
        
        valid_rounds = [r for r in rounds_vals if not np.isnan(r)]
        str_rounds = f"{np.mean(valid_rounds):.1f}" if valid_rounds else "> Max"

        if gpu_vals:
            mean_gpu = np.mean(gpu_vals)
            str_gpu = f"{mean_gpu:.0f} MB"
        else:
            str_gpu = "N/A"

        rows.append({
            "Backbone Model": name,
            "Final Acc (%)": f"{mean_final:.2f} ± {std_final:.2f}",
            "Best Acc (%)": f"{mean_best:.2f} ± {std_best:.2f}",
            "Rounds": str_rounds,
            "Time (s)": f"{mean_time:.0f} ± {std_time:.0f}",
            "Max GPU Mem": str_gpu
        })
    
    if not rows:
        print("No matching models found.")
        return

    df = pd.DataFrame(rows)
    
    print("\n" + "="*100)
    print(f" SUMMARY TABLE (Threshold: {ACC_CUTOFF_THRESHOLD*100:.0f}%)")
    print("="*100)
    print(df.to_string(index=False))
    
    csv_path = os.path.join(fig_dir, "backbone_comparison.csv")
    df.to_csv(csv_path, index=False)
    print(f"\nTable saved to {csv_path}")
    
    save_latex_table(df, fig_dir)

if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--log_dir", type=str, default=".", help="Root log directory to scan")
    p.add_argument("--fig_dir", type=str, default="figs", help="Directory for output figures and tables")
    args = p.parse_args()

    LOG_DIR = os.path.abspath(args.log_dir)
    FIG_DIR = os.path.join(LOG_DIR, "analysis_results") if args.fig_dir == "figs" else os.path.abspath(args.fig_dir)
    os.makedirs(FIG_DIR, exist_ok=True)

    print(f"Scanning: {LOG_DIR}")
    data = load_data(LOG_DIR)
    
    if data:
        print("\n--- Generating Plots ---")
        plot_metric(data, 'acc', 'Test Accuracy', 'Test Accuracy', 'compare_acc.png', FIG_DIR, use_cutoff=True)
        plot_metric(data, 'test_loss', 'Test Loss', 'Test Loss', 'compare_test_loss.png', FIG_DIR, use_cutoff=True)
        plot_metric(data, 'train_loss', 'Training Loss', 'Training Loss', 'compare_train_loss.png', FIG_DIR, use_cutoff=True)
        plot_metric(data, 'gpu_mem', 'GPU Memory (MB)', 'GPU Memory Usage', 'compare_gpu_mem.png', FIG_DIR, use_cutoff=True)
        plot_time_vs_acc(data, FIG_DIR, use_cutoff=True)

        generate_summary_table(data, FIG_DIR)
    else:
        print("No data found.")
