import os
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import math
import matplotlib.pyplot as plt
from scipy import stats
from scipy.special import iv
import seaborn as sns

import numpy as np

def random_VMF (mu , kappa , size = None ):
  """
  Von Mises - Fisher distribution sampler with
  mean direction mu and concentration kappa .
  Source : https :// hal . science /hal - 04004568
  """
  # parse input parameters
  n = 1 if size is None else np . product ( size )
  shape = () if size is None else tuple ( np . ravel ( size ))
  mu = np . asarray ( mu )
  mu = mu / np . linalg . norm ( mu )


  (d ,) = mu . shape
  # z component : radial samples pe rp en dic ul ar to mu
  z = np . random . normal (0 , 1 , (n , d) )
  z /= np . linalg . norm (z , axis =1 , keepdims = True )
  z = z - (z @ mu[:, None ]) * mu[None , :]
  z /= np . linalg . norm (z , axis =1 , keepdims = True )

  # sample angles ( in cos and sin form )
  cos = random_VMF_angle(d , kappa , n)
  sin = np . sqrt (1 - cos ** 2)
  # combine angles with the z component
  x = z * sin [:, None ] + cos [:, None ] * mu[None , :]

  return x. reshape ((*shape , d ))

def random_VMF_angle (d: int , k : float , n: int ):
  """
  Generate n iid samples t with density function given by
  p(t) = someConstant * (1-t**2) **((d-3)/2) * exp ( kappa *t)
  """
  alpha = (d - 1) / 2
  t0 = r0 = np.sqrt (1 + ( alpha / k ) ** 2) - alpha / k
  log_t0 = k * t0 + (d - 1 ) * np.log( 1 - r0 * t0 )
  found = 0
  out = []
  
  while found < n:
    m = min(n , int(( n - found)*1.5))
    t = np.random.beta( alpha , alpha , m )
    t = 2 * t - 1
    t = ( r0 + t ) / (1 + r0 * t )
    log_acc = k * t + (d - 1 ) * np.log ( 1 - r0 * t) - log_t0
    t = t[np.random.random ( m) < np.exp( log_acc )]
    out.append (t)
    found += len( out [-1])
  return np.concatenate ( out )[:n]

def sample_vmf(mu, kappa, dim, num_samples=1):
    """
    Sample from von Mises-Fisher distribution using a simplified algorithm
    """
    
    # Generate samples
    samples = []
    for _ in range(num_samples):
        sample = random_VMF(mu, kappa)
        samples.append(sample)
    
    return torch.from_numpy(np.stack(samples))

class CocoValDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [f for f in os.listdir(root_dir) 
                          if f.endswith(('.jpg', '.jpeg', '.png'))][:5000]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_name).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image

def psnr(x, y, is_video=False):
    """ 
    Return PSNR 
    Args:
        x: Image tensor with normalized values (≈ [0,1])
        y: Image tensor with normalized values (≈ [0,1]), ex: original image
        is_video: If True, the PSNR is computed over the entire batch, not on each image separately
    """
    delta = 255 * (x - y)
    delta = delta.reshape(-1, x.shape[-3], x.shape[-2], x.shape[-1])  # BxCxHxW
    peak = 20 * math.log10(255.0)
    avg_on_dims = (0,1,2,3) if is_video else (1,2,3)
    noise = torch.mean(delta**2, dim=avg_on_dims)
    psnr = peak - 10*torch.log10(noise)
    return psnr

def find_model_pairs(model_dir):
    pairs = {}
    files = list(Path(model_dir).glob('*.pth'))
    
    for f in files:
        base_name = f.stem.replace('_enc', '').replace('_dec', '')
        base_name = base_name.replace('_enc_oop', '').replace('_dec_oop', '')
        
        if base_name not in pairs:
            pairs[base_name] = {'enc': None, 'dec': None}
            
        if '_enc' in f.name:
            pairs[base_name]['enc'] = str(f)
        elif '_dec' in f.name:
            pairs[base_name]['dec'] = str(f)
    
    return {k: v for k, v in pairs.items() if v['enc'] and v['dec']}

