import torch
import matplotlib.pyplot as plt
import os

def load_and_display_metrics(metrics_path):
    """
    Load a PyTorch metrics file and display its contents
    
    Parameters
    ----------
    metrics_path : str
        Path to the metrics file (.pt)
    """
    # Check if file exists
    if not os.path.exists(metrics_path):
        print(f"Error: File '{metrics_path}' not found.")
        return
    
    # Load the metrics file
    metrics = torch.load(metrics_path)
    
    # Display the metrics information
    print("Training Metrics Summary:")
    print("-" * 50)
    
    # Display all available keys in the metrics file
    print(f"Metrics contains the following data: {list(metrics.keys())}")
    
    # Print general metrics
    if 'current_epoch' in metrics:
        print(f"Current Epoch: {metrics['current_epoch']}")
    if 'total_epochs' in metrics:
        print(f"Total Epochs: {metrics['total_epochs']}")
    if 'total_time' in metrics:
        print(f"Total Training Time: {metrics['total_time']:.2f} seconds")
    if 'average_epoch_time' in metrics:
        print(f"Average Epoch Time: {metrics['average_epoch_time']:.2f} seconds")
    
    # Plot loss history if available
    if 'loss_history' in metrics:
        loss_history = metrics['loss_history']
        plt.figure(figsize=(10, 6))
        plt.plot(range(1, len(loss_history) + 1), loss_history)
        plt.title('Training Loss History')
        plt.xlabel('Epoch')
        plt.ylabel('Loss (SSIM)')
        plt.grid(True)
        plt.tight_layout()
        
        # Save the plot
        plot_path = os.path.join(os.path.dirname(metrics_path), 'loss_history_plot.png')
        plt.savefig(plot_path)
        print(f"Loss history plot saved to: {plot_path}")
        
        # Display the plot
        plt.show()
        
        # Print final loss
        print(f"Final Loss: {loss_history[-1]:.6f}")
        print(f"Best Loss: {min(loss_history):.6f} (Epoch {loss_history.index(min(loss_history)) + 1})")
    
    # Plot epoch times if available
    if 'epoch_times' in metrics:
        epoch_times = metrics['epoch_times']
        plt.figure(figsize=(10, 6))
        plt.plot(range(1, len(epoch_times) + 1), epoch_times)
        plt.title('Epoch Completion Times')
        plt.xlabel('Epoch')
        plt.ylabel('Time (seconds)')
        plt.grid(True)
        plt.tight_layout()
        
        # Save the plot
        plot_path = os.path.join(os.path.dirname(metrics_path), 'epoch_times_plot.png')
        plt.savefig(plot_path)
        print(f"Epoch times plot saved to: {plot_path}")
        
        # Display the plot
        plt.show()
    
    return metrics

# Path to the metrics file
metrics_path = '/home/dan5/optics_recon/Optics_Recon_Project/transunet_param/metrics_final_lr2e-05_batch32_ssim.pt'

#/home/dan5/optics_recon/Optics_Recon_Project/vit_param/metrics_epoch49_lr5e-05_batch80_ssim.pt
#/home/dan5/optics_recon/Optics_Recon_Project/unet_param/metrics_epoch47_lr0.0001_batch80_ssim.pt

# Load and display metrics
metrics = load_and_display_metrics(metrics_path)