import numpy as np
from numpy.random import rand

# Submission code for the Novel View Editing Adaptor for unseen view consistent 3D editing
# This is the implementation code for each module of the NVE-Adaptor. 
# To facilitate understanding of the operation, we provide the pseudo code of each module below.

def gaussian_pdf(x, mean, sigma):
    """
    Isotropic Gaussian PDF in 3D: N(x | mean, sigma^2 I).
    x and mean are 3D vectors, sigma is a scalar std.
    Returns the probability density at x.
    """
    dim = len(x) # dimension (3 for 3D)
    cov_inv = np.eye(dim) / (sigma**2) # inverse of sigma^2 * I
    diff = x - mean                   
    exponent = -0.5 * diff.T @ cov_inv @ diff  # -1/2 (x-μ)^T Σ^{-1} (x-μ)
    denom = (2 * np.pi * (sigma**2)) ** (dim / 2)  # normalization constant
    return np.exp(exponent) / denom  # Gaussian probability density

# Implementation of P(x | X, Y) in Equation (5) in main paper 
# (Equation (5) defines P(x | X), but this can be formulated with input of X and Y, mentioned in Line 290 - 291)
def compute_mixture_probability(x, ref_points, new_points, sigma):
    """
    Compute mixture probability P(x | X, Y) where
      - X = reference set ref_points
      - Y = new_points (already sampled)
      - x is a single 3D point
    Returns P(x | X, Y).
    """
    total_points = len(ref_points) + len(new_points)
    if total_points == 0:
        return 0.0
    
    pdf_sum = 0.0
    # Sum of Gaussians from reference set
    for rp in ref_points:
        pdf_sum += gaussian_pdf(x, rp, sigma)
    # Sum of Gaussians from new points
    for np_ in new_points:
        pdf_sum += gaussian_pdf(x, np_, sigma)
    # Average
    return pdf_sum / total_points

# Equation (5) in main paper: tilde(P)(x) = (1 - P(x)) / Z
# This is function-style imaplementation, where the contexts in the function are integrated into def estimate_normalizing_constant()
def inverse_probability(x, ref_points, new_points, sigma, normalize_const):
    """
    Inverse probability tilde(P)(x) = (1 - P(x)) / Z
    """
    p_val = compute_mixture_probability(x, ref_points, new_points, sigma)
    inv_val = max(0.0, 1.0 - p_val)  # ensure non-negative
    return inv_val / normalize_const

# Equation (5) in main paper: Z = ∫ (1 - P(x)) dx by Monte Carlo sampling inside
def estimate_normalizing_constant(ref_points, new_points, sigma,
                                  sample_size=100000,
                                  bbox_min=-2.0, bbox_max=2.0):
    """
    Estimate Z = ∫ (1 - P(x)) dx by Monte Carlo sampling inside
    a bounding box. Increase sample_size and adjust bounding box as needed.
    """
    samples = np.random.uniform(bbox_min, bbox_max, (sample_size, 3))
    values = []
    for s in samples:
        p_val = compute_mixture_probability(s, ref_points, new_points, sigma)
        inv_val = max(0.0, 1.0 - p_val) # compute (1 - P(x)), clamp at 0
        values.append(inv_val)
    volume = (bbox_max - bbox_min) ** 3
    mean_val = np.mean(values) # Monte Carlo estimate (average)
    return mean_val * volume # Z ≈ mean * volume

# single novel view point sampling function in Line 724 of 3DG-PVS algorithm.
def sample_from_inverse_distribution(ref_points, new_points, sigma,
                                     bbox_min=-2.0, bbox_max=2.0,
                                     max_iter=100000):
    """
    Simple rejection sampling to draw a single sample x
    from tilde(P)(x). For demonstration purposes only.
    """
    Z = estimate_normalizing_constant(ref_points, new_points, sigma,
                                      sample_size=50000,
                                      bbox_min=bbox_min,
                                      bbox_max=bbox_max)
    # rejection sampling loop
    for _ in range(max_iter):
        candidate = np.random.uniform(bbox_min, bbox_max, 3) # sample candidate point in box
        p_val = compute_mixture_probability(candidate, ref_points, new_points, sigma)
        inv_val = max(1.0 - p_val, 0.0) # compute (1 - P(x)), clamp at 0
        if rand() < (inv_val / (Z * 2.0)): # accept following the probability 
            return candidate

    return None

def is_in_exclusion_region(x, exclusion_region_func=None):
    """
    Check whether a point x is inside the exclusion region.
    """
    if exclusion_region_func is None:
        return False
    else:
        return exclusion_region_func(x)

# 3DG-PVS algorithm in Line 703 - Line 740 of Supplementary
def sample_3dg_pvs(ref_points, T_star,
                   exclusion_region_func=None,
                   sigma=0.4,
                   bbox_min=-2.0, bbox_max=2.0):
    """
    Main function to sample T_star novel viewpoints.
    ref_points: list or array of shape (T, 3) - reference viewpoints
    T_star: number of novel viewpoints to sample
    exclusion_region_func: function that checks outlier region
    sigma: isotropic Gaussian std
    bbox_min, bbox_max: bounding box for sampling
    """
    Y = []

    for k in range(1, T_star + 1):
        if k == 1:
            while True:
                candidate = sample_from_inverse_distribution( # first sample conditioned only on ref_points
                    ref_points, [], sigma, bbox_min, bbox_max
                )
                if candidate is None:
                    raise RuntimeError("Sampling failed. Consider expanding bounding box or increasing max_iter.")
                if not is_in_exclusion_region(candidate, exclusion_region_func):
                    Y.append(candidate)
                    break
        else:
            while True:
                candidate = sample_from_inverse_distribution(  # later samples conditioned on ref_points and Y
                    ref_points, Y, sigma, bbox_min, bbox_max
                )
                if candidate is None:
                    raise RuntimeError("Sampling failed. Consider expanding bounding box or increasing max_iter.")
                if not is_in_exclusion_region(candidate, exclusion_region_func):
                    Y.append(candidate)
                    break

    return np.array(Y) # return sampled viewpoints as (T_star, 3) array

# Building cube-shaped bounding box in terms of Equation (6) ~ (8)
def compute_bounding_box(points, coverage=0.9):
    """
    Compute an axis-aligned bounding box that contains at least a given percentage of points.
    Returns center, size.
    """
    lower = np.percentile(points, (1 - coverage) / 2 * 100, axis=0)
    upper = np.percentile(points, (1 + coverage) / 2 * 100, axis=0)
    center = (upper + lower) / 2
    size = upper - lower
    return center, size

# Filtering (Outlier removal) interms of Equation (9)
def make_box_exclusion_region(center, size, offset=(0.0, 0.0, 0.0)):
    """
    Returns a function that returns True if a point is inside the given box (+ offset).
    """
    cx, cy, cz = center
    sx, sy, sz = size
    dx, dy, dz = offset
    min_x = cx - (sx + dx) / 2
    max_x = cx + (sx + dx) / 2
    min_y = cy - (sy + dy) / 2
    max_y = cy + (sy + dy) / 2
    min_z = cz - (sz + dz) / 2
    max_z = cz + (sz + dz) / 2

    def exclusion_func(point):
        x, y, z = point
        return (min_x <= x <= max_x) and (min_y <= y <= max_y) and (min_z <= z <= max_z)
    return exclusion_func