def analyze_noise_characteristics(test_msgs, extracted_msgs):
    """
    Analyzes noise characteristics, checks for uniform distribution, 
    and validates the von Mises-Fisher model.
    """
    
    # Convert to numpy arrays for analysis
    test_msgs_np = np.array([msg.cpu().numpy() for msg in test_msgs])
    # The extracted_msgs might have an extra dimension, so we squeeze it
    extracted_msgs_np = np.array([msg.cpu().numpy().squeeze() for msg in extracted_msgs])
    
    # Calculate noise vectors
    noise_vectors = extracted_msgs_np - test_msgs_np
    noise_flat = noise_vectors.flatten()
    
    print("=== Component-wise Distribution Analysis ===")
    
    # 1. Test for Uniform distribution (as requested)
    # We test if the noise components are uniformly distributed between their min and max
    min_noise, max_noise = noise_flat.min(), noise_flat.max()
    ks_stat_uniform, ks_p_uniform = stats.kstest(noise_flat, 'uniform', args=(min_noise, max_noise - min_noise))
    print(f"Kolmogorov-Smirnov test for Uniform Dist: statistic={ks_stat_uniform:.4f}, p-value={ks_p_uniform:.6f}")
    if ks_p_uniform < 0.05:
        print("-> Result: The noise component distribution is NOT uniform (p < 0.05). This is expected.")
    else:
        print("-> Result: The noise component distribution might be uniform.")

    # 2. Test for Gaussian distribution
    shapiro_stat, shapiro_p = stats.shapiro(noise_flat[:5000]) # Shapiro-Wilk is limited to 5000 samples
    print(f"Shapiro-Wilk test for Gaussian Dist: statistic={shapiro_stat:.4f}, p-value={shapiro_p:.6f}")
    if shapiro_p < 0.05:
        print("-> Result: The noise component distribution is NOT Gaussian (p < 0.05).")
    else:
        print("-> Result: The noise component distribution appears to be Gaussian.")


    print("\n=== Directional (von Mises-Fisher) Analysis ===")
    
    # 3. Calculate cosine similarities (key statistic for vMF)
    cos_similarities = np.einsum('ij,ij->i', test_msgs_np, extracted_msgs_np) # Efficient dot product for batches
    
    # 4. Estimate vMF kappa parameter from the mean cosine similarity
    d = test_msgs_np.shape[1]  # dimension (256)
    R_bar = np.mean(cos_similarities)
    # A good approximation for kappa given the mean resultant length (R_bar)
    kappa_est = (R_bar * d - R_bar**3) / (1 - R_bar**2)
    print(f"Dimension (d): {d}")
    print(f"Mean Cosine Similarity (R_bar): {R_bar:.6f}")
    print(f"Estimated vMF Kappa (κ): {kappa_est:.4f}")
    if kappa_est > 10:
        print("-> Result: High kappa value suggests noise is highly concentrated around the original message direction, fitting a vMF model well.")
    else:
         print("-> Result: Low kappa value suggests significant directional noise.")

    # 5. Generate plots for visualization
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle('Watermark Noise Distribution Analysis', fontsize=16)

    # Plot 1: Noise component histogram vs Gaussian and Uniform
    sns.histplot(noise_flat, bins=100, stat='density', alpha=0.7, label='Actual Noise Components', ax=axes[0, 0])
    x = np.linspace(min_noise, max_noise, 200)
    # Gaussian fit
    axes[0, 0].plot(x, stats.norm.pdf(x, np.mean(noise_flat), np.std(noise_flat)), 
                    'r-', lw=2, label='Gaussian Fit')
    # Uniform fit
    axes[0, 0].plot(x, stats.uniform.pdf(x, min_noise, max_noise - min_noise), 
                    'g--', lw=2, label='Uniform Fit')
    axes[0, 0].set_title('Noise Component Distribution')
    axes[0, 0].set_xlabel('Noise Value')
    axes[0, 0].set_ylabel('Density')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Plot 2: Q-Q plot for normality
    stats.probplot(noise_flat[::20], dist="norm", plot=axes[0, 1]) # Subsample for clarity
    axes[0, 1].set_title('Q-Q Plot vs. Gaussian')
    axes[0, 1].grid(True, alpha=0.3)

    # Plot 3: Distribution of Cosine Similarities (the most important plot!)
    # We will compare the actual distribution with one simulated from our estimated vMF parameters
    
    # Simulate angles from a vMF distribution with estimated kappa
    simulated_cos_angles = random_VMF_angle(d, kappa_est, n=len(cos_similarities))

    sns.histplot(cos_similarities, bins=50, stat='density', alpha=0.7, label='Actual Data', ax=axes[1, 0], color='blue')
    sns.histplot(simulated_cos_angles, bins=50, stat='density', alpha=0.7, label=f'vMF Simulation (κ={kappa_est:.1f})', ax=axes[1, 0], color='orange')
    axes[1, 0].set_title('Cosine Similarity Distribution (Actual vs. vMF)')
    axes[1, 0].set_xlabel('Cosine Similarity')
    axes[1, 0].set_ylabel('Density')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Plot 4: Noise magnitude distribution
    noise_magnitudes = np.linalg.norm(noise_vectors, axis=1)
    sns.histplot(noise_magnitudes, bins=50, stat='density', alpha=0.7, ax=axes[1, 1])
    axes[1, 1].set_title('Noise Magnitude (||extracted - test||)')
    axes[1, 1].set_xlabel('L2 Norm of Noise Vector')
    axes[1, 1].set_ylabel('Density')
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig('noise_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()

    return {
        "is_uniform": ks_p_uniform > 0.05,
        "is_gaussian_component_wise": shapiro_p > 0.05,
        "vmf_kappa_estimate": kappa_est,
        "mean_cosine_similarity": R_bar,
    }

def find_optimal_strength(encoder, decoder, img, test_msg, target_psnr=42.0, device='cuda'):
    """Binary search to find watermark strength that achieves target PSNR"""
    left = 0.0
    right = 2.0
    tolerance = 0.1
    max_iterations = 1
    
    for _ in range(max_iterations):
        strength = (left + right) / 2
        
        with torch.no_grad():
            watermarked_full = encoder(img, test_msg.unsqueeze(0))
            watermark = watermarked_full - img
            watermarked = img + watermark * strength
            
            current_psnr = psnr(watermarked, img, is_video=False)
            if isinstance(current_psnr, torch.Tensor):
                current_psnr = current_psnr.item()
            
            if abs(current_psnr - target_psnr) < tolerance:
                return strength, watermarked
            
            if current_psnr < target_psnr:
                right = strength
            else:
                left = strength
    
    return strength, watermarked

def sample_from_noise_model(original_message, kappa):
    """
    Samples a noisy message based on the von Mises-Fisher noise model.

    This simulates the output of the decoder by generating a new vector from a
    vMF distribution centered around the original message.

    Args:
        original_message (torch.Tensor or np.ndarray): The original 1D message vector,
                                                        assumed to be L2 normalized.
        kappa (float): The concentration parameter of the vMF distribution.
                       A higher value means less noise. This should be the value
                       estimated from the real data analysis.

    Returns:
        torch.Tensor: A new message vector with vMF noise applied.
    """
    if isinstance(original_message, torch.Tensor):
        mu = original_message.cpu().numpy()
    else:
        mu = original_message
        
    # Ensure mu is normalized (as required by random_VMF)
    mu = mu / np.linalg.norm(mu)
    
    # Sample a new vector from the vMF distribution
    # The shape will be (256,)
    noisy_sample_np = random_VMF(mu=mu, kappa=kappa)
    
    return torch.from_numpy(noisy_sample_np).float()
def validate_vmf_with_ttest(test_msgs_np, extracted_msgs_np, kappa):
    """
    Validates the vMF noise model against real data using a Student's t-test.

    Args:
        test_msgs_np (np.ndarray): Array of original (ground truth) messages.
        extracted_msgs_np (np.ndarray): Array of messages extracted by the decoder.
        kappa (float): The estimated concentration parameter for the vMF model.
    """
    print("\n" + "="*50)
    print("VMF MODEL VALIDATION WITH STUDENT'S T-TEST")
    print("="*50)

    # --- 1. Calculate statistics from REAL data ---
    real_cos_sims = np.einsum('ij,ij->i', test_msgs_np, extracted_msgs_np)
    real_noise_vectors = extracted_msgs_np - test_msgs_np
    real_noise_mags = np.linalg.norm(real_noise_vectors, axis=1)

    # --- 2. Generate SIMULATED data from the vMF model ---
    num_samples = test_msgs_np.shape[0]
    dim = test_msgs_np.shape[1]
    
    simulated_extracted_msgs = np.zeros_like(extracted_msgs_np)
    for i in tqdm(range(num_samples), desc="Simulating vMF samples"):
        # For each original message, sample a new "extracted" message from the vMF model
        simulated_extracted_msgs[i] = random_VMF(mu=test_msgs_np[i], kappa=kappa)

    # --- 3. Calculate statistics from SIMULATED data ---
    simulated_cos_sims = np.einsum('ij,ij->i', test_msgs_np, simulated_extracted_msgs)
    simulated_noise_vectors = simulated_extracted_msgs - test_msgs_np
    simulated_noise_mags = np.linalg.norm(simulated_noise_vectors, axis=1)

    # --- 4. Perform the t-tests and report results ---
    
    # Test 1: Compare Cosine Similarities
    t_stat_cos, p_val_cos = stats.ttest_ind(real_cos_sims, simulated_cos_sims)
    print(f"\n--- T-Test on Cosine Similarities ---")
    print(f"Mean Real Cosine Sim:      {np.mean(real_cos_sims):.6f}")
    print(f"Mean Simulated Cosine Sim: {np.mean(simulated_cos_sims):.6f}")
    print(f"T-statistic: {t_stat_cos:.4f}, P-value: {p_val_cos:.4f}")
    if p_val_cos > 0.05:
        print("✅ RESULT: The means are NOT statistically different. The vMF model accurately reproduces the directional error.")
    else:
        print("❌ RESULT: The means are statistically different. The model may not be perfect.")
        
    # Test 2: Compare Noise Magnitudes
    t_stat_mag, p_val_mag = stats.ttest_ind(real_noise_mags, simulated_noise_mags)
    print(f"\n--- T-Test on Noise Magnitudes (L2 Norm) ---")
    print(f"Mean Real Noise Mag:      {np.mean(real_noise_mags):.6f}")
    print(f"Mean Simulated Noise Mag: {np.mean(simulated_noise_mags):.6f}")
    print(f"T-statistic: {t_stat_mag:.4f}, P-value: {p_val_mag:.4f}")
    if p_val_mag > 0.05:
        print("✅ RESULT: The means are NOT statistically different. The vMF model accurately reproduces the error magnitude.")
    else:
        print("❌ RESULT: The means are statistically different. The model may not be perfect.")
# def analyze_noise_characteristics(test_msgs, extracted_msgs):
#     """
#     Analyze noise characteristics between normalized test messages and extracted messages
    
#     Args:
#         test_msgs: List of normalized test messages (tensors)
#         extracted_msgs: List of corresponding extracted messages (tensors)
    
#     Returns:
#         Dictionary containing noise analysis results
#     """
    
#     # Convert to numpy arrays for analysis
#     test_msgs_np = np.array([msg.cpu().numpy() for msg in test_msgs])
#     extracted_msgs_np = np.array([msg.cpu().numpy() for msg in extracted_msgs])
    
#     # Calculate noise vectors
#     noise_vectors = extracted_msgs_np - test_msgs_np
    
#     # Flatten all noise vectors for component-wise analysis
#     noise_flat = noise_vectors.flatten()
    
#     # 1. Test for Gaussian distribution (component-wise)
#     print("=== Component-wise Gaussian Analysis ===")
    
#     # Normality tests
#     shapiro_stat, shapiro_p = stats.shapiro(noise_flat[:5000])  # Limit for Shapiro-Wilk
#     ks_stat, ks_p = stats.kstest(noise_flat, 'norm', args=(np.mean(noise_flat), np.std(noise_flat)))
#     jb_stat, jb_p = stats.jarque_bera(noise_flat)
    
#     print(f"Shapiro-Wilk test: statistic={shapiro_stat:.4f}, p-value={shapiro_p:.6f}")
#     print(f"Kolmogorov-Smirnov test: statistic={ks_stat:.4f}, p-value={ks_p:.6f}")
#     print(f"Jarque-Bera test: statistic={jb_stat:.4f}, p-value={jb_p:.6f}")
    
#     # 2. von Mises-Fisher distribution analysis (hypersphere)
#     print("\n=== von Mises-Fisher Distribution Analysis ===")
#     # Choose a mean direction for noise
#     mu = np.random.randn(256)
    
#     # Add vMF noise
#     # Generate vMF noise for each embedding
#     vmf_noises = []
#     # for i in tqdm(range(100),total=5000):
#         # Sample a direction similar to mu but with some randomness
#     vmf_noises = sample_vmf(mu, 64, 256, 100)
#         # vmf_noises.append(vmf_sample[0])

#     vmf_noises = vmf_noises.flatten()

#     ks_stat, ks_p = stats.kstest(noise_flat, vmf_noises, args=(np.mean(noise_flat), np.std(noise_flat)))
#     print(f"Kolmogorov-Smirnov test: statistic={ks_stat:.4f}, p-value={ks_p:.6f}")

#     # Calculate dot products between test and extracted messages (cosine similarities)
#     cos_similarities = []
#     for i in range(len(test_msgs)):
#         cos_sim = np.dot(test_msgs_np[i], extracted_msgs_np[i][0])
#         cos_similarities.append(cos_sim)
    
#     cos_similarities = np.array(cos_similarities)
    
#     # Estimate vMF parameters
#     # For vMF, we need to estimate kappa (concentration parameter)
#     # Using method of moments
#     mean_cos_sim = np.mean(cos_similarities)
    
#     # For high-dimensional vMF, the relationship between mean cosine similarity and kappa is:
#     # mean_cos_sim ≈ I_{d/2}(kappa) / I_{d/2-1}(kappa)
#     # where d is the dimension and I_v is the modified Bessel function
    
#     d = test_msgs_np.shape[1]  # dimension
    
#     # Approximate kappa estimation for high dimensions
#     if mean_cos_sim > 0.95:
#         kappa_est = d * (1 - mean_cos_sim) / (2 * (1 - mean_cos_sim**2))
#     else:
#         # More general approximation
#         kappa_est = mean_cos_sim * d / (1 - mean_cos_sim**2)
    
#     print(f"Estimated kappa (concentration parameter): {kappa_est:.4f}")
#     print(f"Mean cosine similarity: {mean_cos_sim:.6f}")
#     print(f"Std cosine similarity: {np.std(cos_similarities):.6f}")
    
#     # 3. Angular analysis
#     print("\n=== Angular Analysis ===")
    
#     # Calculate angles between test and extracted messages
#     angles = np.arccos(np.clip(cos_similarities, -1, 1))
#     angles_degrees = np.degrees(angles)
    
#     print(f"Mean angle: {np.mean(angles_degrees):.4f} degrees")
#     print(f"Std angle: {np.std(angles_degrees):.4f} degrees")
#     print(f"Min angle: {np.min(angles_degrees):.4f} degrees")
#     print(f"Max angle: {np.max(angles_degrees):.4f} degrees")
    
#     # 4. Noise magnitude analysis
#     print("\n=== Noise Magnitude Analysis ===")
    
#     noise_magnitudes = np.linalg.norm(noise_vectors, axis=1)
#     print(f"Mean noise magnitude: {np.mean(noise_magnitudes):.6f}")
#     print(f"Std noise magnitude: {np.std(noise_magnitudes):.6f}")
    
#     # 5. Generate plots
#     fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
#     # Plot 1: Noise component histogram
#     axes[0, 0].hist(noise_flat, bins=50, density=True, alpha=0.7, label='Noise components')
#     x = np.linspace(noise_flat.min(), noise_flat.max(), 100)
#     axes[0, 0].plot(x, stats.norm.pdf(x, np.mean(noise_flat), np.std(noise_flat)), 
#                     'r-', label='Gaussian fit')
#     axes[0, 0].set_title('Noise Component Distribution')
#     axes[0, 0].set_xlabel('Noise Value')
#     axes[0, 0].set_ylabel('Density')
#     axes[0, 0].legend()
#     axes[0, 0].grid(True, alpha=0.3)
    
#     # Plot 2: Q-Q plot for normality
#     stats.probplot(noise_flat[::10], dist="norm", plot=axes[0, 1])  # Subsample for clarity
#     axes[0, 1].set_title('Q-Q Plot (Gaussian)')
#     axes[0, 1].grid(True, alpha=0.3)
    
#     # Plot 3: Cosine similarity distribution
#     axes[0, 2].hist(cos_similarities, bins=30, density=True, alpha=0.7)
#     axes[0, 2].set_title('Cosine Similarity Distribution')
#     axes[0, 2].set_xlabel('Cosine Similarity')
#     axes[0, 2].set_ylabel('Density')
#     axes[0, 2].grid(True, alpha=0.3)
    
#     # Plot 4: Angular distribution
#     axes[1, 0].hist(angles_degrees, bins=30, density=True, alpha=0.7)
#     axes[1, 0].set_title('Angular Distribution')
#     axes[1, 0].set_xlabel('Angle (degrees)')
#     axes[1, 0].set_ylabel('Density')
#     axes[1, 0].grid(True, alpha=0.3)
    
#     # Plot 5: Noise magnitude distribution
#     axes[1, 1].hist(noise_magnitudes, bins=30, density=True, alpha=0.7)
#     axes[1, 1].set_title('Noise Magnitude Distribution')
#     axes[1, 1].set_xlabel('Noise Magnitude')
#     axes[1, 1].set_ylabel('Density')
#     axes[1, 1].grid(True, alpha=0.3)
#     print(len(cos_similarities),len(noise_magnitudes))
#     # Plot 6: Correlation between cos similarity and noise magnitude
#     # axes[1, 2].scatter(cos_similarities[:99], noise_magnitudes[:99], alpha=0.6)
#     # axes[1, 2].set_title('Cosine Similarity vs Noise Magnitude')
#     # axes[1, 2].set_xlabel('Cosine Similarity')
#     # axes[1, 2].set_ylabel('Noise Magnitude')
#     # axes[1, 2].grid(True, alpha=0.3)
    
#     plt.tight_layout()
#     plt.savefig('noise_analysis.png', dpi=300, bbox_inches='tight')
#     plt.show()
    
#     # Return analysis results
#     return {
#         'gaussian_tests': {
#             'shapiro_wilk': {'statistic': shapiro_stat, 'p_value': shapiro_p},
#             'kolmogorov_smirnov': {'statistic': ks_stat, 'p_value': ks_p},
#             'jarque_bera': {'statistic': jb_stat, 'p_value': jb_p}
#         },
#         'vmf_analysis': {
#             'estimated_kappa': kappa_est,
#             'mean_cosine_similarity': mean_cos_sim,
#             'std_cosine_similarity': np.std(cos_similarities)
#         },
#         'angular_stats': {
#             'mean_angle_degrees': np.mean(angles_degrees),
#             'std_angle_degrees': np.std(angles_degrees),
#             'min_angle_degrees': np.min(angles_degrees),
#             'max_angle_degrees': np.max(angles_degrees)
#         },
#         'noise_stats': {
#             'mean_noise_magnitude': np.mean(noise_magnitudes),
#             'std_noise_magnitude': np.std(noise_magnitudes),
#             'mean_component_noise': np.mean(noise_flat),
#             'std_component_noise': np.std(noise_flat)
#         }
#     }

def collect_noise_data(encoder, decoder, val_loader, device='cuda', num_samples=200):
    """Collect test messages and extracted messages for noise analysis"""
    test_msgs = []
    extracted_msgs = []
    target_psnr = 42.0
    
    sample_count = 0
    
    for batch in tqdm(val_loader):
        if sample_count >= num_samples:
            break
            
        img = batch.to(device)
        
        with torch.no_grad():
            for i in range(img.size(0)):
                if sample_count >= num_samples:
                    break
                    
                single_img = img[i:i+1]
                
                # Generate random message for testing
                test_msg = torch.randn(256).to(device)
                test_msg = test_msg / torch.norm(test_msg)
                
                # Find optimal strength and get watermarked image
                strength, watermarked = find_optimal_strength(
                    encoder, decoder, single_img, 
                    test_msg, 
                    target_psnr, device
                )
                
                # Extract message using decoder
                extracted = decoder(watermarked)
                extracted = extracted / torch.norm(extracted)
                
                test_msgs.append(test_msg)
                extracted_msgs.append(extracted)
                sample_count += 1
    
    return test_msgs, extracted_msgs

def main():
    model_dir = "/home/gevennou/videoseal/models"
    val_dir = "/home/gevennou/BIG_storage/Paper2/coco_dataset/val2017"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    batch_size = 1
    
    model_pairs = find_model_pairs(model_dir)
    print(f"Found {len(model_pairs)} encoder/decoder pairs")
    
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor()
    ])
    
    val_dataset = CocoValDataset(val_dir, transform=transform)
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size,
        shuffle=False,
        num_workers=4
    )
    
    print(f"Loaded {len(val_dataset)} validation images")
    
    encoder = torch.jit.load("/home/gevennou/videoseal/models/nautilus_900_enc.pth").to(device).eval()
    decoder = torch.jit.load("/home/gevennou/videoseal/models/nautilus_900_dec.pth").to(device).eval()
    
    # Collect noise data
    print("Collecting test and extracted messages...")
    test_msgs, extracted_msgs = collect_noise_data(encoder, decoder, val_loader, device)
    
    # Analyze noise characteristics
    print(f"\nAnalyzing noise characteristics from {len(test_msgs)} samples...")
    noise_analysis = analyze_noise_characteristics(test_msgs, extracted_msgs)
    
    # Print summary
    print("\n" + "="*50)
    print("NOISE ANALYSIS SUMMARY")
    print("="*50)
    
    # Determine if noise is more Gaussian or vMF-like
    gaussian_p_values = [
        noise_analysis['gaussian_tests']['shapiro_wilk']['p_value'],
        noise_analysis['gaussian_tests']['kolmogorov_smirnov']['p_value'],
        noise_analysis['gaussian_tests']['jarque_bera']['p_value']
    ]
    
    gaussian_tests_pass = sum([p > 0.05 for p in gaussian_p_values])
    
    print(f"Gaussian tests passed: {gaussian_tests_pass}/3")
    print(f"Mean cosine similarity: {noise_analysis['vmf_analysis']['mean_cosine_similarity']:.6f}")
    print(f"Estimated vMF kappa: {noise_analysis['vmf_analysis']['estimated_kappa']:.4f}")
    print(f"Mean angular error: {noise_analysis['angular_stats']['mean_angle_degrees']:.4f}°")
    
    if gaussian_tests_pass >= 2:
        print("\n🔍 CONCLUSION: Noise appears to be more GAUSSIAN-like")
    else:
        print("\n🔍 CONCLUSION: Noise does NOT follow a simple Gaussian distribution")
        
    if noise_analysis['vmf_analysis']['mean_cosine_similarity'] > 0.95:
        print("🔍 CONCLUSION: High cosine similarity suggests vMF-like behavior on hypersphere")
    else:
        print("🔍 CONCLUSION: Lower cosine similarity suggests significant directional noise")

