import os
import numpy as np
import json
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

def plot_loss_curve(stats_file, save_path=None):
    """
    Plots the training loss curve from a stats JSONL file.
    
    Args:
        stats_file (str): Path to the stats.jsonl file containing training progress.
        save_path (str, optional): Path to save the plot as an image. Defaults to None.
    """
    kimg_list = []
    loss_list = []

    # Read and parse the stats file
    with open(stats_file, 'r') as f:
        for line in f:
            data = json.loads(line)
            kimg = data["Progress/kimg"]["mean"]
            loss = data["Loss/loss"]["mean"]
            kimg_list.append(kimg)
            loss_list.append(loss)

    # Plot the loss curve
    plt.figure(figsize=(10, 6))
    plt.plot(kimg_list, loss_list, label="Training Loss")
    plt.xlabel("KImg (Thousands of Images Seen)")
    plt.ylabel("Loss")
    plt.title("Training Loss Curve")
    plt.legend()
    plt.grid()
    plot_name = "loss_kimg.png"
    if save_path:
        plt.savefig(os.path.join(save_path, plot_name))
        print(f"Plot saved to {os.path.join(save_path, plot_name)}")
    plt.show()

    # Compute loss derivative (change in loss per step)
    loss_diff = np.abs(np.diff(loss_list))  # Absolute change in loss
    kimg_diff = kimg_list[:-1]  # Align kimg indices for loss_diff

    


    # Ensure there are enough points to analyze
    if len(loss_diff) < 5:
        print("Not enough data points for zoomed-in views.")
        return

    # Identify regions of interest
    high_change_indices = np.argsort(loss_diff)[-5:]  # Top 5 highest changes
    plateau_indices = np.argsort(loss_diff)[:5]  # Smallest 5 changes (plateaus)

    # Sort indices for plotting order
    zoom_areas = sorted(set(high_change_indices.tolist() + plateau_indices.tolist()))

    # Define PDF output path
    pdf_name = os.path.join(save_path, "loss_zoomed_views.pdf") if save_path else "loss_zoomed_views.pdf"

    # Save zoomed-in plots as PDF
    with PdfPages(pdf_name) as pdf:
        for i, idx in enumerate(zoom_areas):
            start = max(0, idx - 5)  # Ensure start is valid
            end = min(len(kimg_list), idx + 5)  # Ensure end is valid

            plt.figure(figsize=(8, 5))
            plt.plot(kimg_list[start:end], loss_list[start:end], label=f"Zoom {i+1}", marker='o', markersize=3)
            plt.xlabel("KImg")
            plt.ylabel("Loss")
            plt.title(f"Training Loss (Zoomed View {i+1})")
            plt.legend()
            plt.grid()
            
            # Save figure to PDF
            pdf.savefig()
            plt.close()

    print(f"Zoomed-in plots saved as {pdf_name}")
    # Free remaining memory
    del loss_diff, kimg_diff
    del kimg_list, loss_list

    return 

if __name__ == "__main__":
    # Example usage
    import argparse
    parser = argparse.ArgumentParser(description="Plot training loss curve from stats.jsonl.")
    parser.add_argument("stats_file", type=str, help="Path to stats.jsonl file.")
    parser.add_argument("--save_path", type=str, default=None, help="Path to save the plot image.")

    args = parser.parse_args()
    plot_loss_curve(args.stats_file, save_path=args.save_path)