import os
import torch
import re
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D
from pathlib import Path
import matplotlib.cm as cm

# Flag for dashed line at 3.28
DASHED_LINE_3p28 = False

# Font size settings
title_font_size = 7
axis_label_font_size = 6
ticks_font_size = 6
legend_settings = {'fontsize': 5, 'frameon': True, 'handlelength': 0.7}

def extract_tau_info(folder_name):
    """Extract the tau value and seed from an SFA tau folder name."""
    match = re.match(r'sfa_and_p_rope_tau_([^\_]+)_4k_s(\d+)', folder_name)
    if match:
        try:
            tau = float(match.group(1))
        except:
            if match.group(1) == '0p1':
                tau = 0.1
            elif match.group(1) == '0p01':
                tau = 0.01
            else:
                raise ValueError(f"Invalid tau value: {match.group(1)}")
        seed = match.group(2)
        return tau, f"s{seed}"
    return None, None

def load_sfa_tau_metrics(logs_dir='../gpt2/logs_tau'):
    organized_metrics = defaultdict(dict)  # tau -> seed -> metrics

    # Get all directories in the logs folder
    log_dirs = [f for f in os.listdir(logs_dir) if os.path.isdir(os.path.join(logs_dir, f))]

    # Filter for SFA tau directories
    sfa_tau_dirs = [d for d in log_dirs if d.startswith('sfa_and_p_rope_tau') and '4k' in d]

    for folder in sfa_tau_dirs:
        metrics_path = os.path.join(logs_dir, folder, 'metrics.pt')

        # Check if metrics.pt exists
        if os.path.exists(metrics_path):
            try:
                # Extract tau and seed info
                tau, seed = extract_tau_info(folder)

                if tau and seed:
                    # Load the metrics file
                    metrics = torch.load(metrics_path)

                    # Organize by tau -> seed
                    organized_metrics[tau][seed] = metrics

                    print(f"Loaded metrics from {folder} (tau={tau}, seed={seed})")
            except Exception as e:
                print(f"Error loading {metrics_path}: {str(e)}")

    return organized_metrics

