import os
from warnings import warn
import torch
from collections import defaultdict
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D

def is_str_int(s):
    try:
        int(s)
        return True
    except ValueError:
        return False

NSEEDS = {'4k': 3, '16k': 1, '64k': 1}
DASHED_LINE_3p28 = True
DASHED = False
title_font_size = 7
axis_label_font_size = 6
ticks_font_size = 6

legend_settings = {'fontsize': 5, 'frameon': True, 'handlelength': 0.7}
dashed_line_settings = {'color': 'gray', 'linestyle': '--', 'linewidth': 0.2, 'alpha': 0.7}

pos_types = [ 'sfa_and_p_rope', 'ntk_aware', 'rope', 'p_rope', 'logn_and_p_rope', 'logn_and_ntk_aware', 'logn', 'nope', 'nousyarn', 'alibi']
legend_order = ['rope', 'p_rope', 'nope', 'ntk_aware', 'nousyarn', 'logn', 'logn_and_p_rope', 'logn_and_ntk_aware', 'alibi', 'sfa_and_p_rope']
assert len(legend_order) == len(pos_types)

# Update display names
display_names = {
    'rope': 'RoPE',
    'p_rope': '$p$-RoPE',
    'ntk_aware': 'RoPE+NTK',
    'logn': 'LogN+RoPE',
    'nousyarn': 'YaRN',
    'logn_and_p_rope': 'LogN+$p$-RoPE',
    'logn_and_ntk_aware': 'LogN+NTK',
    'sfa_and_rope': 'Scale-invariant RoPE',
    'sfa_and_nope': 'Scale-invariant NoPE',
    'sfa_and_p_rope': 'Scale-invariant $p$-RoPE (ours)',
    'nope': 'NoPE',
    'alibi': 'ALiBi'
}

# base_colors = ['#a50026', '#d73027', '#f46d43', '#fdae61', '#fee090', '#ffffbf', '#e0f3f8', '#abd9e9', '#74add1', '#4575b4', '#313695']
base_colors = ['#a50026', '#d73027', '#f46d43', '#fdae61', '#ffe54a', '#ffffbf', '#e0f3f8', '#abd9e9', '#74add1', '#4575b4', '#313695']
assert len(base_colors) == 11

