import h5py
import numpy as np
from scipy.interpolate import UnivariateSpline
from scipy.stats import norm
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from scipy.special import comb
from scipy.integrate import quad
import scipy.stats as stats
import math

# --- 1. Data Loading or Generation ---
def load_vectors(file_name, dataset_name='embeddings', num_to_load=50000):
    """
    Loads vectors from an HDF5 file. If the file doesn't exist,
    it generates synthetic anisotropic data for demonstration.
    """
    if os.path.exists(file_name):
        print(f"Loading vectors from {file_name}...")
        with h5py.File(file_name, 'r') as f:
            # Load a subset for efficiency
            vectors = f[dataset_name][:num_to_load]
        # IMPORTANT: L2-normalize the vectors for cosine similarity calculations
        norms = np.linalg.norm(vectors, axis=1, keepdims=True)
        # Avoid division by zero
        norms[norms == 0] = 1
        return vectors / norms
    else:
        print(f"File {file_name} not found. Generating synthetic anisotropic data.")
        # Generate data with a "preferred" direction to simulate anisotropy
        num_samples = num_to_load
        dim = 50
        # Create a random base
        random_base = np.random.randn(num_samples, dim)
        # Create a strong directional vector
        preferred_direction = np.random.randn(dim)
        preferred_direction /= np.linalg.norm(preferred_direction)
        # Add the preferred direction component to the random base
        vectors = random_base * 0.7 + np.outer(np.random.randn(num_samples), preferred_direction) * 0.3
        # L2-normalize the final vectors
        norms = np.linalg.norm(vectors, axis=1, keepdims=True)
        norms[norms == 0] = 1
        return vectors / norms

# --- 2. Monte Carlo Simulation for a single vector pair ---
def estimate_collision_prob_monte_carlo(u, v, num_simulations=100_000):
    """
    Estimates the collision probability for a single pair of vectors (u, v)
    using Monte Carlo simulation.

    Args:
        u (np.ndarray): The first vector (1D).
        v (np.ndarray): The second vector (1D).
        num_simulations (int): The number of random projections to perform.

    Returns:
        float: The empirical collision probability.
    """
    dim = u.shape[0]
    # Generate all random projection vectors at once for efficiency
    random_vectors = np.random.randn(num_simulations, dim)

    # Project u and v onto the random vectors
    projections_u = np.dot(random_vectors, u)
    projections_v = np.dot(random_vectors, v)

    # Get the signs of the projections (the SimHash bits)
    hashes_u = np.sign(projections_u)
    hashes_v = np.sign(projections_v)

    # Count the number of times the hashes are equal
    collisions = np.sum(hashes_u == hashes_v)

    return collisions / num_simulations


# --- Main Script ---
if __name__ == '__main__':
    # --- Configuration ---
    HDF5_FILE = ""
    DATASET_NAME = 'train'  # Assumed name of the vector dataset
    NUM_VECTOR_PAIRS_TO_SAMPLE = 2000 # Number of (u,v) pairs to sample
    NUM_MONTE_CARLO_SIMULATIONS = 10_000 # Projections per pair. Increase for more precision.

    # --- Load Data ---
    vectors = load_vectors(HDF5_FILE, DATASET_NAME)
    num_vectors, dim = vectors.shape
    print(f"Loaded {num_vectors} vectors of dimension {dim}.")

    # --- 3. Run Experiment: Sample pairs and run simulations ---
    print(f"Sampling {NUM_VECTOR_PAIRS_TO_SAMPLE} vector pairs and running simulations...")
    
    # Store results here
    cosine_similarities = []
    empirical_probabilities = []

    # Randomly sample pairs of indices without replacement for each pair
    indices1 = np.random.randint(0, num_vectors, NUM_VECTOR_PAIRS_TO_SAMPLE)
    indices2 = np.random.randint(0, num_vectors, NUM_VECTOR_PAIRS_TO_SAMPLE)

    for i in tqdm(range(NUM_VECTOR_PAIRS_TO_SAMPLE)):
        idx1, idx2 = indices1[i], indices2[i]
        # Ensure we are not comparing a vector to itself
        if idx1 == idx2:
            continue

        u, v = vectors[idx1], vectors[idx2]

        # a) Calculate the true cosine similarity
        s = np.dot(u, v)
        # Clip to handle potential floating point inaccuracies outside [-1, 1]
        s = np.clip(s, -1.0, 1.0)
        
        # b) Run Monte Carlo simulation to get the empirical collision probability
        p_empirical = estimate_collision_prob_monte_carlo(u, v, NUM_MONTE_CARLO_SIMULATIONS)
        
        cosine_similarities.append(s)
        empirical_probabilities.append(p_empirical)

    # --- 4. Fit the Calibrated Function ---
    print("Fitting the calibration function to the empirical data...")
    
    cosine_similarities = np.array(cosine_similarities)
    empirical_probabilities = np.array(empirical_probabilities)

    # Sort the values by cosine similarity for spline fitting
    sort_indices = np.argsort(cosine_similarities)
    s_sorted = cosine_similarities[sort_indices]
    p_empirical_sorted = empirical_probabilities[sort_indices]

    # Use UnivariateSpline for non-parametric fitting.
    # The smoothing factor `s` is crucial. s=0 means interpolation (noisy).
    # A larger `s` creates a smoother curve. The ideal value may require tuning.
    # A good starting point is len(s_sorted) or slightly less.
    smoothing_factor = len(s_sorted) * 0.01 
    calibrated_function = UnivariateSpline(s_sorted, p_empirical_sorted, s=smoothing_factor)
    
    print("Calibration function fitted successfully.")

    # --- 5. Visualize the Results ---
    print("Generating comparison plot...")

    # Generate a smooth range of cosine similarities for plotting the functions
    s_range = np.linspace(-1.0, 1.0, 500)

    # Calculate theoretical probability
    # Use np.arccos for vectorized operation
    theta = np.arccos(s_range)
    p_theoretical = 1 - theta / np.pi

    # Calculate calibrated probability using our new function
    p_calibrated = calibrated_function(s_range)

    plt.style.use('seaborn-v0_8-whitegrid')
    plt.figure(figsize=(12, 8))
    
    # Plot the raw empirical data points
    plt.scatter(cosine_similarities, empirical_probabilities, 
                alpha=0.2, s=10, label='Empirical Data Points (Monte Carlo)')

    # Plot the theoretical curve
    plt.plot(s_range, p_theoretical, 'r-', linewidth=3, 
             label='Theoretical Probability: $p(s) = 1 - \\arccos(s)/\\pi$')
             
    # Plot our new calibrated curve
    plt.plot(s_range, p_calibrated, 'g--', linewidth=3, 
             label='Calibrated Probability: $p_{calibrated}(s)$ (Spline Fit)')

    plt.title('SimHash Collision Probability: Theoretical vs. Calibrated', fontsize=16)
    plt.xlabel('Cosine Similarity (s)', fontsize=12)
    plt.ylabel('Probability of Collision', fontsize=12)
    plt.legend(fontsize=12)
    plt.xlim(-1.05, 1.05)
    plt.ylim(-0.05, 1.05)
    plt.grid(True)
    # plt.show()
    plt.savefig("calibrated")

   