from scipy import stats
import torch.nn.functional as F

def compare_noise_models(test_msgs_np, extracted_msgs_np, kappa_estimate, device='cuda'):
    """
    Compares the real noise distribution against two models:
    1. The Projected Normal ("add Gaussian then normalize") model.
    2. The von Mises-Fisher (vMF) model.

    Args:
        test_msgs_np (np.ndarray): Array of original (ground truth) messages.
        extracted_msgs_np (np.ndarray): Array of messages extracted by the decoder.
        kappa_estimate (float): The estimated kappa for the vMF model.
        device (str): The torch device to use for simulation.
    """
    print("\n" + "="*60)
    print("NOISE MODEL COMPARISON: Real vs. Projected Normal vs. vMF")
    print("="*60)

    # --- 1. Get statistics from REAL data ---
    real_cos_sims = np.einsum('ij,ij->i', test_msgs_np, extracted_msgs_np)

    # --- 2. Simulate data using the Projected Normal model ---
    
    # We need to find the 'sigma' that best matches the real data.
    # We can estimate it from kappa: sigma ≈ 1 / sqrt(kappa)
    sigma_estimate = 1.0 / np.sqrt(kappa_estimate)
    print(f"Derived sigma for Projected Normal model: {sigma_estimate:.6f} (from kappa ≈ 1/sigma²)")

    z = torch.from_numpy(test_msgs_np).to(device)
    noise = torch.randn_like(z) * sigma_estimate
    z_noisy_projected = F.normalize(z + noise, p=2, dim=-1)
    
    projected_normal_cos_sims = np.einsum('ij,ij->i', 
                                          test_msgs_np, 
                                          z_noisy_projected.cpu().numpy())

    # --- 3. Simulate data using the vMF model ---
    num_samples = test_msgs_np.shape[0]
    simulated_vmf_msgs = np.zeros_like(extracted_msgs_np)
    for i in tqdm(range(num_samples), desc="Simulating vMF samples"):
        simulated_vmf_msgs[i] = random_VMF(mu=test_msgs_np[i], kappa=kappa_estimate)
        
    vmf_cos_sims = np.einsum('ij,ij->i', test_msgs_np, simulated_vmf_msgs)

    # --- 4. Perform K-S tests to see which simulation is closer to the real data ---
    
    # Test 1: Projected Normal vs. Real Data
    ks_stat_proj, p_val_proj = stats.ks_2samp(real_cos_sims, projected_normal_cos_sims)
    print("\n--- K-S Test: Real Noise vs. Projected Normal ---")
    print(f"K-S Statistic: {ks_stat_proj:.4f}, P-value: {p_val_proj:.4f}")

    # Test 2: vMF vs. Real Data
    ks_stat_vmf, p_val_vmf = stats.ks_2samp(real_cos_sims, vmf_cos_sims)
    print("\n--- K-S Test: Real Noise vs. von Mises-Fisher ---")
    print(f"K-S Statistic: {ks_stat_vmf:.4f}, P-value: {p_val_vmf:.4f}")
    
    # --- 5. Conclusion ---
    print("\n--- Conclusion ---")
    if ks_stat_proj < ks_stat_vmf:
        print("🏆 The Projected Normal model (your method) is a slightly better fit to the real noise.")
        print("   This is great news, as it's much faster for training!")
    else:
        print("🏆 The von Mises-Fisher model is a slightly better statistical fit.")
        print("   However, the Projected Normal is likely a very close second and more practical.")

    # --- 6. Visualization ---
    plt.figure(figsize=(12, 7))
    sns.kdeplot(real_cos_sims, label='Real Decoder Noise', lw=3, color='black', fill=True, alpha=0.1)
    sns.kdeplot(projected_normal_cos_sims, label=f'Projected Normal (σ={sigma_estimate:.4f})', lw=2, color='blue', linestyle='--')
    sns.kdeplot(vmf_cos_sims, label=f'vMF (κ={kappa_estimate:.1f})', lw=2, color='red', linestyle=':')
    plt.title('Comparison of Cosine Similarity Distributions')
    plt.xlabel('Cosine Similarity')
    plt.ylabel('Density')
    plt.legend()
    plt.grid(alpha=0.4)
    plt.savefig("gaussian_or_wmf_noise.png")
    plt.show()

