
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

def create_mappings(seed=0):
    """
    Generate 4096 bins, each with 64 unique colors after quantizing RGB to 6 bits.
    
    Returns:
        np.ndarray: Mapping of quantized RGB color (0-63, 0-63, 0-63) to a bin index (0-4095)
    """
    # Seed for reproducibility
    np.random.seed(seed)

    # Create all possible quantized RGB colors (256 ** 2 colors)
    indices = np.array(np.arange(65536))
    colors = np.array(np.meshgrid(np.arange(256), np.arange(256))).T.reshape(-1, 2)
    
    # Shuffle the colors
    np.random.shuffle(indices)
    colors = colors[indices]
    np.save('indices_rand2num', indices.astype(int))

    # Record the reverse mapping
    indices_ = np.zeros(65536)
    for rand_idx, num_idx in enumerate(indices):
        indices_[num_idx] = rand_idx
    np.save('indices_num2rand', indices_.astype(int))
    
    # Assign each set of 16 colors to a bin
    bins = np.arange(4096).repeat(16)
    
    # Create a mapping (dict) from (R, G, B) to bin index
    color_to_bin = np.zeros(65536)
    for color, bin_idx in zip(colors, bins):
        p, q = color
        color_to_bin[(p << 8) + q] = bin_idx
    color_to_bin = color_to_bin.astype(int)
    np.save('color_to_bin', color_to_bin)

    colors = colors.reshape(4096, 16, 2)
    bin_to_colors = np.zeros((4096, 16, 2), dtype=np.int32)
    for ii in range(4096):
        bin_to_colors[ii] = colors[ii]
    np.save('bin_to_colors', bin_to_colors)

    return color_to_bin, bin_to_colors

def image_to_bin_indices_pil(image_path, color_to_bin):
    # Load the image and convert it to RGB
    image = Image.open(image_path).convert('RGB')
    image_np = np.array(image, dtype=np.uint8)
    
    # Map quantized RGB values to bin indices
    H, W, _ = image_np.shape
    bin_indices = np.zeros((H, W), dtype=np.int32)
    
    for i in range(H):
        for j in range(W):
            r, g, b = image_np[i, j]
            bin_indices[i, j] = color_to_bin[(r << 8) + g]
    
    # Calculate the per-bin histogram (4096 bins)
    bin_histogram, _ = np.histogram(bin_indices, bins=np.arange(4097))
    
    return bin_indices, bin_histogram

def plot_colors_bins(color_to_bin, save_path):
    fig, axes = plt.subplots(8, 8, figsize=(12, 12))
    
    for target_bin in range(64):
        # Get all colors in the target bin
        colors_in_bin = [(color >> 8, color & 255, 0) for color, bin_idx in enumerate(color_to_bin) if bin_idx == target_bin]
        
        # Create an image to visualize the colors (4x4 grid for 16 colors)
        color_grid = np.array(colors_in_bin).reshape(4, 4, 3).astype(np.uint8)
        
        # Find the current axis in the 8x8 grid
        ax = axes[target_bin // 8, target_bin % 8]
        
        # Plot using imshow in the corresponding subplot
        ax.imshow(color_grid)
        ax.set_title(f"Bin {target_bin}")
        ax.axis('off')  # Hide the axis for cleaner visual
    
    plt.tight_layout()
    plt.savefig(save_path)

# Example usage
image_path = 'sample.png'
color_to_bin, bin_to_colors = create_mappings()
plot_colors_bins(color_to_bin, 'demo.png')
bin_indices, bin_histogram = image_to_bin_indices_pil(image_path, color_to_bin)
print(f'bin_indices shape: {bin_indices.shape}')
print(f'bin_histogram shape: {bin_histogram.shape}')