import numpy as np
import torch
import scipy.ndimage as ndimage


def match_histograms(source, template, alpha=1.0):
    """
    Adjust the pixel values of an image so that its histogram matches that of a target image, with control over the strength of matching.

    Args:
        source (np.ndarray): Source image data to be transformed.
        template (np.ndarray): Template image data to match the histogram of.
        alpha (float): Strength of the histogram matching, 0 (no change) to 1 (full matching).

    Returns:
        matched (np.ndarray): Source image with its histogram partially matched to the template based on alpha.
    """
    oldshape = source.shape
    source = source.ravel()
    template = template.ravel()

    # Get the set of unique pixel values and their corresponding indices and counts
    s_values, bin_idx, s_counts = np.unique(
        source, return_inverse=True, return_counts=True
    )
    t_values, t_counts = np.unique(template, return_counts=True)

    # Calculate cumulative distribution functions for the source and template
    s_quantiles = np.cumsum(s_counts).astype(np.float64) / np.sum(s_counts)
    t_quantiles = np.cumsum(t_counts).astype(np.float64) / np.sum(t_counts)

    # Use linear interpolation of cdf to find new pixel values
    interp_t_values = np.interp(s_quantiles, t_quantiles, t_values)

    # Blend the original and matched histograms
    blended = alpha * interp_t_values[bin_idx].reshape(oldshape) + (
        1 - alpha
    ) * source.reshape(oldshape)

    return blended


def enhance_contrast(latents):
    """
    Enhance the contrast of the latents to make the output more vibrant.

    Args:
        latents (torch.Tensor): The latents to enhance.

    Returns:
        enhanced_latents (torch.Tensor): Latents with enhanced contrast.
    """
    enhanced_latents = np.zeros_like(latents)
    for i in range(latents.shape[0]):
        for j in range(latents.shape[1]):
            process_latents = latents[i, j, :, :].cpu().numpy()
            enhanced_latents[i, j, :, :] = exposure.rescale_intensity(
                process_latents, in_range="image", out_range=(0, 1)
            )
    return torch.from_numpy(enhanced_latents)


def correct_color_offset(
    latents_first_window, latents_other_window, alpha=1.0, process_constrat=False
):
    """
    Corrects color offset in the latents of subsequent windows based on the first window's latents, with adjustable histogram matching strength.

    Args:
        latents_first_window (torch.Tensor): Latents of the first window with shape (4, N1, 64, 64).
        latents_other_window (torch.Tensor): latents for the other windows, with shape (4, N2, 64, 64).
        alpha (float): Strength of the histogram matching.

    Returns:
        corrected_latents (torch.Tensor): Corrected latents for the other windows.
    """

    device = latents_first_window.device
    dtype = latents_first_window.dtype
    latents_first_window = latents_first_window.to(device="cpu", dtype=torch.float32)
    latents_other_window = latents_other_window.to(device="cpu", dtype=torch.float32)

    # Compute mean and std for the first window
    mean_first = latents_first_window.mean(axis=(1, 2, 3), keepdims=True)
    std_first = latents_first_window.std(axis=(1, 2, 3), keepdims=True)

    # Compute mean and std for the current window
    mean_current = latents_other_window.mean(axis=(1, 2, 3), keepdims=True)
    std_current = latents_other_window.std(axis=(1, 2, 3), keepdims=True)

    # Adjust mean and std to match the first window
    latents_other_window = (latents_other_window - mean_current) / (
        std_current + 1e-8
    ) * std_first + mean_first

    # Apply a low-frequency filter (Gaussian filter) for color/illumination correction
    low_freq = ndimage.gaussian_filter(latents_other_window, sigma=(0, 0, 5, 5))
    low_freq = torch.from_numpy(low_freq)

    # Subtract low-frequency component and add it back after color correction
    high_freq = latents_other_window - low_freq
    corrected_latents = low_freq + high_freq

    # Perform histogram matching for each channel in the latent space with controlled strength
    for i in range(latents_other_window.shape[0]):
        for j in range(latents_other_window.shape[1]):
            corrected_latents[i, j, :, :] = torch.tensor(
                match_histograms(
                    corrected_latents[i, j, :, :].numpy(),
                    latents_first_window[i, j, :, :].numpy(),
                    alpha=alpha,
                )
            )

    # Enhance contrast to ensure vibrant colors
    if process_constrat:
        corrected_latents = enhance_contrast(corrected_latents)

    corrected_latents = corrected_latents.to(device=device, dtype=dtype)

    return corrected_latents
