import os
import torch
from collections import defaultdict
import numpy as np

def is_str_int(s):
    try:
        int(s)
        return True
    except ValueError:
        return False


# Update display names
display_names = {
    'infini_rope': 'Infini-RoPE',
    'infini_p_rope': 'Infini-$p$-RoPE',
}
pos_types = list(display_names.keys())

def extract_categories(folder_name):
    """Extract the position embedding type, context length, and seed from the folder name."""
    # Split by underscore
    parts = folder_name.split('_')
    if folder_name.startswith('infini_rope'):
        pos_type = 'infini_rope'
    elif folder_name.startswith('infini_p_rope'):
        pos_type = 'infini_p_rope'
    else:
        pos_type = 'unknown'

    # Extract context length
    context_length = None
    for part in parts:
        if part in ['4k', '16k', '64k']:
            context_length = part
            break

    # Extract seed
    seed = None
    for part in parts:
        if part.startswith('s') and len(part) >= 2 and is_str_int(part[1:]):
            seed = part
            break

    return pos_type, context_length, seed

def load_metrics_files(logs_dir=None):
    """Load all metrics.pt files and organize them by embedding type, context length, and seed."""
    organized_metrics = defaultdict(lambda: defaultdict(dict))

    # Get all directories in the logs folder
    log_dirs = [f for f in os.listdir(logs_dir) if os.path.isdir(os.path.join(logs_dir, f)) if '_med_' not in f and 'infini' in f]

    for folder in log_dirs:
        metrics_path = os.path.join(logs_dir, folder, 'metrics.pt')

        # Check if metrics.pt exists
        if os.path.exists(metrics_path):
            try:
                # Extract categories from folder name
                pos_type, context_length, seed = extract_categories(folder)


                # Load the metrics file
                metrics = torch.load(metrics_path)

                # Organize by position embedding type -> context length -> seed
                organized_metrics[pos_type][context_length][seed] = metrics

                print(f"Loaded metrics from {folder}")
            except Exception as e:
                print(f"Error loading {metrics_path}: {str(e)}")

    print("\nSummary of loaded data:")
    for pos_type in pos_types:
        if pos_type in organized_metrics:
            print(f"\n{pos_type}:")
            for context_length in organized_metrics[pos_type]:
                print(f"  {context_length}: {len(organized_metrics[pos_type][context_length])} seeds")

    return organized_metrics

def mk_table(metrics_data):
    """Create a LaTeX table of final validation losses for training at 4k context length."""
    train_length = '4k'
    val_lengths = ['4k', '16k', '64k']
    # Split methods into two groups

    # Initialize table data structure
    table_data = {val_length: {} for val_length in val_lengths}

    # Collect final losses for each method and validation length
    for pos_type in pos_types:
        if pos_type in metrics_data and train_length in metrics_data[pos_type]:
            for val_length in val_lengths:
                final_losses = []
                for seed in metrics_data[pos_type][train_length]:
                    metrics = metrics_data[pos_type][train_length][seed]
                    if isinstance(metrics, list):
                        val_key = f'val_loss_{val_length}'
                        # Find the last metric entry
                        for metric_dict in reversed(metrics):
                            if 'step' in metric_dict and val_key in metric_dict:
                                final_losses.append(metric_dict[val_key])
                                break
                if final_losses:
                    mean_loss = np.mean(final_losses)
                    std_err = np.std(final_losses) / np.sqrt(len(final_losses))
                    table_data[val_length][pos_type] = (mean_loss, std_err)

    # Create transposed LaTeX table
    latex_table = "\\begin{table}[h]\n\\centering\n"

    # Combine all methods
    all_methods = ['infini_rope', 'infini_p_rope']

    # Create tabular
    latex_table += f"\\begin{{tabular}}{{l{'c' * len(val_lengths)}}}\n"

    # Add header
    latex_table += "\\toprule\n"
    latex_table += "Method & " + " & ".join([f"Val @ {length}" for length in val_lengths]) + " \\\\\n"
    latex_table += "\\midrule\n"

    # Find the best (lowest) loss for each validation length
    best_per_length = {}
    for val_length in val_lengths:
        best_loss = float('inf')
        best_method = None
        for pos_type in all_methods:
            if pos_type in table_data[val_length]:
                mean_loss, _ = table_data[val_length][pos_type]
                if mean_loss < best_loss:
                    best_loss = mean_loss
                    best_method = pos_type
        best_per_length[val_length] = best_method

    # Add data rows
    for pos_type in all_methods:
        row = [display_names[pos_type]]
        for val_length in val_lengths:
            if pos_type in table_data[val_length]:
                mean_loss, std_err = table_data[val_length][pos_type]
                # Bold the best method for this validation length
                if pos_type == best_per_length[val_length]:
                    cell_value = f"\\textbf{{{mean_loss:.3f}}} $\\pm$ {std_err:.3f}"
                else:
                    cell_value = f"{mean_loss:.3f} $\\pm$ {std_err:.3f}"
                row.append(cell_value)
            else:
                row.append("N/A")
        latex_table += " & ".join(row) + " \\\\\n"

    # Add footer
    latex_table += "\\bottomrule\n"
    return latex_table

# Load the metrics data
metrics_data = load_metrics_files(logs_dir='../gpt2/logs')
print(mk_table(metrics_data))