def plot_sfa_tau_comparison(metrics_data, output_file=None):
    tau_values = sorted(metrics_data.keys())

    # Custom color scheme
    # colors = ['#f5f556', '#c2e699', '#78c679', '#31a354', '#006837']  # Yellow for tau=10^-2
    # colors = ['#d7191c', '#fdae61', '#ffffbf', '#abd9e9', '#2c7bb6']
    colors = ['#d7191c', '#fdae61', '#f5f556', '#abd9e9', '#2c7bb6']

    # Create figure with three subplots side by side - adjusted height and width
    fig, axes = plt.subplots(1, 3, figsize=(5.5, 2))

    # Store data for each tau to ensure consistent plotting
    all_data = []

    # Plot data for each tau value
    for i, tau in enumerate(tau_values):
        # Average across seeds
        val_4k_losses = []
        val_16k_losses = []
        val_64k_losses = []
        steps_data = []
        left_plot_key = 'val_loss_4k'
        middle_plot_key = 'val_loss_16k'
        right_plot_key = 'val_loss_64k'

        for seed in metrics_data[tau]:
            metrics = metrics_data[tau][seed]

            # Handle metrics data format
            if isinstance(metrics, list):
                seed_steps = []
                seed_val_4k_losses = []
                seed_val_16k_losses = []
                seed_val_64k_losses = []

                for metric_dict in metrics:
                    if 'step' in metric_dict and left_plot_key in metric_dict and middle_plot_key in metric_dict and right_plot_key in metric_dict:
                        seed_steps.append(metric_dict['step'])
                        seed_val_4k_losses.append(metric_dict[left_plot_key])
                        seed_val_16k_losses.append(metric_dict[middle_plot_key])
                        seed_val_64k_losses.append(metric_dict[right_plot_key])

                if seed_steps and seed_val_4k_losses and seed_val_16k_losses and seed_val_64k_losses:
                    steps_data.append(np.array(seed_steps))
                    val_4k_losses.append(np.array(seed_val_4k_losses))
                    val_16k_losses.append(np.array(seed_val_16k_losses))
                    val_64k_losses.append(np.array(seed_val_64k_losses))

        if val_4k_losses and val_16k_losses and val_64k_losses:
            # Make sure all step arrays are the same length by using the shortest one
            min_length = min(len(steps) for steps in steps_data)
            aligned_steps = steps_data[0][:min_length]
            aligned_val_4k_losses = [loss[:min_length] for loss in val_4k_losses]
            aligned_val_16k_losses = [loss[:min_length] for loss in val_16k_losses]
            aligned_val_64k_losses = [loss[:min_length] for loss in val_64k_losses]

            # Average across seeds
            mean_val_4k_loss = np.mean(aligned_val_4k_losses, axis=0)
            mean_val_16k_loss = np.mean(aligned_val_16k_losses, axis=0)
            mean_val_64k_loss = np.mean(aligned_val_64k_losses, axis=0)

            # Use custom color scheme
            color_idx = tau_values.index(tau) % len(colors)
            color = colors[color_idx]

            # Store data for plotting
            all_data.append((tau, aligned_steps, mean_val_4k_loss, mean_val_16k_loss, mean_val_64k_loss, color))

    # Plot all data
    for tau, steps, val_4k_loss, val_16k_loss, val_64k_loss, color in all_data:
        # Main plots
        axes[0].plot(steps, val_4k_loss, color=color, linewidth=1)
        axes[1].plot(steps, val_16k_loss, color=color, linewidth=1)
        axes[2].plot(steps, val_64k_loss, color=color, linewidth=1)

    # Set titles and labels
    axes[0].set_title('Train @4k / Val @4k', fontsize=title_font_size)
    axes[1].set_title('Train @4k / Val @16k', fontsize=title_font_size)
    axes[2].set_title('Train @4k / Val @64k', fontsize=title_font_size)

    # Add dashed line at 3.28 if flag is True
    if DASHED_LINE_3p28:
        axes[0].axhline(y=3.28, color='gray', linestyle='--', alpha=0.7)
        axes[1].axhline(y=3.28, color='gray', linestyle='--', alpha=0.7)
        axes[2].axhline(y=3.28, color='gray', linestyle='--', alpha=0.7)

    # Configure main axes
    for i, ax in enumerate(axes):
        ax.set_xlabel('Step', fontsize=axis_label_font_size)
        ax.set_xlim(3250, 4578)
        # Set same y limits on both plots
        ax.set_ylim(3.20, 3.4)
        ax.set_xticks([3500, 4000, 4500])

        # Only set y-axis label and ticks on the first plot
        if i == 0:
            ax.set_ylabel('Val Loss', fontsize=axis_label_font_size)
            ax.set_yticks([3.2, 3.3, 3.4])
        else:
            ax.set_yticks([])  # Remove y ticks on second plot

        ax.tick_params(axis='both', which='major', labelsize=ticks_font_size)

    # Create legend elements - sort by tau value and place on left plot
    legend_elements = []
    for tau, color in sorted([(data[0], data[5]) for data in all_data]):
        if tau == 0.01:
            tau_label ="$\\tau=10^{-2}$"
        elif tau == 0.1:
            tau_label ="$\\tau=10^{-1}$"
        elif tau == 1.0:
            tau_label ="$\\tau=10^{0}$"
        elif tau == 10.0:
            tau_label ="$\\tau=10^{1}$"
        elif tau == 100.0:
            tau_label ="$\\tau=10^{2}$"
        else:
            raise ValueError(f"Invalid tau value: {tau}")
        legend_elements.append(Line2D([0], [0], color=color, linewidth=2, label=tau_label))

    # Add vertical legend to the first plot
    axes[0].legend(handles=legend_elements, loc='upper right', **legend_settings)

    # Adjust layout
    plt.tight_layout()
    plt.subplots_adjust(wspace=0.03)  # Reduce spacing between subplots

    # Save figure
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    plt.savefig(output_file, format='pdf', bbox_inches='tight')
    print(f"Plot saved to {output_file}")
    plt.close(fig)

if __name__ == "__main__":
    metrics_data = load_sfa_tau_metrics()
    plot_sfa_tau_comparison(metrics_data, output_file='./sfa_tau_4k.pdf')