import glob
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import argparse
import sys

# --- Configuration: Target Method Aliases ---
METHOD_ALIAS = {
    'cg': 'CG',
    # 'diag': 'Diagonal',
    'lbfgs': 'L-BFGS',
    'lowrank': 'Low-Rank',
    'exact': 'Exact Hessian',
    # 'neumann': 'Neumann'
}

# --- ORDERING CONFIGURATION ---
# Define the explicit order you want. Any method not listed here will be sorted alphabetically before these.
# "Exact Hessian" is placed last as requested.
SORT_ORDER = [
    'L-BFGS',
    'Low-Rank',
    'CG',
    'Exact Hessian' 
]

def custom_sort_key(method_name):
    """Helper function to sort methods based on SORT_ORDER list."""
    if method_name in SORT_ORDER:
        return SORT_ORDER.index(method_name)
    else:
        # If a method is not in the list, place it at the beginning (or handle as needed)
        return -1

# --- Configuration: Default Thresholds ---
DEFAULT_THRESHOLDS = {
    'cifar10': 0.88,
    'cifar100': 0.90,
    'default': 0.80
}

def get_cutoff_threshold(dataset_name):
    return DEFAULT_THRESHOLDS.get(dataset_name, DEFAULT_THRESHOLDS['default'])

def load_data(log_dir):
    """
    Recursively loads .npz log files and aggregates data by Hessian approximation method.
    Expected structure: log_dir/method_name/seed/stats.npz
    """
    search_pattern = os.path.join(log_dir, "**", "*.npz")
    files = glob.glob(search_pattern, recursive=True)
    
    data = {} 
    
    if not files:
        print(f"[Warning] No .npz files found in {log_dir}")
        return None

    print(f"Found {len(files)} log files.")

    for f in files:
        try:
            # Parse method name from directory structure
            rel_path = os.path.relpath(f, log_dir)
            path_parts = rel_path.split(os.sep)
            
            raw_method = "unknown"
            for part in path_parts:
                if part in METHOD_ALIAS:
                    raw_method = part
                    break
            
            if raw_method not in METHOD_ALIAS:
                continue
            
            display_name = METHOD_ALIAS[raw_method]
            loaded = np.load(f, allow_pickle=True)

            if display_name not in data:
                data[display_name] = {
                    'acc': [], 'test_loss': [], 'train_loss': [], 'time': [], 'gpu_mem': []      
                }
            
            if 'test_acc' in loaded and len(loaded['test_acc']) > 0:
                data[display_name]['acc'].append(loaded['test_acc'])
                data[display_name]['test_loss'].append(loaded['test_loss'])
                data[display_name]['time'].append(loaded['wall_time'])
                
                if 'train_loss' in loaded:
                    data[display_name]['train_loss'].append(loaded['train_loss'])
                
                if 'gpu_mem' in loaded:
                    data[display_name]['gpu_mem'].append(loaded['gpu_mem'])
                elif 'max_gpu_mem' in loaded:
                    val = loaded['max_gpu_mem']
                    data[display_name]['gpu_mem'].append(val)

        except Exception as e:
            print(f"Error loading {f}: {e}")
    
    return data

def get_truncated_data(metrics, metric_key, use_cutoff, threshold):
    raw_data_list = metrics.get(metric_key, [])
    if not raw_data_list: return []

    if isinstance(raw_data_list[0], (int, float, np.number)):
        return raw_data_list

    if not use_cutoff: return raw_data_list

    processed_data = []
    ref_acc_list = metrics.get('acc', []) 

    for idx, run in enumerate(raw_data_list):
        if not ref_acc_list or idx >= len(ref_acc_list):
            processed_data.append(run)
            continue
            
        ref_run_acc = np.array(ref_acc_list[idx])
        cutoff_indices = np.where(ref_run_acc >= threshold)[0]
        
        if len(cutoff_indices) > 0:
            cutoff_idx = cutoff_indices[0]
            cutoff_len = min(len(run), cutoff_idx + 5) 
            processed_data.append(run[:cutoff_len])
        else:
            processed_data.append(run)
            
    return processed_data

