import matplotlib.pyplot as plt
import re
import pandas as pd
import os
import sys # Import the sys module to access command-line arguments
import json
import numpy as np
from datetime import datetime

def parse_log_metadata(file_path):
    """
    Parses the log file header to extract noise regularization parameters.
    Returns a tuple: (noise_strength, noise_param)
    """
    noise_strength = None
    noise_param = None
    try:
        with open(file_path, 'r') as f:
            for line in f:
                # Stop reading once we hit actual training data (e.g., "Epoch 1/4000")
                if re.match(r"Epoch \d+/", line):
                    break
                
                strength_match = re.search(r"Noise Regularization Strength:\s*([\d.]+)", line)
                if strength_match:
                    noise_strength = float(strength_match.group(1))
                
                param_match = re.search(r"Noise Regularization Parameter:\s*([\d.]+)", line)
                if param_match:
                    noise_param = float(param_match.group(1))
    except Exception as e:
        print(f"Error reading metadata from '{file_path}': {e}", file=sys.stderr)
    
    return noise_strength, noise_param

def load_metrics_json(log_file_path):
    """
    Load metrics.json from the same directory as the log file.
    Returns the metrics dict if found, None otherwise.
    """
    # Get directory of log file
    log_dir = os.path.dirname(log_file_path)
    
    # Try to find metrics.json in same directory or in a seeds subdirectory
    metrics_paths = [
        os.path.join(log_dir, "combined_results.json"),
        os.path.join(log_dir, "seeds", "seed_*", "metrics.json")
    ]
    
    for pattern in metrics_paths:
        # Handle glob patterns
        if '*' in pattern:
            import glob
            files = glob.glob(pattern)
            if files:
                # Use first seed's metrics as representative
                try:
                    with open(files[0], 'r') as f:
                        return json.load(f)
                except:
                    continue
        else:
            if os.path.exists(pattern):
                try:
                    with open(pattern, 'r') as f:
                        return json.load(f)
                except:
                    continue
    
    return None

def parse_log_data(file_path):
    """
    Parses the log file and extracts Epoch, Train Loss, Val Loss, and Val Accuracy.
    """
    data = []
    try:
        with open(file_path, 'r') as f:
            for line in f:
                # Try pattern with Val Perplexity (newer format)
                match = re.search(
                    r"Epoch (\d+)/\d+, Train Loss: ([\d.]+), Val Loss: ([\d.]+), Val Perplexity: [\d.]+, Val Accuracy: ([\d.]+)%",
                    line
                )
                if not match:
                    # Try pattern without Val Perplexity (older format)
                    match = re.search(
                        r"Epoch (\d+)/\d+, Train Loss: ([\d.]+), Val Loss: ([\d.]+), Val Accuracy: ([\d.]+)%",
                        line
                    )
                if match:
                    epoch = int(match.group(1))
                    train_loss = float(match.group(2))
                    val_loss = float(match.group(3))
                    val_accuracy = float(match.group(4))
                    data.append({
                        'Epoch': epoch,
                        'Train Loss': train_loss,
                        'Val Loss': val_loss,
                        'Val Accuracy': val_accuracy
                    })
        return pd.DataFrame(data)
    except FileNotFoundError:
        print(f"Error: File not found at '{file_path}'. Please check the path.", file=sys.stderr)
        return pd.DataFrame() # Return an empty DataFrame on error
    except Exception as e:
        print(f"Error parsing file '{file_path}': {e}", file=sys.stderr)
        return pd.DataFrame() # Return an empty DataFrame on error

