#!/usr/bin/env python3
"""
Run convergence analysis on a SINGLE test point (index 0) with multiple N values.

This script:
1. Loads rotated MNIST test data
2. Selects test point at index 0
3. Runs estimation with N = [100, 500, 1000, 5000, 10000, 50000]
4. Uses n_trials = 10 for each N
5. Uses N=50000 as ground truth
6. Saves results for convergence plotting

Usage:
    python scripts/run_single_point_convergence_analysis.py \
        --sigma 0.5 \
        --eps_y 10.0 \
        --confidence 0.95 \
        --device cpu
"""

import torch
import numpy as np
import json
import argparse
from pathlib import Path
import sys
from datetime import datetime
from tqdm import tqdm

# Add paths
sys.path.append(str(Path(__file__).resolve().parent.parent / "src"))
sys.path.append(str(Path(__file__).resolve().parent.parent / "experiments/mnist_rotation"))

from bounded_certifier_convergence_analysis import BoundedCertifierConvergenceValidator
from dataset_generator import load_mnist_rotation_datasets
from e2cnn_rotation_model import RotationEquivariantCNN_Simple, cos_sin_to_angle


def load_model_and_data(model_path, device='cpu'):
    """Load trained model and test dataset."""
    # Load model
    model = RotationEquivariantCNN_Simple(N=8).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    print(f"✓ Loaded model from {model_path}")
    
    # Load MNIST rotation dataset with ground truth angles
    print("Loading MNIST rotation dataset with ground truth angles...")
    _, test_loader = load_mnist_rotation_datasets(
        rotation_range=(0.0, 360.0),
        augmentation_factor=1,
        batch_size=1,
        seed=42
    )
    print(f"✓ Loaded rotation test dataset ({len(test_loader.dataset)} images with ground truth angles)")
    
    return model, test_loader


def model_predict_angle(model, image_normalized_np, device='cpu'):
    """
    Predict angle from NORMALIZED image (numpy array).
    
    Args:
        model: Trained e2cnn model
        image_normalized_np: (28, 28) numpy array, already normalized
        device: Device
        
    Returns:
        Predicted angle in radians [-π, π]
    """
    # Convert to tensor [1, 1, 28, 28]
    image_tensor = torch.from_numpy(image_normalized_np).float().unsqueeze(0).unsqueeze(0).to(device)
    
    with torch.no_grad():
        pred_cos_sin = model(image_tensor)
    
    # Get angle in degrees [-180, 180]
    angle_degrees = cos_sin_to_angle(pred_cos_sin).item()
    
    # Convert to radians [-π, π]
    angle_radians = np.radians(angle_degrees)
    return angle_radians


def model_predict_angles_batch(model, images_normalized_np, device='cpu', batch_size=1000):
    """
    Predict angles from NORMALIZED images in batches for GPU efficiency.
    
    Args:
        model: Trained e2cnn model
        images_normalized_np: (N, 28, 28) numpy array, already normalized
        device: Device
        batch_size: Batch size for processing (larger for GPU)
        
    Returns:
        Array of predicted angles in radians [-π, π], shape (N,)
    """
    N = images_normalized_np.shape[0]
    angles_rad = np.zeros(N)
    
    model.eval()
    with torch.no_grad():
        for i in range(0, N, batch_size):
            end_idx = min(i + batch_size, N)
            batch_images = images_normalized_np[i:end_idx]
            
            # Convert to tensor [B, 1, 28, 28]
            batch_tensor = torch.from_numpy(batch_images).float().unsqueeze(1).to(device)
            
            # Predict in batch
            pred_cos_sin = model(batch_tensor)
            
            # Convert to angles
            angles_deg = cos_sin_to_angle(pred_cos_sin).cpu().numpy()
            angles_rad[i:end_idx] = np.radians(angles_deg)
    
    return angles_rad


