import os
import torch
from collections import defaultdict
import numpy as np

NSEEDS = 3

legend_order = ['rope', 'p_rope', 'nope', 'ntk_aware', 'nousyarn', 'logn', 'logn_and_p_rope', 'logn_and_ntk_aware', 'alibi', 'sfa_and_p_rope']

display_names = {
    'rope': 'RoPE',
    'p_rope': '$p$-RoPE',
    'ntk_aware': 'RoPE+NTK',
    'logn': 'LogN+RoPE',
    'logn_and_p_rope': 'LogN+$p$-RoPE',
    'logn_and_ntk_aware': 'LogN+NTK',
    'sfa_and_p_rope': 'Scale-invariant $p$-RoPE (ours)',
    'nope': 'NoPE',
    'alibi': 'ALiBi',
    'nousyarn': 'YaRN'
}

def extract_categories(folder_name):
    """Extract the position embedding type, context length, and seed from the folder name."""
    # Split by underscore
    parts = folder_name.split('_')

    # Determine position embedding type
    if folder_name.startswith('p_rope'):
        pos_type = 'p_rope'
    elif folder_name.startswith('rope_th500k'):
        pos_type = 'rope_th500k'
    elif folder_name.startswith('rope'):
        pos_type = 'rope'
    elif folder_name.startswith('sfa_and_p_rope'):
        pos_type = 'sfa_and_p_rope'
    elif folder_name.startswith('sfa_and_rope'):
        pos_type = 'sfa_and_rope'
    elif folder_name.startswith('sfa_and_nope'):
        pos_type = 'sfa_and_nope'
    elif folder_name.startswith('ntk_aware'):
        pos_type = 'ntk_aware'
    elif folder_name.startswith('sfa'):
        pos_type = 'sfa'
    elif folder_name.startswith('logn_trick_and_p_rope_learnS'):
        pos_type = 'logn_and_p_rope'
    elif folder_name.startswith('logn_trick_and_ntk_aware_learnS'):
        pos_type = 'logn_and_ntk_aware'
    elif folder_name.startswith('logn_trick_learnS'):
        pos_type = 'logn'
    elif folder_name.startswith('again_nope'):
        pos_type = 'nope'
    elif folder_name.startswith('alibi'):
        pos_type = 'alibi'
    elif folder_name.startswith('nousyarn'):
        pos_type = 'nousyarn'
    else:
        pos_type = 'unknown'
    # Extract context length
    context_length = None
    for part in parts:
        if part in ['4k', '16k', '64k']:
            context_length = part
            break

    # Extract seed
    seed = None
    for part in parts:
        if part.startswith('s') and len(part) == 2 and part[1].isdigit():
            seed = part
            break

    return pos_type, context_length, seed

def load_metrics_files(logs_dir='../gpt2/logs'):
    """Load nih_metrics files and organize them by embedding type, context length, and seed."""
    organized_metrics = defaultdict(lambda: defaultdict(dict))

    # Get all directories in the logs folder that may contain our target files
    log_dirs = [f for f in os.listdir(logs_dir)
                if os.path.isdir(os.path.join(logs_dir, f))]

    for folder in log_dirs:
        # Look for files matching the pattern nih_metrics_4096_sX.pt
        metrics_files = [f for f in os.listdir(os.path.join(logs_dir, folder))
                        if f.startswith('nih_metrics_4096_s') and f.endswith('.pt')]

        if metrics_files:
            try:
                # Extract categories from folder name
                pos_type, context_length, _ = extract_categories(folder)

                # Set context_length to '4k' as we're considering all these files as 4k training
                context_length = '4k'

                # Load all metrics files for this folder
                for metrics_file in metrics_files:
                    # Extract seed from filename (e.g., 's1' from 'nih_metrics_4096_s1.pt')
                    seed = metrics_file.split('_')[-1].split('.')[0]  # Gets 's1' from 'nih_metrics_4096_s1.pt'

                    metrics_path = os.path.join(logs_dir, folder, metrics_file)
                    metrics = torch.load(metrics_path)

                    # Organize by position embedding type -> context length -> seed
                    organized_metrics[pos_type][context_length][seed] = metrics

                print(f"Loaded {len(metrics_files)} metrics files from {folder}")
            except Exception as e:
                print(f"Error loading metrics from {folder}: {str(e)}")

    return organized_metrics

def create_final_acc_table(metrics_data):
    """Create a LaTeX table with final accuracies at step 300 for different context lengths."""
    # Define validation lengths
    val_lengths = ['4k', '16k', '64k']

    # Initialize results dictionary with nested structure for each validation length
    results = {method: {} for method in display_names.values()}

    # Extract final accuracies for each method and validation length
    for pos_type in legend_order:
        if pos_type in metrics_data and '4k' in metrics_data[pos_type]:
            method = display_names[pos_type]
            for val_length in val_lengths:
                all_accs = []
                for seed in metrics_data[pos_type]['4k']:
                    metrics = metrics_data[pos_type]['4k'][seed]
                    if isinstance(metrics, list):
                        # Find the metric entry at step 300
                        for metric_dict in metrics:
                            if metric_dict.get('step') == 300:
                                metric_key = f'val_acc_numbers_and_cities_{val_length}'
                                if metric_key in metric_dict:
                                    all_accs.append(metric_dict[metric_key])
                                break

                if all_accs:
                    # Calculate mean accuracy and standard error
                    mean_acc = np.mean(all_accs)
                    std_error = np.std(all_accs) / np.sqrt(len(all_accs))
                    results[method][val_length] = (mean_acc, std_error)

    # Create LaTeX table
    latex_table = "\\begin{table}[htbp]\n"
    latex_table += "\\centering\n"
    latex_table += "\\begin{tabular}{lccc}\n"
    latex_table += "\\hline\n"
    latex_table += "Method & Val Acc @4k & Val Acc @16k & Val Acc @64k \\\\\n"
    latex_table += "\\hline\n"

    # Add rows in the same order as legend_order
    for pos_type in legend_order:
        method = display_names[pos_type]
        row = f"{method}"
        for val_length in val_lengths:
            if val_length in results[method]:
                mean_acc, std_error = results[method][val_length]
                row += f" & {mean_acc:.3f} $\\pm$ {std_error:.3f}"
            else:
                row += " & -"
        row += " \\\\\n"
        latex_table += row

    latex_table += "\\hline\n"
    latex_table += "\\end{tabular}\n"
    latex_table += "\\caption{Final validation accuracies at step 300 for different context lengths. Values shown as mean $\\pm$ standard error.}\n"
    latex_table += "\\label{tab:final_acc}\n"
    latex_table += "\\end{table}\n"

    print(latex_table)

metrics_data = load_metrics_files()
create_final_acc_table(metrics_data)