## version 2, same line style
linewidth = 0.6
#order = ['sfa_and_p_rope', 'sfa_and_rope', 'sfa_and_nope', 'ntk_aware', 'rope', 'p_rope', 'logn_and_p_rope',  'logn']
plot_settings = {
    'sfa_and_p_rope': {'color': '#000', 'linestyle': '-', 'linewidth': linewidth, 'alpha': 1.0},
    'sfa_and_rope': {'color': base_colors[0], 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'sfa_and_nope': {'color': base_colors[3], 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'rope': {'color': base_colors[0], 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'p_rope': {'color': base_colors[2], 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'nope': {'color': base_colors[3], 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'ntk_aware': {'color': base_colors[4], 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'alibi': {'color': '#5D3A9B', 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'nousyarn': {'color': '#9ACD32', 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'logn': {'color': base_colors[-3], 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'logn_and_p_rope': {'color': base_colors[-2], 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'logn_and_ntk_aware': {'color': base_colors[-1], 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
}
## hack to make lines draw later slightly thinner
mult = 1.
for name in pos_types:
    plot_settings[name]['linewidth'] *= mult
    mult *= 0.98

def extract_categories(folder_name):
    """Extract the position embedding type, context length, and seed from the folder name."""
    # Split by underscore
    parts = folder_name.split('_')

    # Determine position embedding type
    if folder_name.startswith('_med_'):
        pos_type = 'unknown'
    elif folder_name.startswith('p_rope'):
        pos_type = 'p_rope'
    elif folder_name.startswith('rope'):
        pos_type = 'rope'
    elif folder_name.startswith('sfa_and_p_rope'):
        pos_type = 'sfa_and_p_rope'
    elif folder_name.startswith('sfa_and_rope'):
        pos_type = 'sfa_and_rope'
    elif folder_name.startswith('sfa_and_nope'):
        pos_type = 'sfa_and_nope'
    elif folder_name.startswith('sfa'):
        pos_type = 'sfa'
    elif folder_name.startswith('logn_trick_and_p_rope_learnS'):
        pos_type = 'logn_and_p_rope'
    elif folder_name.startswith('logn_trick_learnS'):
        pos_type = 'logn'
    elif folder_name.startswith('logn_trick_and_ntk_aware_learnS'):
        pos_type = 'logn_and_ntk_aware'
    elif folder_name.startswith('ntk_aware'):
        pos_type = 'ntk_aware'
    elif folder_name.startswith('again_nope'):
        pos_type = 'nope'
    elif folder_name.startswith('nousyarn'):
        pos_type = 'nousyarn'
    elif folder_name.startswith('alibi'):
        pos_type = 'alibi'
    else:
        pos_type = 'unknown'

    # Extract context length
    context_length = None
    for part in parts:
        if part in ['4k', '16k', '64k']:
            context_length = part
            break

    # Extract seed
    seed = None
    for part in parts:
        if part.startswith('s') and len(part) >= 2 and is_str_int(part[1:]):
            seed = part
            break

    return pos_type, context_length, seed

def load_metrics_files(logs_dir=None):
    """Load all metrics.pt files and organize them by embedding type, context length, and seed."""
    organized_metrics = defaultdict(lambda: defaultdict(dict))

    # Get all directories in the logs folder
    log_dirs = [f for f in os.listdir(logs_dir) if os.path.isdir(os.path.join(logs_dir, f)) if '_med_' not in f]

    for folder in log_dirs:
        metrics_path = os.path.join(logs_dir, folder, 'metrics.pt')

        # Check if metrics.pt exists
        if os.path.exists(metrics_path):
            try:
                # Extract categories from folder name
                pos_type, context_length, seed = extract_categories(folder)


                # Load the metrics file
                metrics = torch.load(metrics_path)

                # Organize by position embedding type -> context length -> seed
                organized_metrics[pos_type][context_length][seed] = metrics

                print(f"Loaded metrics from {folder}")
            except Exception as e:
                print(f"Error loading {metrics_path}: {str(e)}")

    print("\nSummary of loaded data:")
    for pos_type in pos_types:
        if pos_type in organized_metrics:
            print(f"\n{pos_type}:")
            for context_length in organized_metrics[pos_type]:
                print(f"  {context_length}: {len(organized_metrics[pos_type][context_length])} seeds")

    return organized_metrics

def plot_validation_losses(metrics_data, output_dir='./plots'):
    """Plot validation losses for different training context lengths and validation lengths."""
    # Make sure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # train_lengths = ['4k', '16k', '64k']
    train_lengths = ['16k']
    val_lengths = ['4k', '16k', '64k']

    # Create a separate plot for each training length
    for train_length in train_lengths:
        fig, axes = plt.subplots(1, 3, figsize=(5.5, 1.5))

        # Each column is a different validation length
        for j, val_length in enumerate(val_lengths):
            ax = axes[j]
            ax.set_title(f'Train @{train_length} / Val @{val_length}', fontsize=title_font_size)

            for pos_type in pos_types:
                if pos_type in metrics_data and train_length in metrics_data[pos_type]:
                    all_losses = []
                    all_steps = []

                    seeds = [5, 6, 8] if train_length == '4k' else [5]
                    # for seed in metrics_data[pos_type][train_length]:
                    for seed in seeds:
                        try:
                            metrics = metrics_data[pos_type][train_length][f"s{seed}"]
                        except Exception as e:
                            print(f"Error loading metrics for {pos_type} at train={train_length}, seed={seed}: {str(e)}")
                            print(f"metrics_data[pos_type][train_length].keys(): {metrics_data[pos_type][train_length].keys()}")
                            # continue
                            raise e

                        if isinstance(metrics, list):
                            steps_data = []
                            loss_data = []
                            val_key = f'val_loss_{val_length}'

                            for metric_dict in metrics:
                                if 'step' in metric_dict and val_key in metric_dict:
                                    steps_data.append(metric_dict['step'])
                                    loss_data.append(metric_dict[val_key])

                            if steps_data and loss_data:
                                all_steps.append(np.array(steps_data))
                                all_losses.append(np.array(loss_data))

                    if all_losses:
                        min_length = min(len(steps) for steps in all_steps)
                        aligned_losses = [loss[:min_length] for loss in all_losses]
                        aligned_steps = all_steps[0][:min_length]

                        nseeds = len(aligned_losses)
                        if nseeds != NSEEDS[train_length]:
                            print(f"WARN: number of seeds is not {NSEEDS[train_length]} (got {nseeds})for tr_len:{train_length}, val_len:{val_length}, pos_type:{pos_type}")
                        mean_loss = np.mean(aligned_losses, axis=0)

                        ax.plot(aligned_steps, mean_loss, **plot_settings[pos_type])
                    else:
                        print(f"WARN: No valid loss data for {pos_type} at train={train_length}, val={val_length}")

            ## dashed line
            if DASHED_LINE_3p28:
                ax.axhline(y=3.28, **dashed_line_settings)

            if j == 0:
                ax.set_ylabel(f'Val Loss', fontsize=axis_label_font_size)
            else:
                # Hide y-axis ticks and labels for non-left plots
                ax.set_yticklabels([])
                # Keep the left border visible but hide the tick marks
                ax.tick_params(axis='y', which='both', left=False)

            ax.set_xlabel('Step', fontsize=axis_label_font_size)
            ax.set_xlim(200, 4578)

            # Set custom x-ticks to include 1000
            ax.set_xticks([1000, 2000, 3000, 4000])

            # Set font size for tick labels
            ax.tick_params(axis='both', which='major', labelsize=ticks_font_size)

            if train_length == '4k':
                ax.set_ylim(3.15, 5.75)
                ax.set_yticks([3.5, 4.0, 4.5, 5.0, 5.5])
            elif train_length == '16k':
                ax.set_ylim(3.15, 4.65)
                ax.set_yticks([3.5, 4.0, 4.5])
            elif train_length == '64k':
                ax.set_ylim(3.15, 4.00)
                # ax.set_yticks([3.5, 4.0, 4.5, 5.0])
            else:
                raise ValueError(f"Unexpected val_length: {val_length}")

        # Create legend below all panels with fixed linewidth=1 and no border
        legend_elements = [Line2D([0], [0], color=plot_settings[pos_type]['color'], linestyle=plot_settings[pos_type]['linestyle'], linewidth=1, alpha=plot_settings[pos_type]['alpha'], label=display_names[pos_type])
                          for pos_type in legend_order]
        # Local legend settings for compact legend
        local_legend_settings = legend_settings.copy()
        local_legend_settings.update({'frameon': True, 'handlelength': 0.6, 'handletextpad': 0.3, 'columnspacing': 0.9, 'fontsize': 6})
        fig.legend(handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, 1.08),
                    ncol=len(legend_order), **local_legend_settings)

        plt.tight_layout()
        # Adjust layout to make room for legend above
        plt.subplots_adjust(wspace=0.03, top=0.98)

        fname = f'val_loss_train_{train_length}_dashed' if DASHED else f'val_loss_train_{train_length}'
        output_file = os.path.join(output_dir, f'{fname}.pdf')
        plt.savefig(output_file, format='pdf', bbox_inches='tight')
        plt.close(fig)
        print(f"Saved plot to {output_file}")

def plot_sfp_comparison(metrics_data, output_dir='./plots'):
    """Plot validation losses for specific training/validation length combinations."""
    # Make sure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Define the specific configurations we want to plot
    configs = [
        ('4k', '4k'),
        ('16k', '16k'),
        ('64k', '64k')
    ]

    fig, axes = plt.subplots(1, 3, figsize=(5.5, 1.5))

    # Each column is a different configuration
    for j, (train_length, val_length) in enumerate(configs):
        ax = axes[j]
        ax.set_title(f'Train @{train_length} / Val @{val_length}', fontsize=title_font_size)

        for pos_type in pos_types:
            if pos_type in metrics_data and train_length in metrics_data[pos_type]:
                all_losses = []
                all_steps = []

                for seed in metrics_data[pos_type][train_length]:
                    metrics = metrics_data[pos_type][train_length][seed]
                    if isinstance(metrics, list):
                        steps_data = []
                        loss_data = []
                        val_key = f'val_loss_{val_length}'

                        for metric_dict in metrics:
                            if 'step' in metric_dict and val_key in metric_dict:
                                steps_data.append(metric_dict['step'])
                                loss_data.append(metric_dict[val_key])

                        if steps_data and loss_data:
                            all_steps.append(np.array(steps_data))
                            all_losses.append(np.array(loss_data))

                if all_losses:
                    min_length = min(len(steps) for steps in all_steps)
                    aligned_losses = [loss[:min_length] for loss in all_losses]
                    aligned_steps = all_steps[0][:min_length]

                    nseeds = len(aligned_losses)
                    if nseeds != NSEEDS[train_length]:
                        print(f"WARN: number of seeds is not {NSEEDS[train_length]} (got {nseeds})for tr_len:{train_length}, val_len:{val_length}, pos_type:{pos_type}")

                    mean_loss = np.mean(aligned_losses, axis=0)
                    ax.plot(aligned_steps, mean_loss, **plot_settings[pos_type])

        ## dashed line
        if DASHED_LINE_3p28:
            ax.axhline(y=3.28, **dashed_line_settings)

        if j == 0:
            ax.set_ylabel(f'Val Loss', fontsize=axis_label_font_size)
        else:
            # Hide y-axis ticks and labels for non-left plots
            ax.set_yticklabels([])
            # Keep the left border visible but hide the tick marks
            ax.tick_params(axis='y', which='both', left=False)

        ax.set_xlabel('Step', fontsize=axis_label_font_size)
        ax.set_xlim(3578, 4578)

        # Set custom x-ticks to include 1000
        ax.set_xticks([4000, 4500])
        ax.tick_params(axis='x')

        # Set font size for tick labels
        ax.tick_params(axis='both', which='major', labelsize=ticks_font_size)

        ax.set_ylim(3.21, 3.42)
        ax.set_yticks([3.25, 3.30, 3.35, 3.40])

    # Create legend below all panels with fixed linewidth=1 and no border
    legend_elements = [Line2D([0], [0], color=plot_settings[pos_type]['color'], linestyle=plot_settings[pos_type]['linestyle'], linewidth=1, alpha=plot_settings[pos_type]['alpha'], label=display_names[pos_type])
                      for pos_type in legend_order]
    # Local legend settings for compact legend
    local_legend_settings = legend_settings.copy()
    local_legend_settings.update({'frameon': True, 'handlelength': 0.6, 'handletextpad': 0.3, 'columnspacing': 0.9, 'fontsize': 6})
    fig.legend(handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, 1.08),
                ncol=len(legend_order), **local_legend_settings)

    plt.tight_layout()
    # Adjust layout to make room for legend above
    plt.subplots_adjust(wspace=0.03, top=0.98)

    fname = 'val_loss_sfp_only_better_dashed' if DASHED else 'val_loss_sfp_only_better'
    output_file = os.path.join(output_dir, f'{fname}.pdf')
    plt.savefig(output_file, format='pdf', bbox_inches='tight')
    plt.close(fig)
    print(f"Saved plot to {output_file}")

def plot_4k_validation_losses_with_zoom(metrics_data, output_dir='./plots'):
    """Plot validation losses for 4k training length with both full and zoomed views."""
    # Make sure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    train_length = '4k'
    val_lengths = ['16k', '64k']

    # Create figure with GridSpec for custom layout
    # Layout: [full_16k, full_64k, whitespace, zoom_16k, zoom_64k]
    fig = plt.figure(figsize=(5.5, 1.5))  # Made figure wider to accommodate whitespace
    gs = plt.GridSpec(1, 5, width_ratios=[1, 1, 0.5, 0.5, 0.5], wspace=0.03)

    # Create axes with different widths
    ax_full_16k = fig.add_subplot(gs[0])
    ax_full_64k = fig.add_subplot(gs[1])
    # gs[2] is whitespace
    ax_zoom_16k = fig.add_subplot(gs[3])
    ax_zoom_64k = fig.add_subplot(gs[4])

    # Reduce title font size
    title_font_size = 6  # Reduced from 7

    # Each column is a different validation length
    for j, val_length in enumerate(val_lengths):
        # Full view (left panels)
        ax_full = ax_full_16k if j == 0 else ax_full_64k
        ax_full.set_title(f'Train @4k / Val @{val_length}', fontsize=title_font_size)

        # Zoomed view (right panels)
        ax_zoom = ax_zoom_16k if j == 0 else ax_zoom_64k
        ax_zoom.set_title(f'Train @4k / Val @{val_length}', fontsize=title_font_size - 1.95)

        for pos_type in pos_types:
            if pos_type in metrics_data and train_length in metrics_data[pos_type]:
                all_losses = []
                all_steps = []

                for seed in metrics_data[pos_type][train_length]:
                    metrics = metrics_data[pos_type][train_length][seed]
                    if isinstance(metrics, list):
                        steps_data = []
                        loss_data = []
                        val_key = f'val_loss_{val_length}'

                        for metric_dict in metrics:
                            if 'step' in metric_dict and val_key in metric_dict:
                                steps_data.append(metric_dict['step'])
                                loss_data.append(metric_dict[val_key])

                        if steps_data and loss_data:
                            all_steps.append(np.array(steps_data))
                            all_losses.append(np.array(loss_data))

                if all_losses:
                    min_length = min(len(steps) for steps in all_steps)
                    aligned_losses = [loss[:min_length] for loss in all_losses]
                    aligned_steps = all_steps[0][:min_length]

                    nseeds = len(aligned_losses)
                    if nseeds != NSEEDS[train_length]:
                        print(f"WARN: number of seeds is not {NSEEDS[train_length]} (got {nseeds})for tr_len:{train_length}, val_len:{val_length}, pos_type:{pos_type}")

                    mean_loss = np.mean(aligned_losses, axis=0)

                    # Plot full view
                    ax_full.plot(aligned_steps, mean_loss, **plot_settings[pos_type])

                    # Plot zoomed view
                    ax_zoom.plot(aligned_steps, mean_loss, **plot_settings[pos_type])

        ## dashed line
        if DASHED_LINE_3p28:
            ax_full.axhline(y=3.28, **dashed_line_settings)
            ax_zoom.axhline(y=3.28, **dashed_line_settings)

        # Configure full view
        ax_full.set_xlabel('Step', fontsize=axis_label_font_size)
        ax_full.set_xlim(200, 4578)
        ax_full.set_xticks([1000, 2000, 3000, 4000])
        ax_full.set_ylim(3.15, 5.75)
        ax_full.set_yticks([3.5, 4.0, 4.5, 5.0, 5.5])

        # Configure zoomed view
        ax_zoom.set_xlabel('Step', fontsize=axis_label_font_size)
        ax_zoom.set_xlim(3578, 4578)  # Last 1000 steps
        ax_zoom.set_xticks([4000, 4500])
        ax_zoom.set_ylim(3.21, 3.43)
        ax_zoom.set_yticks([3.25, 3.30, 3.35, 3.40])

        # Set font size for tick labels
        ax_full.tick_params(axis='both', which='major', labelsize=ticks_font_size)
        ax_zoom.tick_params(axis='both', which='major', labelsize=ticks_font_size)

        # Add y-axis label and legend to first full view plot
        if j == 0:
            ax_full.set_ylabel('Val Loss', fontsize=axis_label_font_size)
            ax_zoom.set_ylabel('Val Loss', fontsize=axis_label_font_size)
        else:
            # Hide y-axis ticks and labels for other plots
            ax_full.set_yticklabels([])
            ax_full.tick_params(axis='y', which='both', left=False)
            ax_zoom.set_yticklabels([])
            ax_zoom.tick_params(axis='y', which='both', left=False)

    # Add labels to the right of each pair
    fig.text(0.575, 0.5, 'Zoomed Out', rotation=270, va='center', ha='center', fontsize=axis_label_font_size, style='italic')
    fig.text(0.91, 0.5, 'Zoomed In', rotation=270, va='center', ha='center', fontsize=axis_label_font_size, style='italic')

    # Add legend to the center right of the Train @ 4k / Val @ 64k plot
    # legend_elements = [Line2D([0], [0], **plot_settings[pos_type], label=display_names[pos_type])
    #                   for pos_type in legend_order]
    # ax_full_64k.legend(handles=legend_elements, loc='center right', bbox_to_anchor=(1.0, 0.5), **legend_settings)

    fname = 'val_loss_4k_with_zoom_dashed' if DASHED else 'val_loss_4k_with_zoom'
    output_file = os.path.join(output_dir, f'{fname}.pdf')
    plt.savefig(output_file, format='pdf', bbox_inches='tight')
    plt.close(fig)
    print(f"Saved plot to {output_file}")

def plot_4k_training_comparison(metrics_data, output_dir='./plots'):
    """Plot validation losses for 4k training length with different validation lengths."""
    # Make sure output directory exists

    def get_display_name(pos_type):
        x = display_names.get(pos_type, pos_type)
        if x.endswith('(ours)'):
            x = x[:-6].strip()
        return x
    os.makedirs(output_dir, exist_ok=True)

    train_length = '4k'
    val_lengths = ['4k', '16k', '64k']

    # Only show these three methods
    selected_pos_types = ['sfa_and_nope', 'sfa_and_rope', 'sfa_and_p_rope']
    selected_legend_order = ['sfa_and_nope', 'sfa_and_rope', 'sfa_and_p_rope']

    # Create figure with 3 subplots
    fig, axes = plt.subplots(1, 3, figsize=(5.5, 1.5))

    # Each column is a different validation length
    for j, val_length in enumerate(val_lengths):
        ax = axes[j]
        ax.set_title(f'Train @{train_length} / Val @{val_length}', fontsize=title_font_size)

        for pos_type in selected_pos_types:
            if pos_type in metrics_data and train_length in metrics_data[pos_type]:
                all_losses = []
                all_steps = []

                for seed in metrics_data[pos_type][train_length]:
                    metrics = metrics_data[pos_type][train_length][seed]
                    if isinstance(metrics, list):
                        steps_data = []
                        loss_data = []
                        val_key = f'val_loss_{val_length}'

                        for metric_dict in metrics:
                            if 'step' in metric_dict and val_key in metric_dict:
                                steps_data.append(metric_dict['step'])
                                loss_data.append(metric_dict[val_key])

                        if steps_data and loss_data:
                            all_steps.append(np.array(steps_data))
                            all_losses.append(np.array(loss_data))

                if all_losses:
                    min_length = min(len(steps) for steps in all_steps)
                    aligned_losses = [loss[:min_length] for loss in all_losses]
                    aligned_steps = all_steps[0][:min_length]

                    nseeds = len(aligned_losses)
                    if nseeds != NSEEDS[train_length]:
                        print(f"WARN: number of seeds is not {NSEEDS[train_length]} (got {nseeds})for tr_len:{train_length}, val_len:{val_length}, pos_type:{pos_type}")

                    mean_loss = np.mean(aligned_losses, axis=0)
                    ax.plot(aligned_steps, mean_loss, **plot_settings[pos_type])

        ## dashed line
        if DASHED_LINE_3p28:
            ax.axhline(y=3.28, **dashed_line_settings)

        if j == 0:
            ax.set_ylabel(f'Val Loss', fontsize=axis_label_font_size)

            # Create legend for the first plot (left plot)
            legend_elements = [Line2D([0], [0], **plot_settings[pos_type], label=get_display_name(pos_type))
                           for pos_type in selected_legend_order]
            ax.legend(handles=legend_elements, loc='upper right', **legend_settings)
        else:
            # Hide y-axis ticks and labels for non-left plots
            ax.set_yticklabels([])
            # Keep the left border visible but hide the tick marks
            ax.tick_params(axis='y', which='both', left=False)

        ax.set_xlabel('Step', fontsize=axis_label_font_size)
        ax.set_xlim(200, 4578)

        # Set custom x-ticks
        ax.set_xticks([1000, 2000, 3000, 4000])

        # Set font size for tick labels
        ax.tick_params(axis='both', which='major', labelsize=ticks_font_size)

        # Set y-axis limits and ticks
        ax.set_ylim(3.15, 4.5)
        ax.set_yticks([3.5, 4.0, 4.5])

    plt.tight_layout()
    plt.subplots_adjust(wspace=0.03)  # Reduce spacing between subplots
    fname = 'val_loss_4k_sfa_comparison_dashed' if DASHED else 'val_loss_4k_sfa_comparison'
    output_file = os.path.join(output_dir, f'{fname}.pdf')
    plt.savefig(output_file, format='pdf', bbox_inches='tight')
    plt.close(fig)
    print(f"Saved plot to {output_file}")

def mk_table(metrics_data):
    """Create a LaTeX table of final validation losses for training at 4k context length."""
    train_length = '4k'
    val_lengths = ['4k', '16k', '64k']
    # Split methods into two groups
    first_group = ['rope', 'p_rope', 'nope', 'ntk_aware', 'nousyarn']
    second_group = ['logn', 'logn_and_p_rope', 'logn_and_ntk_aware', 'alibi', 'sfa_and_p_rope']

    # Initialize table data structure
    table_data = {val_length: {} for val_length in val_lengths}

    # Collect final losses for each method and validation length
    for pos_type in pos_types:
        if pos_type in metrics_data and train_length in metrics_data[pos_type]:
            for val_length in val_lengths:
                final_losses = []
                for seed in metrics_data[pos_type][train_length]:
                    metrics = metrics_data[pos_type][train_length][seed]
                    if isinstance(metrics, list):
                        val_key = f'val_loss_{val_length}'
                        # Find the last metric entry
                        for metric_dict in reversed(metrics):
                            if 'step' in metric_dict and val_key in metric_dict:
                                final_losses.append(metric_dict[val_key])
                                break
                if final_losses:
                    mean_loss = np.mean(final_losses)
                    std_err = np.std(final_losses) / np.sqrt(len(final_losses))
                    table_data[val_length][pos_type] = (mean_loss, std_err)

    # Create transposed LaTeX table
    latex_table = "\\begin{table}[h]\n\\centering\n"

    # Combine all methods
    all_methods = first_group + second_group

    # Create tabular
    latex_table += f"\\begin{{tabular}}{{l{'c' * len(val_lengths)}}}\n"

    # Add header
    latex_table += "\\toprule\n"
    latex_table += "Method & " + " & ".join([f"Val @ {length}" for length in val_lengths]) + " \\\\\n"
    latex_table += "\\midrule\n"

    # Find the best (lowest) loss for each validation length
    best_per_length = {}
    for val_length in val_lengths:
        best_loss = float('inf')
        best_method = None
        for pos_type in all_methods:
            if pos_type in table_data[val_length]:
                mean_loss, _ = table_data[val_length][pos_type]
                if mean_loss < best_loss:
                    best_loss = mean_loss
                    best_method = pos_type
        best_per_length[val_length] = best_method

    # Add data rows
    for pos_type in all_methods:
        row = [display_names[pos_type]]
        for val_length in val_lengths:
            if pos_type in table_data[val_length]:
                mean_loss, std_err = table_data[val_length][pos_type]
                # Bold the best method for this validation length
                if pos_type == best_per_length[val_length]:
                    cell_value = f"\\textbf{{{mean_loss:.3f}}} $\\pm$ {std_err:.3f}"
                else:
                    cell_value = f"{mean_loss:.3f} $\\pm$ {std_err:.3f}"
                row.append(cell_value)
            else:
                row.append("N/A")
        latex_table += " & ".join(row) + " \\\\\n"

    # Add footer
    latex_table += "\\bottomrule\n"
    latex_table += "\\end{tabular}\n"

    # Add caption and label
    latex_table += "\\caption{Final validation losses (step 4578) for different methods when training at 4k context length}\n"
    latex_table += "\\label{tab:final_losses}\n"
    latex_table += "\\end{table}"

    return latex_table

"""old"""

# Load the metrics data
metrics_data = load_metrics_files(logs_dir='../gpt2/logs')

# Generate the plots
plot_validation_losses(metrics_data, output_dir='.')
plot_sfp_comparison(metrics_data, output_dir='.')
plot_4k_validation_losses_with_zoom(metrics_data, output_dir='.')
plot_4k_training_comparison(metrics_data, output_dir='.')

# # Print the LaTeX table
# print("\nLaTeX Table:")
print(mk_table(metrics_data))