import os
import torch
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D

def is_str_int(s):
    try:
        int(s)
        return True
    except ValueError:
        return False

NSEED = 1
DASHED_LINE_3p28 = False
DASHED = False
title_font_size = 7
axis_label_font_size = 6
ticks_font_size = 6

legend_settings = {'fontsize': 5, 'frameon': True, 'handlelength': 0.7}
dashed_line_settings = {'color': 'gray', 'linestyle': '--', 'linewidth': 0.2, 'alpha': 0.8}

pos_types = [ 'sfa_and_p_rope', 'ntk_aware', 'rope', 'p_rope', 'logn_and_p_rope', 'logn_and_ntk_aware', 'logn', 'nope', 'alibi']
legend_order = ['rope', 'p_rope', 'nope', 'ntk_aware', 'logn', 'logn_and_p_rope', 'logn_and_ntk_aware', 'alibi', 'sfa_and_p_rope']
assert len(legend_order) == len(pos_types)

# Update display names
display_names = {
    'rope': 'RoPE',
    'p_rope': 'p-RoPE',
    'ntk_aware': 'RoPE+NTK',
    'logn': 'LogN+RoPE',
    'logn_and_p_rope': 'LogN+p-RoPE',
    'logn_and_ntk_aware': 'LogN+NTK',
    'sfa_and_rope': 'Scale-invariant RoPE',
    'sfa_and_nope': 'Scale-invariant NoPE',
    'sfa_and_p_rope': 'Scale-invariant p-RoPE\n(ours)',
    'nope': 'NoPE',
    'alibi': 'ALiBi'
}

base_colors = ['#a50026', '#d73027', '#f46d43', '#fdae61', '#ffe54a', '#ffffbf', '#e0f3f8', '#abd9e9', '#74add1', '#4575b4', '#313695']
assert len(base_colors) == 11

