import matplotlib.pyplot as plt
import json
import sys
import pandas as pd
from datetime import datetime
import os

def plot_json_metrics_overlayed(json_file_path):
    """
    Parses a JSON file containing training metrics and generates a single
    overlay plot for validation loss and noise stabilities vs. epochs.
    Noise stability data is duplicated to match the epoch count of validation loss.
    """
    try:
        with open(json_file_path, 'r') as f:
            data = json.load(f)
    except FileNotFoundError:
        print(f"Error: JSON file not ytfound at '{json_file_path}'. Please check the path.", file=sys.stderr)
        sys.exit(1)
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from '{json_file_path}'. Ensure it's a valid JSON file.", file=sys.stderr)
        sys.exit(1)
    except Exception as e:
        print(f"An unexpected error occurred while reading the JSON file: {e}", file=sys.stderr)
        sys.exit(1)

    # Extracting data
    val_losses = data.get('val_losses')
    noise_stabilities = data.get('noise_stabilities')

    if not val_losses or not isinstance(val_losses, list) or not val_losses[0]:
        print("Error: 'val_losses' data is missing or in an unexpected format.", file=sys.stderr)
        sys.exit(1)
    
    if not noise_stabilities or not isinstance(noise_stabilities, dict):
        print("Error: 'noise_stabilities' data is missing or in an unexpected format.", file=sys.stderr)
        sys.exit(1)

    # Get validation losses and determine the full epoch range
    val_losses_list = val_losses[0]
    epochs = [i + 1 for i in range(len(val_losses_list))] # Epochs starting from 1
    full_epochs_count = len(epochs)

    plt.style.use('seaborn-v0_8-darkgrid') # Set a nice plotting style

    # Create figure and a primary y-axis for Validation Loss
    fig, ax1 = plt.subplots(figsize=(14, 7), dpi=300)

    # Plot Validation Loss on the primary y-axis
    ax1.plot(epochs, val_losses_list, label='Validation Loss', color='blue', linestyle='-', marker='o', markersize=4, alpha=0.8)
    ax1.set_xlabel('Epoch', fontsize=15)
    ax1.set_ylabel('Validation Loss', color='blue', fontsize=15)
    ax1.tick_params(axis='y', labelcolor='blue')
    ax1.grid(True)

    # Create a secondary y-axis for Noise Stabilities
    ax2 = ax1.twinx()
    
    # Plot Noise Stabilities on the secondary y-axis
    stability_lines = [] # To collect line objects for the legend
    color_cycle = plt.cm.get_cmap('Dark2', len(noise_stabilities)) # Get a color map for different noise levels

    for i, (noise_level, stability_data_list) in enumerate(noise_stabilities.items()):
        if not isinstance(stability_data_list, list) or not stability_data_list[0]:
            print(f"Warning: Noise stability data for level '{noise_level}' is missing or in an unexpected format. Skipping.", file=sys.stderr)
            continue
        
        raw_stability_values = stability_data_list[0]
        
        # Duplicate each noise stability value to match the number of epochs for val_loss
        resampled_stability_values = []
        for val in raw_stability_values:
            resampled_stability_values.extend([val, val])
        
        # Trim if duplication makes it slightly longer than full_epochs_count
        resampled_stability_values = resampled_stability_values[:full_epochs_count]

        line, = ax2.plot(epochs, resampled_stability_values, 
                         label=f'Noise Stability (level: {noise_level})', 
                         linestyle='--', 
                         marker='x', 
                         markersize=4, 
                         alpha=0.8, 
                         color=color_cycle(i))
        stability_lines.append(line)
    
    ax2.set_ylabel('Noise Stability', color='firebrick', fontsize=15) # Choose a distinct color (red-ish) for the secondary y-axis label.
    ax2.tick_params(axis='y', labelcolor='firebrick', labelsize=15)

    # Combine legends from both axes
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax2.legend(lines1 + lines2, labels1 + labels2, loc='best')

    plt.title('Validation Loss and Noise Stabilities Over Epochs', fontsize=18)
    plt.tight_layout() # Adjust layout to prevent labels from overlapping

    folder_name = f"./figures/"
    os.makedirs(folder_name, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")

    file_name = f'{folder_name}/json_losses_{timestamp}.png'
    plt.savefig(file_name)

if __name__ == "__main__":
    if len(sys.argv) != 2:
        print("Usage: python your_script_name.py <path_to_json_file>", file=sys.stderr)
        sys.exit(1)
    
    json_file_path = sys.argv[1]
    plot_json_metrics_overlayed(json_file_path)