def evaluate_model_at_point(model, image, noise, device='cpu'):
    """
    Evaluate model at a noisy version of the image.
    
    Returns scalar angle in radians (consistent with mnist_rotation_full_certification.py).
    
    Args:
        model: Trained model
        image: Original image numpy array [28, 28] in [0, 1] pixel space
        noise: Gaussian noise to add [28, 28]
        device: Device
        
    Returns:
        Predicted angle in radians (scalar)
    """
    # Add noise to image
    noisy_image = image + noise
    noisy_image = np.clip(noisy_image, 0.0, 1.0)
    
    # Normalize
    noisy_image_normalized = (noisy_image - 0.1307) / 0.3081
    
    # Predict angle (returns scalar angle in radians)
    angle_rad = model_predict_angle(model, noisy_image_normalized, device)
    return angle_rad


def run_single_point_convergence(model, test_loader, image_idx, sigma, eps_y_rad, 
                                 N_values, n_trials, confidence, device='cpu', seed=42):
    """
    Run convergence analysis on a single test point.
    
    Returns:
        Results dictionary with estimates for each N
    """
    # Get the test point
    dataset = test_loader.dataset
    sample = dataset[image_idx]
    
    if isinstance(sample, (list, tuple)) and len(sample) >= 2:
        image_tensor, true_angle_deg = sample[0], sample[1]
        if isinstance(true_angle_deg, torch.Tensor):
            true_angle_deg = true_angle_deg.item()
    else:
        # Try to get from dataset methods
        image_tensor = sample[0] if isinstance(sample, (list, tuple)) else sample
        true_angle_deg = dataset.get_true_angle(image_idx) if hasattr(dataset, 'get_true_angle') else None
    
    # Convert image to numpy
    if isinstance(image_tensor, torch.Tensor):
        image_np = image_tensor.squeeze().cpu().numpy()  # [28, 28]
    else:
        image_np = np.array(image_tensor).squeeze()
    
    # Get clean prediction
    clean_pred_rad = model_predict_angle(model, image_np, device)
    clean_pred_deg = np.degrees(clean_pred_rad)
    
    print(f"\nTest point {image_idx}:")
    print(f"  True angle: {true_angle_deg:.2f}°" if true_angle_deg is not None else "  True angle: unknown")
    print(f"  Clean prediction: {clean_pred_deg:.2f}°")
    
    # Initialize validator
    validator = BoundedCertifierConvergenceValidator(
        sigma=sigma,
        eps_y=eps_y_rad,
        confidence=confidence
    )
    
    # Estimate ground truth using N=50000 (largest N)
    print(f"\nEstimating ground truth using N=50000...")
    ground_truth_N = 50000
    rng_gt = np.random.default_rng(seed)
    eta_samples_gt = rng_gt.normal(0.0, sigma, size=(ground_truth_N, 28, 28))
    
    # Use batched processing for efficiency (consistent with mnist_rotation_full_certification.py)
    perturbed_images_gt = image_np[None, :, :] + eta_samples_gt  # (N, 28, 28)
    perturbed_images_gt = np.clip(perturbed_images_gt, 0.0, 1.0)
    perturbed_images_gt_normalized = (perturbed_images_gt - 0.1307) / 0.3081
    
    # Predict in batches
    batch_size = 1000 if device == 'cuda' else 100
    f_values_gt = model_predict_angles_batch(model, perturbed_images_gt_normalized, device=device, batch_size=batch_size)
    # f_values_gt is now (N,) array of scalar angles in radians
    
    # Estimate ground truth values using high-fidelity estimates (consistent with full_certification.py)
    # Mean (g_z): simple sample mean
    g_z_true = float(np.mean(f_values_gt))
    
    # Variance (C): sample variance
    C_true = float(np.var(f_values_gt, ddof=0))
    
    # Gradient norm (||G||): U-statistic estimator (consistent with full_certification.py)
    eta_samples_gt_flat = eta_samples_gt.reshape(ground_truth_N, -1)  # (N, 784)
    G_norm_true, _, _ = validator.u_statistic_gradient_norm_estimator_alpha_half(
        f_values_gt, eta_samples_gt_flat
    )
    G_norm_true = float(G_norm_true)
    
    print(f"  Ground truth: g_z={g_z_true:.6f}, C={C_true:.6f}, ||G||={G_norm_true:.6f}")
    
    # Run experiments for each N
    results = {
        'image_idx': image_idx,
        'true_angle_deg': float(true_angle_deg) if true_angle_deg is not None else None,
        'clean_pred_deg': float(clean_pred_deg),
        'ground_truth': {
            'g_z': float(g_z_true),
            'C': float(C_true),
            'G_norm': float(G_norm_true)
        },
        'parameters': {
            'sigma': sigma,
            'eps_y_deg': np.degrees(eps_y_rad),
            'eps_y_rad': float(eps_y_rad),
            'confidence': confidence,
            'N_values': N_values,
            'n_trials': n_trials,
            'ground_truth_N': ground_truth_N
        },
        'results_by_N': {N: [] for N in N_values}
    }
    
    trial_count = 0
    for N in N_values:
        print(f"\nTesting N={N}: ", end="", flush=True)
        for i in range(n_trials):
            trial_seed = seed + trial_count if seed is not None else None
            rng = np.random.default_rng(trial_seed)
            
            # Generate samples
            eta_samples = rng.normal(0.0, sigma, size=(N, 28, 28))
            
            # Evaluate model on samples - returns scalar angles in radians
            # Use batched processing for efficiency (consistent with mnist_rotation_full_certification.py)
            # Prepare batched inputs
            perturbed_images = image_np[None, :, :] + eta_samples  # (N, 28, 28)
            perturbed_images = np.clip(perturbed_images, 0.0, 1.0)
            
            # Normalize
            perturbed_images_normalized = (perturbed_images - 0.1307) / 0.3081
            
            # Predict in batches (much faster, especially on GPU)
            batch_size = 1000 if device == 'cuda' else 100
            f_values = model_predict_angles_batch(model, perturbed_images_normalized, device=device, batch_size=batch_size)
            # f_values is now (N,) array of scalar angles in radians
            
            # Reshape eta_samples from (N, 28, 28) to (N, 784) for gradient computation
            eta_samples_flat = eta_samples.reshape(len(eta_samples), -1)  # (N, 784)
            
            # Estimate variance using U-statistic (consistent with mnist_rotation_full_certification.py)
            C_hat, C_lower, C_upper = validator.u_statistic_variance_estimator_alpha_half(f_values)
            
            # Estimate θ = ||G||² using U-statistic with z-critical CI (consistent with full_certification.py)
            theta_hat, theta_lower, theta_upper = validator.compute_theta_ci_with_z_critical(
                f_values, eta_samples_flat, confidence=validator.confidence
            )
            
            # Estimate gradient norm (||G||) from variance + gradient estimator
            # This gives point estimate and CIs for ||G|| directly (consistent with full_certification.py)
            G_hat, G_lower, G_upper = validator.u_statistic_gradient_norm_estimator_alpha_half(
                f_values, eta_samples_flat
            )
            
            # Estimate mean (g_z) - simple sample mean with confidence interval
            # For union bound with 3 quantities (E, C, G), we use α/3 for each
            from scipy.stats import t
            alpha_total = 1 - confidence
            alpha_split = alpha_total / 3.0  # α/3 for union bound
            n = len(f_values)
            g_z_hat = float(np.mean(f_values))
            g_z_std = float(np.std(f_values, ddof=1))
            t_critical = t.ppf(1 - alpha_split / 2, df=n-1)
            se_g_z = g_z_std / np.sqrt(n)
            g_z_lower = g_z_hat - t_critical * se_g_z
            g_z_upper = g_z_hat + t_critical * se_g_z
            
            # Store results
            trial_result = {
                'C_hat': float(C_hat),
                'C_lower_analytical': float(C_lower),
                'C_upper_analytical': float(C_upper),
                'G_norm_hat': float(G_hat),
                'G_norm_lower': float(G_lower),
                'G_norm_upper': float(G_upper),
                'theta_hat': float(theta_hat),
                'theta_lower': float(theta_lower),
                'theta_upper': float(theta_upper),
                'g_z_hat': float(g_z_hat),
                'g_z_lower': float(g_z_lower),
                'g_z_upper': float(g_z_upper),
                'C_true': float(C_true),
                'G_norm_true': float(G_norm_true),
                'g_z_true': float(g_z_true),
                'N_samples': N,
                'trial': i
            }
            
            results['results_by_N'][N].append(trial_result)
            trial_count += 1
            
            if (i + 1) % 5 == 0:
                print(f"{i+1} ", end="", flush=True)
        print(f"({n_trials} trials)")
    
    return results


