import matplotlib.pyplot as plt


def plot_latent_grid_2d(grid, dim1_name=None, dim2_name=None, figsize=(10, 10)):
    """
    Plot a grid of images generated by varying latent dimensions.
    
    Args:
        grid: Tensor of shape (n_steps, n_steps, C, H, W) containing generated images
        dim1_name: Label for the first dimension (optional)
        dim2_name: Label for the second dimension (optional)
        figsize: Size of the figure (width, height)
        
    Returns:
        fig: matplotlib figure object that can be logged to wandb
    """
    fig, axes = plt.subplots(grid.shape[0], grid.shape[1], figsize=figsize)
    
    # If single channel, repeat to make it RGB
    if grid.shape[2] == 1:
        grid = grid.repeat(1, 1, 3, 1, 1)
    
    # Normalize to [0, 1] if needed
    if grid.max() > 1.0 or grid.min() < 0.0:
        grid = (grid - grid.min()) / (grid.max() - grid.min())
    
    for i in range(grid.shape[0]):
        for j in range(grid.shape[1]):
            img = grid[i, j].permute(1, 2, 0).cpu().numpy()
            axes[i, j].imshow(img)
            axes[i, j].axis('off')
    
    # Add labels if provided
    if dim1_name:
        fig.text(0.5, 0.02, f'{dim1_name}', ha='center', va='center')
    if dim2_name:
        fig.text(0.02, 0.5, f'{dim2_name}', ha='center', va='center', rotation=90)
    
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.tight_layout(pad=0)
    fig.patch.set_visible(False)
    
    return fig

