import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
from pathlib import Path

def read_loss_file(filepath):
    """Read a loss file and return epoch, train_loss, val_loss"""
    data = []
    with open(filepath, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) == 3:
                epoch = int(parts[0])
                train_loss = float(parts[1])
                val_loss = float(parts[2])
                data.append([epoch, train_loss, val_loss])
    
    return pd.DataFrame(data, columns=['epoch', 'train_loss', 'val_loss'])

def plot_loss_curves(file_pattern="loss_values_0.000001_*.txt", save_plot=True):
    """
    Plot training and validation loss curves from multiple runs
    
    Args:
        file_pattern: Pattern to match loss files
        save_plot: Whether to save the plot as an image
    """
    
    # Find all matching files
    files = list(Path('.').glob(file_pattern))
    
    if not files:
        print(f"No files found matching pattern: {file_pattern}")
        return
    
    # Sort files to ensure consistent ordering
    files.sort()
    
    # Set up the plot
    plt.figure(figsize=(12, 8))
    
    # Colors for different runs
    colors = plt.cm.tab10(np.linspace(0, 1, len(files)))
    
    all_train_losses = []
    all_val_losses = []
    
    # Plot each run
    for i, filepath in enumerate(files):
        df = read_loss_file(filepath)
        
        # Extract run number from filename
        run_num = filepath.stem.split('_')[-1]
        
        # Plot training loss
        plt.plot(df['epoch'], df['train_loss'], 
                color=colors[i], linestyle='-', alpha=0.7,
                label=f'Train Run {run_num}')
        
        # Plot validation loss
        plt.plot(df['epoch'], df['val_loss'], 
                color=colors[i], linestyle='--', alpha=0.7,
                label=f'Val Run {run_num}')
        
        all_train_losses.append(df['train_loss'].values)
        all_val_losses.append(df['val_loss'].values)
    
    # Calculate and plot average curves
    if len(all_train_losses) > 1:
        # Ensure all arrays have the same length
        min_length = min(len(arr) for arr in all_train_losses)
        train_losses_array = np.array([arr[:min_length] for arr in all_train_losses])
        val_losses_array = np.array([arr[:min_length] for arr in all_val_losses])
        
        epochs = np.arange(min_length)
        
        # Calculate mean and std
        train_mean = np.mean(train_losses_array, axis=0)
        train_std = np.std(train_losses_array, axis=0)
        val_mean = np.mean(val_losses_array, axis=0)
        val_std = np.std(val_losses_array, axis=0)
        
        # Plot average curves with error bands
        plt.plot(epochs, train_mean, color='black', linewidth=2, 
                label='Train Average', linestyle='-')
        plt.fill_between(epochs, train_mean - train_std, train_mean + train_std,
                        color='black', alpha=0.2)
        
        plt.plot(epochs, val_mean, color='red', linewidth=2, 
                label='Val Average', linestyle='--')
        plt.fill_between(epochs, val_mean - val_std, val_mean + val_std,
                        color='red', alpha=0.2)
    
    # Customize the plot
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title('Training and Validation Loss Curves (Learning Rate: 1e-6)', fontsize=14)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True, alpha=0.3)
    
    # Set y-axis to log scale for better visualization
    plt.yscale('log')
    
    # Tight layout to prevent legend cutoff
    plt.tight_layout()
    
    # Save plot if requested
    if save_plot:
        plt.savefig('loss_curves.png', dpi=300, bbox_inches='tight')
        plt.savefig('loss_curves.pdf', bbox_inches='tight')
        print("Plots saved as 'loss_curves.png' and 'loss_curves.pdf'")
    
    plt.show()

def plot_individual_runs(file_pattern="loss_values_0.000001_*.txt"):
    """Plot each run in a separate subplot"""
    
    files = list(Path('.').glob(file_pattern))
    files.sort()
    
    if not files:
        print(f"No files found matching pattern: {file_pattern}")
        return
    
    n_runs = len(files)
    cols = min(3, n_runs)
    rows = (n_runs + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows))
    if n_runs == 1:
        axes = [axes]
    elif rows == 1:
        axes = axes
    else:
        axes = axes.flatten()
    
    for i, filepath in enumerate(files):
        df = read_loss_file(filepath)
        run_num = filepath.stem.split('_')[-1]
        
        ax = axes[i]
        ax.plot(df['epoch'], df['train_loss'], 'b-', label='Training Loss', alpha=0.8)
        ax.plot(df['epoch'], df['val_loss'], 'r--', label='Validation Loss', alpha=0.8)
        
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.set_title(f'Run {run_num}')
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_yscale('log')
    
    # Hide empty subplots
    for i in range(len(files), len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    plt.savefig('individual_loss_curves.png', dpi=300, bbox_inches='tight')
    print("Individual plots saved as 'individual_loss_curves.png'")
    plt.show()

def print_summary_stats(file_pattern="loss_values_0.000001_*.txt"):
    """Print summary statistics for all runs"""
    
    files = list(Path('.').glob(file_pattern))
    files.sort()
    
    print("="*60)
    print("SUMMARY STATISTICS")
    print("="*60)
    
    all_final_train = []
    all_final_val = []
    all_min_train = []
    all_min_val = []
    
    for filepath in files:
        df = read_loss_file(filepath)
        run_num = filepath.stem.split('_')[-1]
        
        final_train = df['train_loss'].iloc[-1]
        final_val = df['val_loss'].iloc[-1]
        min_train = df['train_loss'].min()
        min_val = df['val_loss'].min()
        
        all_final_train.append(final_train)
        all_final_val.append(final_val)
        all_min_train.append(min_train)
        all_min_val.append(min_val)
        
        print(f"Run {run_num}:")
        print(f"  Final - Train: {final_train:.4f}, Val: {final_val:.4f}")
        print(f"  Best  - Train: {min_train:.4f}, Val: {min_val:.4f}")
        print()
    
    if len(files) > 1:
        print("AVERAGE ACROSS RUNS:")
        print(f"  Final - Train: {np.mean(all_final_train):.4f} ± {np.std(all_final_train):.4f}")
        print(f"          Val:   {np.mean(all_final_val):.4f} ± {np.std(all_final_val):.4f}")
        print(f"  Best  - Train: {np.mean(all_min_train):.4f} ± {np.std(all_min_train):.4f}")
        print(f"          Val:   {np.mean(all_min_val):.4f} ± {np.std(all_min_val):.4f}")

if __name__ == "__main__":
    # Main plotting function
    plot_loss_curves()
    
    # Plot individual runs
    plot_individual_runs()
    
    # Print summary statistics
    print_summary_stats()