def main():
    parser = argparse.ArgumentParser(description="Run convergence analysis on single test point")
    parser.add_argument("--model_path", type=str,
                       default="experiments/mnist_rotation/e2cnn_rotation_model.pth",
                       help="Path to trained model")
    parser.add_argument("--image_idx", type=int, default=0,
                       help="Test image index (default: 0)")
    parser.add_argument("--sigma", type=float, default=0.5,
                       help="Noise standard deviation")
    parser.add_argument("--eps_y", type=float, default=10.0,
                       help="Output tolerance in degrees")
    parser.add_argument("--N_values", nargs="+", type=int,
                       default=[100, 500, 1000, 5000, 10000, 50000],
                       help="Sample sizes to test")
    parser.add_argument("--n_trials", type=int, default=10,
                       help="Number of trials per N")
    parser.add_argument("--confidence", type=float, default=0.95,
                       help="Confidence level")
    parser.add_argument("--device", type=str, default="cpu",
                       choices=["cpu", "cuda"],
                       help="Device to use")
    parser.add_argument("--seed", type=int, default=42,
                       help="Random seed")
    parser.add_argument("--output", type=str, default=None,
                       help="Output JSON file path (auto-generated if not provided)")
    
    args = parser.parse_args()
    
    eps_y_rad = np.radians(args.eps_y)
    
    print("="*80)
    print("SINGLE POINT CONVERGENCE ANALYSIS")
    print("="*80)
    print(f"Model: {args.model_path}")
    print(f"Test point index: {args.image_idx}")
    print(f"Sigma: {args.sigma}")
    print(f"Epsilon_y: {args.eps_y}° ({eps_y_rad:.4f} rad)")
    print(f"N values: {args.N_values}")
    print(f"Trials per N: {args.n_trials}")
    print(f"Confidence: {args.confidence}")
    print(f"Device: {args.device}")
    print("="*80)
    
    # Load model and data
    model, test_loader = load_model_and_data(args.model_path, args.device)
    
    # Run convergence analysis
    results = run_single_point_convergence(
        model=model,
        test_loader=test_loader,
        image_idx=args.image_idx,
        sigma=args.sigma,
        eps_y_rad=eps_y_rad,
        N_values=args.N_values,
        n_trials=args.n_trials,
        confidence=args.confidence,
        device=args.device,
        seed=args.seed
    )
    
    # Save results
    if args.output is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        args.output = f"mnist_rotation_convergence_img{args.image_idx}_sigma{args.sigma}_N{'_'.join(map(str, args.N_values))}_n{args.n_trials}_{timestamp}.json"
    
    with open(args.output, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"\n✓ Saved results to: {args.output}")
    print(f"\nNext step: Generate plots using:")
    print(f"  python3 scripts/plot_convergence_from_rotated_mnist_estimation.py \\")
    print(f"      --estimation_file {args.output} \\")
    print(f"      --output figures/convergence_mnist_rotated_bounded_sigma{args.sigma}.png")


if __name__ == "__main__":
    main()