def second_main():
    model_dir = "/home/gevennou/videoseal/models"
    val_dir = "/home/gevennou/BIG_storage/Paper2/coco_dataset/val2017"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    batch_size = 1
    
    # ... [your existing setup code for dataloader etc.] ...
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(256),
        transforms.ToTensor()
    ])
    
    val_dataset = CocoValDataset(val_dir, transform=transform)
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size,
        shuffle=False,
        num_workers=4
    )

    print(f"Loaded {len(val_dataset)} validation images")

    encoder = torch.jit.load("/home/gevennou/videoseal/models/nautilus_900_enc.pth").to(device).eval()
    decoder = torch.jit.load("/home/gevennou/videoseal/models/nautilus_900_dec.pth").to(device).eval()
    
    # Collect noise data
    print("Collecting test and extracted messages...")
    test_msgs, extracted_msgs = collect_noise_data(encoder, decoder, val_loader, device, num_samples=5000)
    
    # Analyze noise characteristics
    print(f"\nAnalyzing noise characteristics from {len(test_msgs)} samples...")
    analysis_results = analyze_noise_characteristics(test_msgs, extracted_msgs)
    
    # Print summary
    print("\n" + "="*50)
    print("NOISE ANALYSIS SUMMARY")
    print("="*50)
    
    print(f"Is noise component distribution uniform? -> {analysis_results['is_uniform']}")
    print(f"Is noise component distribution Gaussian? -> {analysis_results['is_gaussian_component_wise']}")
    print(f"Mean Cosine Similarity: {analysis_results['mean_cosine_similarity']:.6f}")
    print(f"Estimated vMF Kappa (Concentration): {analysis_results['vmf_kappa_estimate']:.2f}")

    print("\n🔍 CONCLUSION:")
    print("The noise is NOT uniform. While it has some Gaussian-like properties component-wise,")
    print("the best model for the directional error is the von Mises-Fisher (vMF) distribution.")
    print("The high kappa value indicates that the extracted messages are tightly clustered around the original messages on the hypersphere.")

    # --- DEMONSTRATION OF THE NOISE SAMPLER ---
    print("\n" + "="*50)
    print("DEMONSTRATING THE NOISE SAMPLER")
    print("="*50)

    # Get the estimated kappa from our analysis
    estimated_kappa = analysis_results['vmf_kappa_estimate']
    
    # Take the first message from our test set as an example
    original_message = test_msgs[0]
    
    # Generate a new noisy sample
    sampled_noisy_message = sample_from_noise_model(original_message, estimated_kappa)
    
    # Compare the new sample to the original
    cos_sim_demo = torch.dot(original_message.cpu(), sampled_noisy_message) / (torch.norm(original_message.cpu()) * torch.norm(sampled_noisy_message))
    
    # Convert lists to numpy arrays for the validation function
    test_msgs_np = np.array([msg.cpu().numpy() for msg in test_msgs])
    extracted_msgs_np = np.array([msg.cpu().numpy().squeeze() for msg in extracted_msgs])

    validate_vmf_with_ttest(test_msgs_np, extracted_msgs_np, estimated_kappa)
    compare_noise_models(test_msgs_np, extracted_msgs_np, estimated_kappa, device=device)

    print(f"Original message norm: {torch.norm(original_message):.4f}")
    print(f"Sampled noisy message norm: {torch.norm(sampled_noisy_message):.4f}")
    print(f"Cosine similarity between original and sampled: {cos_sim_demo:.6f}")
    print("\nThis new `sampled_noisy_message` is a realistic simulation of a decoded message.")
    print(f"You can now generate realistic noise by calling `sample_from_noise_model(msg, kappa={estimated_kappa:.2f})`.")


if __name__ == "__main__":
    second_main()