## version 2, same line style
linewidth = 0.6
plot_settings = {
    'sfa_and_p_rope': {'color': '#000', 'linestyle': '-', 'linewidth': linewidth*1.1, 'alpha': 0.8},
    'sfa_and_rope': {'color': base_colors[0], 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'sfa_and_nope': {'color': base_colors[3], 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'rope': {'color': base_colors[0], 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'p_rope': {'color': base_colors[2], 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'nope': {'color': base_colors[3], 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'ntk_aware': {'color': base_colors[4], 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'alibi': {'color': '#5D3A9B', 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'logn': {'color': base_colors[-3], 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'logn_and_p_rope': {'color': base_colors[-2], 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
    'logn_and_ntk_aware': {'color': base_colors[-1], 'linestyle': '-', 'linewidth': linewidth, 'alpha': 0.8},
}
## hack to make lines draw later slightly thinner
mult = 1.
for name in pos_types:
    plot_settings[name]['linewidth'] *= mult
    mult *= 0.98

def extract_categories(folder_name):
    """Extract the position embedding type, context length, and seed from the folder name."""
    # Split by underscore
    parts = folder_name.split('_')

    # Determine position embedding type
    if folder_name.startswith('p_rope'):
        pos_type = 'p_rope'
    elif folder_name.startswith('rope'):
        pos_type = 'rope'
    elif folder_name.startswith('sfa_and_p_rope'):
        pos_type = 'sfa_and_p_rope'
    elif folder_name.startswith('sfa_and_rope'):
        pos_type = 'sfa_and_rope'
    elif folder_name.startswith('sfa_and_nope'):
        pos_type = 'sfa_and_nope'
    elif folder_name.startswith('sfa'):
        pos_type = 'sfa'
    elif folder_name.startswith('logn_and_p_rope'):
        pos_type = 'logn_and_p_rope'
    elif folder_name.startswith('logn'):
        pos_type = 'logn'
    elif folder_name.startswith('logn_and_ntk_aware'):
        pos_type = 'logn_and_ntk_aware'
    elif folder_name.startswith('ntk_aware'):
        pos_type = 'ntk_aware'
    elif folder_name.startswith('nope'):
        pos_type = 'nope'
    elif folder_name.startswith('alibi'):
        pos_type = 'alibi'
    else:
        pos_type = 'unknown'

    # Extract context length
    context_length = None
    for part in parts:
        if part in ['4k', '16k', '64k']:
            context_length = part
            break

    # Extract seed
    seed = None
    for part in parts:
        if part.startswith('s') and len(part) >= 2 and is_str_int(part[1:]):
            seed = part
            break

    return pos_type, context_length, seed

def load_metrics_files(logs_dir=None):
    """Load all metrics.pt files and organize them by embedding type, context length, and seed."""
    organized_metrics = defaultdict(lambda: defaultdict(dict))

    # Get all directories in the logs folder
    log_dirs = [f for f in os.listdir(logs_dir) if os.path.isdir(os.path.join(logs_dir, f)) if '_tau_' not in f and '_med_' in f]

    for folder in log_dirs:
        metrics_path = os.path.join(logs_dir, folder, 'metrics.pt')

        # Check if metrics.pt exists
        if os.path.exists(metrics_path):
            try:
                # Extract categories from folder name
                pos_type, context_length, seed = extract_categories(folder)


                # Load the metrics file
                metrics = torch.load(metrics_path)

                # Organize by position embedding type -> context length -> seed
                organized_metrics[pos_type][context_length][seed] = metrics

                print(f"Loaded metrics from {folder}")
            except Exception as e:
                print(f"Error loading {metrics_path}: {str(e)}")

    print("\nSummary of loaded data:")
    for pos_type in pos_types:
        if pos_type in organized_metrics:
            print(f"\n{pos_type}:")
            for context_length in organized_metrics[pos_type]:
                print(f"  {context_length}: {len(organized_metrics[pos_type][context_length])} seeds")

    return organized_metrics

def plot_validation_losses(metrics_data, output_dir='./plots'):
    """Plot validation losses for different training context lengths and validation lengths."""
    # Make sure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    train_lengths = [None] # this is 4k in the med setting
    val_lengths = ['4k', '16k', '64k']
    pos_types = ['alibi',  'logn_and_p_rope', 'sfa_and_p_rope']

    # Create a separate plot for each training length
    for train_length in train_lengths:
        fig, axes = plt.subplots(1, 3, figsize=(5.5, 1.9))

        # Each column is a different validation length
        for j, val_length in enumerate(val_lengths):
            ax = axes[j]
            ax.set_title(f'Train @4k / Val @{val_length}', fontsize=title_font_size)

            for pos_type in pos_types:
                if pos_type in metrics_data and train_length in metrics_data[pos_type]:
                    all_losses = []
                    all_steps = []

                    for seed in metrics_data[pos_type][train_length]:
                        metrics = metrics_data[pos_type][train_length][seed]
                        if isinstance(metrics, list):
                            steps_data = []
                            loss_data = []
                            val_key = f'val_loss_{val_length}'

                            for metric_dict in metrics:
                                if 'step' in metric_dict and val_key in metric_dict:
                                    steps_data.append(metric_dict['step'])
                                    loss_data.append(metric_dict[val_key])

                            if steps_data and loss_data:
                                all_steps.append(np.array(steps_data))
                                all_losses.append(np.array(loss_data))

                    if all_losses:
                        min_length = min(len(steps) for steps in all_steps)
                        aligned_losses = [loss[:min_length] for loss in all_losses]
                        aligned_steps = all_steps[0][:min_length]

                        nseeds = len(aligned_losses)
                        if nseeds != NSEED:
                            print(f"WARN: number of seeds is not {NSEED} (got {nseeds})for tr_len:{train_length}, val_len:{val_length}, pos_type:{pos_type}")
                        # else:
                            # print(f"GOOD: number of seeds is expected! for tr_len:{train_length}, val_len:{val_length}, pos_type:{pos_type}")
                        mean_loss = np.mean(aligned_losses, axis=0)

                        ax.plot(aligned_steps, mean_loss, **plot_settings[pos_type])
                    else:
                        print(f"WARN: No valid loss data for {pos_type} at train={train_length}, val={val_length}")

            ## dashed line
            if DASHED_LINE_3p28:
                ax.axhline(y=3.28, **dashed_line_settings)

            if j == 0:
                ax.set_ylabel(f'Val Loss', fontsize=axis_label_font_size)

                # Create legend for the first plot (left plot)
                legend_elements = [Line2D([0], [0], **plot_settings[pos_type], label=display_names[pos_type])
                               for pos_type in legend_order if pos_type in pos_types]
                ax.legend(handles=legend_elements, loc='upper right', **legend_settings)
            else:
                # Hide y-axis ticks and labels for non-left plots
                ax.set_yticklabels([])
                # Keep the left border visible but hide the tick marks
                ax.tick_params(axis='y', which='both', left=False)

            ax.set_xlabel('Step', fontsize=axis_label_font_size)
            ax.set_xlim(1000, 10900)

            # Set custom x-ticks to include 1000
            # ax.set_xticks([1000, 2000, 3000, 4000])

            # Set font size for tick labels
            ax.tick_params(axis='both', which='major', labelsize=ticks_font_size)

            ax.set_ylim(2.90, 3.30)
            ax.set_yticks([2.9, 3.0, 3.1, 3.2, 3.3])
     #
        plt.tight_layout()
        plt.subplots_adjust(wspace=0.03)  # Reduce spacing between subplots even more
        fname = f'val_loss_med_train_4k_dashed' if DASHED else f'val_loss_med_train_4k'
        output_file = os.path.join(output_dir, f'{fname}.pdf')
        plt.savefig(output_file, format='pdf', bbox_inches='tight')
        plt.close(fig)
        print(f"Saved plot to {output_file}")

metrics_data = load_metrics_files('../gpt2/logs/')
plot_validation_losses(metrics_data, output_dir='./')