def plot_training_metrics_multi(datasets):
    """
    Generates two overlay plots for multiple datasets: one for losses and one for accuracies.
    datasets: list of tuples (df, label, noise_strength)
    """
    plt.style.use('seaborn-v0_8-darkgrid') # A nice style for plots

    # Set matplotlib to high resolution
    plt.rcParams['figure.dpi'] = 300
    plt.rcParams['savefig.dpi'] = 300
    
    # Define color palette
    colors = plt.cm.tab10(range(len(datasets)))
    
    # Plot 1: Training and Validation Losses
    plt.figure(figsize=(10, 6))
    for i, (df, label, noise_strength) in enumerate(datasets):
        # Use dashed lines for baseline (noise_strength = 0)
        linestyle_train = '--' if noise_strength == 0 else '--'
        linestyle_val = ':' if noise_strength == 0 else '-'
        linewidth = 2.5 if noise_strength == 0 else 1.5
        alpha = 0.9 if noise_strength == 0 else 0.7
        
        plt.plot(df['Epoch'], df['Train Loss'], 
                label=f'{label} Train', 
                linestyle=linestyle_train, 
                marker='o' if noise_strength == 0 else None,
                markersize=3 if noise_strength == 0 else 0,
                markevery=max(1, len(df) // 20),
                linewidth=linewidth,
                alpha=alpha, 
                color=colors[i])
        plt.plot(df['Epoch'], df['Val Loss'], 
                label=f'{label} Val', 
                linestyle=linestyle_val, 
                marker='x' if noise_strength == 0 else None,
                markersize=4 if noise_strength == 0 else 0,
                markevery=max(1, len(df) // 20),
                linewidth=linewidth,
                alpha=alpha, 
                color=colors[i])
    
    plt.title('Training and Validation Losses: WikiText NTK', fontsize=18)
    plt.xlabel('Epoch', fontsize=15)
    plt.ylabel('Loss', fontsize=15)
    plt.legend(fontsize=10, loc='best', ncol=2)
    plt.grid(True)
    plt.tight_layout()
    
    folder_name = f"./figures/"
    os.makedirs(folder_name, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")

    file_name = f'{folder_name}/losses_wikitext_{timestamp}.png'
    plt.savefig(file_name)
    print(f"Saved: {file_name}")

    # Plot 2: Validation Accuracies
    plt.figure(figsize=(10, 6))
    for i, (df, label, noise_strength) in enumerate(datasets):
        # Use dashed lines for baseline (noise_strength = 0)
        linestyle = ':' if noise_strength == 0 else '-'
        linewidth = 2.5 if noise_strength == 0 else 1.5
        alpha = 0.9 if noise_strength == 0 else 0.7
        marker = '^' if noise_strength == 0 else None
        markersize = 4 if noise_strength == 0 else 0
        
        plt.plot(df['Epoch'], df['Val Accuracy'], 
                label=f'{label}', 
                linestyle=linestyle, 
                marker=marker,
                markersize=markersize,
                markevery=max(1, len(df) // 20),
                linewidth=linewidth,
                alpha=alpha, 
                color=colors[i])
    
    plt.title('Validation Accuracies: WikiText NTK', fontsize=18)
    plt.xlabel('Epoch', fontsize=15)
    plt.ylabel('Accuracy (%)', fontsize=15)
    plt.legend(fontsize=10, loc='best')
    plt.grid(True)
    plt.tight_layout()

    file_name = f'{folder_name}/accuracies_wikitext_{timestamp}.png'
    plt.savefig(file_name)
    print(f"Saved: {file_name}")

def plot_noise_stability(datasets_with_metrics, fixed_rho=0.5):
    """
    Plot noise stability vs epoch for datasets that have this data.
    
    Args:
        datasets_with_metrics: list of tuples (df, label, noise_strength, metrics)
        fixed_rho: The rho value to use for plotting (default 0.5)
    """
    plt.style.use('seaborn-v0_8-darkgrid')
    plt.rcParams['figure.dpi'] = 300
    plt.rcParams['savefig.dpi'] = 300
    
    # Filter datasets that have noise stability data
    datasets_with_ns = []
    for df, label, noise_strength, metrics in datasets_with_metrics:
        if metrics and 'noise_stabilities' in metrics:
            ns_dict = metrics['noise_stabilities']
            # Find the closest rho value to fixed_rho
            if ns_dict:
                available_rhos = [float(k) for k in ns_dict.keys()]
                closest_rho = min(available_rhos, key=lambda x: abs(x - fixed_rho))
                datasets_with_ns.append((df, label, noise_strength, ns_dict[str(closest_rho)], closest_rho))
    
    if not datasets_with_ns:
        print("No noise stability data found in any dataset.")
        return
    
    # Plot noise stability
    plt.figure(figsize=(10, 6))
    colors = plt.cm.tab10(range(len(datasets_with_ns)))
    
    for i, (df, label, noise_strength, ns_values, actual_rho) in enumerate(datasets_with_ns):
        linestyle = ':' if noise_strength == 0 else '-'
        linewidth = 2.5 if noise_strength == 0 else 1.5
        alpha = 0.9 if noise_strength == 0 else 0.7
        marker = '^' if noise_strength == 0 else None
        
        # ns_values might be a list (single seed) or list of lists (multiple seeds)
        if isinstance(ns_values[0], list):
            # Multiple seeds - take mean
            ns_array = np.array(ns_values)
            ns_mean = np.mean(ns_array, axis=0)
            epochs = range(1, len(ns_mean) + 1)
            plt.plot(epochs, ns_mean, 
                    label=f'{label} (ρ={actual_rho})', 
                    linestyle=linestyle,
                    marker=marker,
                    markersize=4 if marker else 0,
                    markevery=max(1, len(ns_mean) // 20),
                    linewidth=linewidth,
                    alpha=alpha,
                    color=colors[i])
        else:
            # Single seed
            epochs = range(1, len(ns_values) + 1)
            plt.plot(epochs, ns_values,
                    label=f'{label} (ρ={actual_rho})',
                    linestyle=linestyle,
                    marker=marker,
                    markersize=4 if marker else 0,
                    markevery=max(1, len(ns_values) // 20),
                    linewidth=linewidth,
                    alpha=alpha,
                    color=colors[i])
    
    plt.title('Noise Stability vs Epoch', fontsize=18)
    plt.xlabel('Epoch', fontsize=15)
    plt.ylabel('Noise Stability', fontsize=15)
    plt.legend(fontsize=10, loc='best')
    plt.grid(True)
    plt.tight_layout()
    
    folder_name = f"./figures/"
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    file_name = f'{folder_name}/noise_stability_{timestamp}.png'
    plt.savefig(file_name)
    print(f"Saved: {file_name}")

def calculate_spectral_tail_percentage(noise_stability, f_l2_squared, fixed_rho=0.5, degree=15):
    """
    Calculate the percentage of spectral tail for degrees >= degree.
    
    Formula: percentage = delta / (1 - fixed_rho^degree)
    where delta = Stability / ||f||_2^2
    
    Args:
        noise_stability: Measured noise stability value
        f_l2_squared: ||f||_2^2 value
        fixed_rho: The rho value used (default 0.5)
        degree: Minimum degree for tail calculation (default 15)
    
    Returns:
        percentage: Spectral tail percentage
    """
    if f_l2_squared == 0:
        return 0
    
    delta = noise_stability / f_l2_squared
    denominator = 1 - (fixed_rho ** degree)
    
    if denominator == 0:
        return 0
    
    percentage = delta / denominator
    return percentage * 100  # Convert to percentage

def plot_spectral_tail(datasets_with_metrics, fixed_rho=0.5, degree=15):
    """
    Plot spectral tail percentage vs epoch.
    
    Formula: percentage = delta / (1 - fixed_rho^degree)
    where delta = Stability / ||f||_2^2
    
    Note: Only uses noise stability measured at ρ=0.5
    
    Args:
        datasets_with_metrics: list of tuples (df, label, noise_strength, metrics)
        fixed_rho: The rho value to use (always 0.5 for spectral tail)
        degree: Minimum degree for tail calculation (default 15)
    """
    plt.style.use('seaborn-v0_8-darkgrid')
    plt.rcParams['figure.dpi'] = 300
    plt.rcParams['savefig.dpi'] = 300
    
    # Filter datasets that have noise stability at rho=0.5 and ||f||_2^2 data
    datasets_with_spectral = []
    for df, label, noise_strength, metrics in datasets_with_metrics:
        if metrics and 'noise_stabilities' in metrics and 'model_l2_squared' in metrics:
            ns_dict = metrics['noise_stabilities']
            l2_sq_values = metrics['model_l2_squared']
            
            # Only use rho=0.5 for spectral tail calculation
            if ns_dict and l2_sq_values:
                available_rhos = [float(k) for k in ns_dict.keys()]
                closest_rho = min(available_rhos, key=lambda x: abs(x - fixed_rho))
                
                # Only include if we have rho very close to 0.5
                if abs(closest_rho - fixed_rho) < 0.01:
                    ns_values = ns_dict[str(closest_rho)]
                    datasets_with_spectral.append((df, label, noise_strength, ns_values, l2_sq_values))
    
    if not datasets_with_spectral:
        print("No spectral tail data found (need noise_stabilities at ρ=0.5 and model_l2_squared in metrics.json).")
        return
    
    # Plot spectral tail percentage
    plt.figure(figsize=(10, 6))
    colors = plt.cm.tab10(range(len(datasets_with_spectral)))
    
    for i, (df, label, noise_strength, ns_values, l2_sq_values) in enumerate(datasets_with_spectral):
        linestyle = ':' if noise_strength == 0 else '-'
        linewidth = 2.5 if noise_strength == 0 else 1.5
        alpha = 0.9 if noise_strength == 0 else 0.7
        marker = 's' if noise_strength == 0 else None
        
        # Calculate spectral tail percentage for each epoch using fixed_rho=0.5
        spectral_percentages = []
        
        # Handle both single seed and multiple seeds
        if isinstance(ns_values[0], list):
            # Multiple seeds - calculate for mean values
            ns_array = np.array(ns_values)
            ns_mean = np.mean(ns_array, axis=0)
            l2_array = np.array(l2_sq_values) if isinstance(l2_sq_values[0], list) else l2_sq_values
            l2_mean = np.mean(l2_array, axis=0) if isinstance(l2_sq_values[0], list) else l2_sq_values
            
            for ns_val, l2_val in zip(ns_mean, l2_mean):
                # Always use fixed_rho=0.5 in the formula
                percentage = calculate_spectral_tail_percentage(ns_val, l2_val, fixed_rho, degree)
                spectral_percentages.append(percentage)
        else:
            # Single seed
            for ns_val, l2_val in zip(ns_values, l2_sq_values):
                # Always use fixed_rho=0.5 in the formula
                percentage = calculate_spectral_tail_percentage(ns_val, l2_val, fixed_rho, degree)
                spectral_percentages.append(percentage)
        
        epochs = range(1, len(spectral_percentages) + 1)
        plt.plot(epochs, spectral_percentages,
                label=f'{label} (ρ={fixed_rho}, d≥{degree})',
                linestyle=linestyle,
                marker=marker,
                markersize=4 if marker else 0,
                markevery=max(1, len(spectral_percentages) // 20),
                linewidth=linewidth,
                alpha=alpha,
                color=colors[i])
    
    plt.title(f'Spectral Tail Percentage (degrees ≥ {degree}) vs Epoch', fontsize=18)
    plt.xlabel('Epoch', fontsize=15)
    plt.ylabel('Spectral Tail %', fontsize=15)
    plt.legend(fontsize=10, loc='best')
    plt.grid(True)
    plt.tight_layout()
    
    folder_name = f"./figures/"
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    file_name = f'{folder_name}/spectral_tail_{timestamp}.png'
    plt.savefig(file_name)
    print(f"Saved: {file_name}")

if __name__ == "__main__":
    # Check if at least 2 log files are provided
    if len(sys.argv) < 3:
        print("Usage: python create_plots.py <path_to_log_file_1> <path_to_log_file_2> [path_to_log_file_3] ...", file=sys.stderr)
        sys.exit(1) # Exit with an error code

    log_files = sys.argv[1:]  # All arguments are log file paths
    
    datasets = []
    datasets_with_metrics = []
    
    for file_path in log_files:
        # Parse metadata to get noise parameters
        noise_strength, noise_param = parse_log_metadata(file_path)
        
        # Create label using gamma (γ) for strength and rho (ρ) for parameter
        if noise_strength is None:
            noise_strength = 0.0
        
        if noise_param is not None:
            label = f"γ={noise_strength}, ρ={noise_param}"
        else:
            label = f"γ={noise_strength}"
        
        # Parse training data
        df = parse_log_data(file_path)
        
        if df.empty:
            print(f"Warning: Skipping '{file_path}' due to missing or unparseable data.", file=sys.stderr)
            continue
        
        # Try to load metrics.json for additional plots
        metrics = load_metrics_json(file_path)
        
        datasets.append((df, label, noise_strength))
        datasets_with_metrics.append((df, label, noise_strength, metrics))
    
    if len(datasets) < 2:
        print("Error: Need at least 2 valid log files to generate plots.", file=sys.stderr)
        sys.exit(1)
    
    # Plot training metrics (losses and accuracies)
    plot_training_metrics_multi(datasets)
    
    # Plot noise stability if data is available
    plot_noise_stability(datasets_with_metrics, fixed_rho=0.5)
    
    # Plot spectral tail (placeholder - needs ||f||_2^2 data)
    plot_spectral_tail(datasets_with_metrics, fixed_rho=0.5, degree=15)
