import json
import matplotlib.pyplot as plt
import pandas as pd
import os

def load_training_metrics(results_dir="./results"):
    """Load training metrics from JSON file"""
    metrics_file = os.path.join(results_dir, "training_metrics.json")
    
    if not os.path.exists(metrics_file):
        print(f"Metrics file not found: {metrics_file}")
        print("Make sure training has been run with the updated training.py")
        return None
    
    with open(metrics_file, 'r') as f:
        metrics = json.load(f)
    
    return pd.DataFrame(metrics)

def plot_training_metrics(df, save_dir="./results"):
    """Plot training and evaluation metrics"""
    
    # Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Plot 1: Training Loss
    if 'loss' in df.columns:
        train_data = df[df['loss'].notna()]
        axes[0, 0].plot(train_data['step'], train_data['loss'], 'b-', label='Training Loss')
        axes[0, 0].set_xlabel('Step')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].set_title('Training Loss')
        axes[0, 0].grid(True)
        axes[0, 0].legend()
    
    # Plot 2: Evaluation Loss (if available)
    if 'eval_loss' in df.columns:
        eval_data = df[df['eval_loss'].notna()]
        if not eval_data.empty:
            axes[0, 1].plot(eval_data['step'], eval_data['eval_loss'], 'r-', label='Evaluation Loss')
            axes[0, 1].set_xlabel('Step')
            axes[0, 1].set_ylabel('Loss')
            axes[0, 1].set_title('Evaluation Loss')
            axes[0, 1].grid(True)
            axes[0, 1].legend()
    
    # Plot 3: Learning Rate
    if 'learning_rate' in df.columns:
        lr_data = df[df['learning_rate'].notna()]
        axes[1, 0].plot(lr_data['step'], lr_data['learning_rate'], 'g-', label='Learning Rate')
        axes[1, 0].set_xlabel('Step')
        axes[1, 0].set_ylabel('Learning Rate')
        axes[1, 0].set_title('Learning Rate Schedule')
        axes[1, 0].grid(True)
        axes[1, 0].legend()
    
    # Plot 4: Token Accuracy
    if 'mean_token_accuracy' in df.columns:
        acc_data = df[df['mean_token_accuracy'].notna()]
        axes[1, 1].plot(acc_data['step'], acc_data['mean_token_accuracy'], 'm-', label='Token Accuracy')
        axes[1, 1].set_xlabel('Step')
        axes[1, 1].set_ylabel('Accuracy')
        axes[1, 1].set_title('Mean Token Accuracy')
        axes[1, 1].grid(True)
        axes[1, 1].legend()
    
    plt.tight_layout()
    
    # Save plot
    plot_path = os.path.join(save_dir, "training_metrics_plot.png")
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    print(f"Training metrics plot saved to: {plot_path}")
    
    plt.show()

def print_training_summary(df):
    """Print summary statistics"""
    print("\n=== Training Summary ===")
    
    if 'loss' in df.columns:
        train_loss = df['loss'].dropna()
        if not train_loss.empty:
            print(f"Initial Training Loss: {train_loss.iloc[0]:.4f}")
            print(f"Final Training Loss: {train_loss.iloc[-1]:.4f}")
            print(f"Best Training Loss: {train_loss.min():.4f}")
    
    if 'eval_loss' in df.columns:
        eval_loss = df['eval_loss'].dropna()
        if not eval_loss.empty:
            print(f"Initial Eval Loss: {eval_loss.iloc[0]:.4f}")
            print(f"Final Eval Loss: {eval_loss.iloc[-1]:.4f}")
            print(f"Best Eval Loss: {eval_loss.min():.4f}")
    
    if 'mean_token_accuracy' in df.columns:
        accuracy = df['mean_token_accuracy'].dropna()
        if not accuracy.empty:
            print(f"Initial Token Accuracy: {accuracy.iloc[0]:.4f}")
            print(f"Final Token Accuracy: {accuracy.iloc[-1]:.4f}")
            print(f"Best Token Accuracy: {accuracy.max():.4f}")
    
    print(f"Total Training Steps: {df['step'].max()}")
    print(f"Total Epochs: {df['epoch'].max():.2f}")

def main():
    # Load metrics
    df = load_training_metrics()
    
    if df is not None:
        print("Available columns:", df.columns.tolist())
        
        # Print summary
        print_training_summary(df)
        
        # Plot metrics
        plot_training_metrics(df)
        
        # Save metrics to CSV for easy analysis
        csv_path = "./results/training_metrics.csv"
        df.to_csv(csv_path, index=False)
        print(f"Metrics also saved to: {csv_path}")

if __name__ == "__main__":
    main()