import torch
import matplotlib.pyplot as plt
import numpy as np
from itertools import islice

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def slice_dict(d: dict, n: int) -> dict:
    """
    Get the first `n` items from the dictionary `d`.

    Args:
        d (dict): Input dictionary.
        n (int): Number of items to extract from the dictionary.

    Returns:
        dict: A dictionary containing the first `n` items from the input dictionary.
    """
    # Use itertools.islice to extract the first `n` items from the dictionary.
    return dict(islice(d.items(), n))


def visualise_learned_scalars(dict_scalars: dict, two_over_L: float):
    """
    Visualize the learned scalar values over iterations with a comparison to a constant reference line.

    Args:
        dict_scalars (dict): Dictionary where keys are iterations (e.g., time steps) and values are learned scalar values.
        two_over_L (float): Reference scalar value (2/L) to be displayed as a horizontal line.

    Returns:
        None: Displays the plot.
    """
    # Slice the dictionary to show the first 100 scalars
    dict_scalars = slice_dict(dict_scalars, 100)

    # Set up the style and figure for the plot
    plt.style.use('seaborn-whitegrid')
    plt.figure(figsize=(8, 6))

    # Plot the learned scalars (dict keys are iterations, values are learned scalars)
    plt.plot(1 + np.array(list(dict_scalars.keys())), list(dict_scalars.values()), linewidth=2, label=r'Learned $\alpha_t$')

    # Add a horizontal line for the constant reference value (2/L)
    plt.axhline(y=two_over_L, color='r', linestyle='--', label=r'$2/L$', linewidth=4)

    # Label the axes and customize the tick parameters
    plt.xlabel(r'$t$', fontsize=24)
    plt.ylabel(r'Learned $\alpha_t$', fontsize=24)
    plt.tick_params(axis='both', which='major', labelsize=24)
    plt.tick_params(axis='both', which='minor', labelsize=24)

    # Adjust the layout and display the legend
    plt.tight_layout()
    plt.legend(fontsize=20)

    # Show the plot
    plt.show()


def visualise_learned_matrices(dict_matrices: dict, indices='all'):
    """
    Visualize a sequence of learned matrices over iterations using a heatmap for each matrix.

    Args:
        dict_matrices (dict): Dictionary of learned matrices, where keys are iteration numbers and values are matrices (torch.Tensor).
        indices (str or list, optional): List of iteration indices to visualize. If 'all', all matrices will be visualized. Defaults to 'all'.

    Returns:
        None: Displays the plot with heatmaps of the matrices.
    """
    # If indices is 'all', use all available keys from the dictionary
    if indices == 'all':
        indices = list(range(len(dict_matrices)))

    # Number of subplots is one more than the number of matrices (for colorbar)
    num_subplots = len(indices) + 1

    # Set the size of the figure based on the number of subplots
    plt.figure(figsize=(11 * len(indices), 10))

    # Find the global min and max values to normalize the heatmaps across all matrices
    vmin = min([torch.min(dict_matrices[i]).item() for i in indices])
    vmax = max([torch.max(dict_matrices[i]).item() for i in indices])

    # Adjust spacing between subplots
    plt.subplots_adjust(wspace=0.05, hspace=0)

    # Loop through the indices and create a heatmap for each matrix
    for num, i in enumerate(indices):
        plt.subplot(1, num_subplots, num + 1)
        # Display the matrix as an image (detach from computation graph and move to CPU)
        plt.imshow(dict_matrices[i].squeeze().cpu().detach().numpy(), vmin=vmin, vmax=vmax, cmap='jet', aspect='auto')
        plt.title(f'Iteration {i}', fontsize=80)
        plt.axis('off')  # Turn off the axis for the heatmaps

    # Create a new subplot for the colorbar
    cax = plt.subplot(1, num_subplots, num + 2)
    cbar = plt.colorbar(cax=cax)

    # Adjust the position and size of the colorbar subplot
    cax_position = cax.get_position()
    cax_width = cax_position.width * 0.1  # Adjust this value to change the width of the colorbar
    cax.set_position([cax_position.x0, cax_position.y0, cax_width, cax_position.height])

    # Customize the tick size of the colorbar
    cbar.ax.tick_params(labelsize=80)

    # Show the entire plot with all subplots
    plt.show()

    