def plot_metric(data, metric_key, y_label, title, filename, fig_dir, threshold, use_cutoff=False):
    # Increased figure size slightly to accommodate larger fonts
    plt.figure(figsize=(12, 9)) 
    colors = plt.cm.Dark2(np.linspace(0, 1, 8))
    markers = ['o', 's', '^', 'D', 'v', '<', '>'] 
    
    has_data = False
    
    # Sort keys using custom logic (Exact Hessian last)
    sorted_keys = sorted(data.keys(), key=custom_sort_key) 

    for i, name in enumerate(sorted_keys):
        metrics = data[name]
        truncated_list = get_truncated_data(metrics, metric_key, use_cutoff, threshold)
        
        if not truncated_list: continue

        try:
            min_len = min(len(run) for run in truncated_list if hasattr(run, '__len__'))
        except:
            continue
        if min_len == 0: continue
        
        aligned_data = [run[:min_len] for run in truncated_list]
        arr = np.array(aligned_data) 
        if arr.ndim != 2: continue 
        
        has_data = True
        mean = np.mean(arr, axis=0)
        std = np.std(arr, axis=0)
        x = np.arange(1, min_len + 1)

        color = colors[i % len(colors)]
        marker = markers[i % len(markers)]
        
        # Increased linewidth and markersize
        plt.plot(x, mean, label=name, linewidth=3, color=color, marker=marker, 
                 markevery=max(1, len(x)//10), markersize=10)
        plt.fill_between(x, mean - std, mean + std, color=color, alpha=0.1)
        
    if not has_data:
        plt.close()
        return
    
    final_title = title
    if use_cutoff:
        final_title += f" (Truncated at {threshold*100:.0f}%)"

    # --- FONT SIZE MODIFICATIONS ---
    plt.xlabel("Communication Rounds", fontsize=24)
    plt.ylabel(y_label, fontsize=24)
    plt.title(final_title, fontsize=28)
    plt.legend(fontsize=20, loc='best')
    
    # Increase tick label size
    plt.tick_params(axis='both', which='major', labelsize=20)
    
    if metric_key == 'acc':
        plt.axhline(y=threshold, color='gray', linestyle='--', 
                    alpha=0.5, label=f'{threshold*100:.0f}% Threshold')

    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    plt.savefig(os.path.join(fig_dir, filename), dpi=300)
    plt.close()
    print(f"Generated {filename}")

def plot_time_vs_acc(data, fig_dir, threshold, dataset_name, use_cutoff=False):
    # Increased figure size slightly to accommodate larger fonts
    plt.figure(figsize=(12, 9))
    colors = plt.cm.Dark2(np.linspace(0, 1, 8))
    markers = ['o', 's', '^', 'D', 'v', '<', '>']
    
    has_data = False
    
    # Sort keys using custom logic (Exact Hessian last)
    sorted_keys = sorted(data.keys(), key=custom_sort_key)

    for i, name in enumerate(sorted_keys):
        metrics = data[name]
        acc_list = get_truncated_data(metrics, 'acc', use_cutoff, threshold)
        time_list = get_truncated_data(metrics, 'time', use_cutoff, threshold)
        
        if not acc_list or not time_list: continue
            
        min_len = min(min(len(r) for r in acc_list), min(len(r) for r in time_list))
        if min_len == 0: continue
        
        acc_arr = np.array([r[:min_len] for r in acc_list])
        time_arr = np.array([r[:min_len] for r in time_list])
        
        mean_acc = np.mean(acc_arr, axis=0)
        mean_time = np.mean(time_arr, axis=0)
        
        has_data = True
        color = colors[i % len(colors)]
        marker = markers[i % len(markers)]

        # Increased linewidth and markersize
        plt.plot(mean_time, mean_acc, label=name, linewidth=3, color=color, 
                 marker=marker, markevery=max(1, len(mean_time)//10), markersize=10)

    if not has_data: return

    final_title = "Test Accuracy vs. Wall-clock Time"
    if use_cutoff: final_title += f" (Truncated)"

    # --- FONT SIZE MODIFICATIONS ---
    plt.xlabel("Wall-clock Time (s)", fontsize=24)
    plt.ylabel("Test Accuracy", fontsize=24)
    plt.title(final_title, fontsize=28)
    plt.legend(fontsize=20, loc='best')
    
    # Increase tick label size
    plt.tick_params(axis='both', which='major', labelsize=20)
    
    plt.axhline(y=threshold, color='gray', linestyle='--', alpha=0.5)
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    
    # --- FILENAME MODIFICATION ---
    filename = f'{dataset_name}_solvers_compare_time_vs_acc.png'
    plt.savefig(os.path.join(fig_dir, filename), dpi=300)
    plt.close()
    print(f"Generated {filename}")

def save_latex_table(df, fig_dir, threshold):
    latex_path = os.path.join(fig_dir, "hessian_method_comparison.tex")
    
    # Generate LaTeX with corrected escape sequence
    latex_code = df.to_latex(index=False, caption=f"Method Comparison (Threshold: {threshold*100:.0f}\\%)", label="tab:hessian", column_format="lccccc")
    latex_code = latex_code.replace("\\toprule", "\\toprule\n\\textbf{Method} & \\textbf{Final Acc} & \\textbf{Best Acc} & \\textbf{Rounds} & \\textbf{Time (s)} & \\textbf{GPU Mem} \\\\")
    
    with open(latex_path, "w") as f:
        f.write(latex_code)
    print(f"LaTeX table saved to {latex_path}")

def generate_summary_table(data, fig_dir, threshold):
    rows = []
    
    # Sort keys using custom logic (Exact Hessian last)
    sorted_keys = sorted(data.keys(), key=custom_sort_key)
    
    for name in sorted_keys:
        metrics = data[name]
        if not metrics['acc']: continue
        
        raw_accs = metrics['acc']
        final_vals, best_vals, time_vals, rounds_vals, gpu_vals = [], [], [], [], []
        
        for i, run_acc in enumerate(raw_accs):
            if len(run_acc) == 0: continue
            run_arr = np.array(run_acc)
            best_vals.append(np.max(run_arr) * 100)
            
            target_idx = np.where(run_arr >= threshold)[0]
            if len(target_idx) > 0:
                idx = target_idx[0]
                rounds_vals.append(idx + 1)
                final_vals.append(run_arr[idx] * 100)
                if i < len(metrics['time']):
                    t_run = metrics['time'][i]
                    time_vals.append(t_run[idx] if len(t_run) > idx else t_run[-1])
            else:
                rounds_vals.append(np.nan)
                final_vals.append(run_arr[-1] * 100)
                if i < len(metrics['time']):
                    t_run = metrics['time'][i]
                    time_vals.append(t_run[-1] if len(t_run) > 0 else 0)

            if i < len(metrics['gpu_mem']):
                g_mem = metrics['gpu_mem'][i]
                if isinstance(g_mem, (list, np.ndarray)):
                    gpu_vals.append(np.max(g_mem))
                else:
                    gpu_vals.append(g_mem)

        if not final_vals: continue

        mean_final, std_final = np.mean(final_vals), np.std(final_vals)
        mean_best, std_best = np.mean(best_vals), np.std(best_vals)
        mean_time, std_time = (np.mean(time_vals), np.std(time_vals)) if time_vals else (0, 0)
        
        valid_rounds = [r for r in rounds_vals if not np.isnan(r)]
        str_rounds = f"{np.mean(valid_rounds):.1f}" if valid_rounds else "> Max"

        if gpu_vals:
            mean_gpu = np.mean(gpu_vals)
            str_gpu = f"{mean_gpu:.0f} MB"
        else:
            str_gpu = "N/A"

        rows.append({
            "Method": name,
            "Final Acc (%)": f"{mean_final:.2f} ± {std_final:.2f}",
            "Best Acc (%)": f"{mean_best:.2f} ± {std_best:.2f}",
            "Rounds": str_rounds,
            "Time (s)": f"{mean_time:.0f} ± {std_time:.0f}",
            "Max GPU Mem": str_gpu
        })
    
    if not rows:
        print("No matching methods found.")
        return

    df = pd.DataFrame(rows)
    print("\n" + "="*100)
    print(f" SUMMARY TABLE (Threshold: {threshold*100:.0f}%)")
    print("="*100)
    print(df.to_string(index=False))
    
    csv_path = os.path.join(fig_dir, "hessian_comparison.csv")
    df.to_csv(csv_path, index=False)
    save_latex_table(df, fig_dir, threshold)

if __name__ == "__main__":
    p = argparse.ArgumentParser(description="Analyze Hessian Approximation Experiments")
    p.add_argument("--dataset", type=str, required=True, choices=['cifar10', 'cifar100'], 
                   help="Name of the dataset (e.g., cifar10, cifar100)")
    p.add_argument("--base_dir", type=str, default="./", 
                   help="Base directory containing dataset folders")
    p.add_argument("--output_root", type=str, default="analysis_results", 
                   help="Root directory to save analysis results")
    
    args = p.parse_args()

    # Construct dynamic paths
    # Input: solvers/cifar100/logs
    LOG_DIR = os.path.join(args.base_dir, args.dataset, "logs")
    # Output: analysis_results/cifar100
    FIG_DIR = os.path.join(args.output_root, args.dataset)
    
    # Get adaptive threshold
    THRESHOLD = get_cutoff_threshold(args.dataset)

    print(f"--- Configuration ---")
    print(f"Dataset:     {args.dataset}")
    print(f"Log Dir:     {LOG_DIR}")
    print(f"Output Dir:  {FIG_DIR}")
    print(f"Threshold:   {THRESHOLD*100:.0f}%")
    print(f"---------------------")

    # Ensure directories exist
    if not os.path.exists(LOG_DIR):
        print(f"Error: Log directory does not exist: {LOG_DIR}")
        sys.exit(1)
        
    os.makedirs(FIG_DIR, exist_ok=True)

    # Run Analysis
    data = load_data(LOG_DIR)
    
    if data:
        print("\n--- Generating Plots ---")
        ds = args.dataset # Short variable for dataset name
        
        # --- FILENAME MODIFICATION: Added dataset prefix ---
        plot_metric(data, 'acc', 'Test Accuracy', f'{ds.upper()} Test Accuracy', f'{ds}_solvers_compare_acc.png', FIG_DIR, THRESHOLD, use_cutoff=True)
        plot_metric(data, 'test_loss', 'Test Loss', f'{ds.upper()} Test Loss', f'{ds}_solvers_compare_test_loss.png', FIG_DIR, THRESHOLD, use_cutoff=True)
        plot_metric(data, 'train_loss', 'Training Loss', f'{ds.upper()} Training Loss', f'{ds}_solvers_compare_train_loss.png', FIG_DIR, THRESHOLD, use_cutoff=True)
        plot_metric(data, 'gpu_mem', 'GPU Memory (MB)', 'GPU Memory Usage', f'{ds}_solvers_compare_gpu_mem.png', FIG_DIR, THRESHOLD, use_cutoff=True)
        
        # Passed dataset name to plot_time_vs_acc
        plot_time_vs_acc(data, FIG_DIR, THRESHOLD, ds, use_cutoff=True)

        generate_summary_table(data, FIG_DIR, THRESHOLD)
    else:
        print("No data loaded. Exiting.")
