#!/usr/bin/env python3
"""
plot_histograms.py

Load a predictions file and interactively plot histograms for pred_noise, noise,
and c across epochs, with left/right arrow navigation.
"""
import argparse
import torch
import numpy as np
import matplotlib.pyplot as plt

from plot_training_curve import load_predictions


def interactive_histograms(predictions, bins=50):
    """
    Create an interactive histogram viewer for pred_noise, noise, and c.
    Use left/right arrow keys to navigate through epochs.

    Parameters:
    ------------
    predictions : dict
        Mapping direction -> list of tensors [pred_noise, noise, c]
    bins : int
        Number of bins for the histograms.
    """
    directions = list(predictions.keys())
    n_dirs = len(directions)
    # Assume each direction has the same number of epochs
    n_epochs = min(len(predictions[d]) for d in directions)
    epoch = 0

    fig, axes = plt.subplots(n_dirs, 3, figsize=(15, 5 * n_dirs))
    if n_dirs == 1:
        axes = axes[np.newaxis, :]

    def plot_epoch(idx):
        # Clear all axes
        for row in axes:
            for ax in row:
                ax.clear()

        # Plot per-direction
        for i, direction in enumerate(directions):
            try:
                tensor = predictions[direction][idx]
            except (IndexError, KeyError):
                continue
            # extract arrays
            pred_noise = tensor[0].squeeze().cpu().numpy().flatten()
            noise = tensor[1].squeeze().cpu().numpy().flatten()
            c = tensor[2].squeeze().cpu().numpy().flatten()

            # histograms
            axes[i, 0].hist(pred_noise, bins=bins)
            axes[i, 0].set_title(f"{direction} pred_noise (epoch {idx})")
            axes[i, 0].set_xlabel("Value")
            axes[i, 0].set_ylabel("Frequency")

            axes[i, 1].hist(noise, bins=bins)
            axes[i, 1].set_title(f"{direction} noise (epoch {idx})")
            axes[i, 1].set_xlabel("Value")
            axes[i, 1].set_ylabel("Frequency")

            axes[i, 2].hist(c, bins=bins)
            axes[i, 2].set_title(f"{direction} c (epoch {idx})")
            axes[i, 2].set_xlabel("Value")
            axes[i, 2].set_ylabel("Frequency")

        fig.suptitle(f"Epoch {idx}", fontsize=16)
        fig.tight_layout(rect=[0, 0.03, 1, 0.95])
        fig.canvas.draw_idle()

    def on_key(event):
        nonlocal epoch
        if event.key == "right":
            epoch = (epoch + 1) % n_epochs
            plot_epoch(epoch)
        elif event.key == "left":
            epoch = (epoch - 1) % n_epochs
            plot_epoch(epoch)

    # Connect key press events
    fig.canvas.mpl_connect("key_press_event", on_key)

    # Initial display
    plot_epoch(epoch)
    plt.show()


def main():
    parser = argparse.ArgumentParser(
        description="Interactive histograms for predictions"
    )
    parser.add_argument(
        "--predictions",
        type=str,
        required=True,
        help="Folder containing predictions file (PyTorch .pt or .pth)",
    )
    parser.add_argument("--bins", type=int, default=50, help="Number of histogram bins")
    args = parser.parse_args()

    predictions = load_predictions(args.predictions)
    interactive_histograms(predictions, bins=args.bins)


if __name__ == "__main__":
    main()
