"""
Utilities for generating out-of-distribution points.

This module provides functions for generating OOD points with controllable
difficulty levels. OOD points are generated by sampling from a grid in the
label space and selecting points based on their distance to training data.

Difficulty Levels:
- "easy": Points far from training data (mean dist ~0.30)
- "medium": Points at moderate distance (mean dist ~0.10)
- "hard": Points just outside the boundary (mean dist ~0.04)

The "hard" level creates a more challenging OOD detection task by selecting
points that are geometrically close to the training distribution but still
outside it.
"""
import torch
from typing import Optional, Tuple, Dict


# Pre-defined difficulty configurations
# These are calibrated for normalized labels in [0, 1] range
DIFFICULTY_CONFIGS = {
    "easy": {"min_dist": 0.15, "max_dist": 1.0},    # Far from data
    "medium": {"min_dist": 0.05, "max_dist": 0.15}, # Moderate distance
    "hard": {"min_dist": 0.02, "max_dist": 0.08},   # Near boundary
}


def get_ood_points(
    train_labels: torch.Tensor,
    threshold: float = 0.05,
    grid_steps: int = 25,
    max_threshold: Optional[float] = None,
    grid_margin: float = 0.1,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Generate OOD points as complement of training distribution.

    Creates a grid in the label space and selects points based on their
    **normalized** distance to the nearest training point. Distances are
    computed after scaling each dimension to [0, 1], so thresholds are
    relative to the data range (e.g., 0.02 means 2% of the range).

    Args:
        train_labels: Training labels tensor (N, L)
        threshold: Minimum normalized distance from training data (default: 0.05).
                   This is relative to data range per dimension.
        grid_steps: Number of grid points per dimension (default: 25)
        max_threshold: Optional maximum normalized distance from training data.
                       If provided, only points with distance in
                       [threshold, max_threshold] are selected.
                       This enables "near-boundary" OOD generation.
        grid_margin: Margin to extend beyond data bounds (as fraction of range).
                     Default 0.1 means 10% margin on each side.

    Returns:
        ood_points: Tensor of OOD points in original scale (M, L)
        ood_points_dists: Normalized distance to nearest training point for each OOD point
    """
    device = train_labels.device if hasattr(train_labels, 'device') else 'cpu'

    # Get label dimension
    L = train_labels.shape[1]

    # Compute data statistics for normalization
    data_min = train_labels.min(dim=0).values
    data_max = train_labels.max(dim=0).values
    data_range = data_max - data_min

    # Avoid division by zero for constant dimensions
    data_range = torch.where(data_range > 1e-8, data_range, torch.ones_like(data_range))

    # Normalize training labels to [0, 1] per dimension
    train_labels_norm = (train_labels - data_min) / data_range

    # Grid bounds with margin (in normalized space, this is [-margin, 1+margin])
    grid_min_norm = -grid_margin
    grid_max_norm = 1.0 + grid_margin

    # Create grid in normalized space
    grid_1d = torch.linspace(grid_min_norm, grid_max_norm, grid_steps, device=device)

    if L == 3:
        # Create 3D meshgrid
        grid_x, grid_y, grid_z = torch.meshgrid(grid_1d, grid_1d, grid_1d, indexing='ij')
        grid_norm = torch.stack([grid_x.flatten(), grid_y.flatten(), grid_z.flatten()], dim=1)
    elif L == 2:
        grid_x, grid_y = torch.meshgrid(grid_1d, grid_1d, indexing='ij')
        grid_norm = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=1)
    else:
        raise ValueError(f"Only 2D and 3D labels supported, got {L}D")

    print(f"Grid shape: {grid_norm.shape}")

    # Compute distances in normalized space
    # Process in batches to avoid memory issues
    batch_size = 1000
    ood_points_norm = []
    ood_dists = []

    for i in range(0, grid_norm.shape[0], batch_size):
        grid_batch = grid_norm[i:i+batch_size]  # (batch, L)

        # Compute pairwise distances in normalized space
        dists = torch.cdist(grid_batch, train_labels_norm, p=2)  # (batch, N_train)

        # Get minimum distance to training data for each grid point
        min_dists, _ = torch.min(dists, dim=1)  # (batch,)

        # Select points based on distance thresholds (in normalized space)
        if max_threshold is not None:
            mask = (min_dists >= threshold) & (min_dists <= max_threshold)
        else:
            mask = min_dists >= threshold

        ood_points_norm.append(grid_batch[mask])
        ood_dists.append(min_dists[mask])

    ood_points_norm = torch.cat(ood_points_norm, dim=0)
    ood_dists = torch.cat(ood_dists, dim=0)

    # Convert OOD points back to original scale
    ood_points = ood_points_norm * data_range + data_min

    if max_threshold is not None:
        print(f"Found {ood_points.shape[0]} OOD points out of {grid_norm.shape[0]} grid points "
              f"(normalized dist in [{threshold:.3f}, {max_threshold:.3f}])")
    else:
        print(f"Found {ood_points.shape[0]} OOD points out of {grid_norm.shape[0]} grid points")

    return ood_points, ood_dists


def get_ood_points_by_difficulty(
    train_labels: torch.Tensor,
    difficulty: str = "medium",
    n_points: Optional[int] = None,
    grid_steps: int = 30,
    seed: int = 42,
) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
    """Generate OOD points with controllable difficulty level.

    Difficulty levels control how close OOD points are to the training distribution:
    - "easy": Points far from training data (mean dist ~0.30)
    - "medium": Points at moderate distance (mean dist ~0.10)
    - "hard": Points just outside the training distribution boundary (mean dist ~0.04)

    The "hard" level is useful for evaluating UQ methods on more challenging
    OOD detection tasks where points are geometrically similar to training data.

    Args:
        train_labels: Training labels tensor (N, L)
        difficulty: One of "easy", "medium", "hard"
        n_points: Target number of OOD points. If None, returns all found points.
                  If provided and more points are found, subsamples randomly.
        grid_steps: Grid resolution for sampling (default: 30)
        seed: Random seed for subsampling (default: 42)

    Returns:
        ood_points: Tensor of OOD points
        ood_dists: Distance to nearest training point
        config: Dictionary with the configuration used, including:
            - difficulty: The difficulty level used
            - min_dist: Minimum distance threshold
            - max_dist: Maximum distance threshold
            - n_found: Number of points found before subsampling
            - grid_steps: Grid resolution used

    Example:
        >>> train_labels = torch.rand(1000, 3)
        >>> ood_points, dists, config = get_ood_points_by_difficulty(
        ...     train_labels, difficulty="hard", n_points=500
        ... )
        >>> print(f"Mean dist to training: {dists.mean():.3f}")
    """
    if difficulty not in DIFFICULTY_CONFIGS:
        raise ValueError(f"Unknown difficulty: {difficulty}. Use one of: {list(DIFFICULTY_CONFIGS.keys())}")

    config = DIFFICULTY_CONFIGS[difficulty]
    min_dist = config["min_dist"]
    max_dist = config["max_dist"]

    # Use finer grid for hard difficulty to find more near-boundary points
    if difficulty == "hard":
        grid_steps = max(grid_steps, 40)

    # Generate OOD points using the threshold-based function
    ood_points, ood_dists = get_ood_points(
        train_labels,
        threshold=min_dist,
        grid_steps=grid_steps,
        max_threshold=max_dist,
    )

    n_found = ood_points.shape[0]

    # Subsample if requested and we have more points than needed
    if n_points is not None and ood_points.shape[0] > n_points:
        torch.manual_seed(seed)
        perm = torch.randperm(ood_points.shape[0])[:n_points]
        ood_points = ood_points[perm]
        ood_dists = ood_dists[perm]

    result_config = {
        "difficulty": difficulty,
        "min_dist": min_dist,
        "max_dist": max_dist,
        "n_found": n_found,
        "n_returned": ood_points.shape[0],
        "grid_steps": grid_steps,
    }

    return ood_points, ood_dists, result_config
