from __future__ import annotations

from typing import Optional, Tuple, Callable
from dataclasses import dataclass

from networkx import radius
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader, TensorDataset


# ============================================================================
# Core Functions
# ============================================================================

@torch.no_grad()
def sample_scaled_ball(
    n_samples: int,
    dim: int,
    radius: float,
    device: torch.device,
) -> torch.Tensor:
    """
    Sample uniformly on the sphere of given radius: {u : ||u|| = radius}.
    
    Args:
        n_samples: Number of points to sample
        dim: Dimension of the ball
        radius: Radius of the sphere (typically 1 - α)
        device: Torch device
    
    Returns:
        Tensor of shape (n_samples, dim) with ||u|| = radius
    """
    # Sample directions uniformly on the unit sphere
    u = torch.randn(n_samples, dim, device=device)
    u = u / (torch.linalg.norm(u, dim=-1, keepdim=True) + 1e-12)
    return radius * u


@torch.no_grad()
def approximate_region_boundary(
    model_G: nn.Module,
    x: torch.Tensor,
    radius: float,
    n_boundary_samples: int = 256,
) -> torch.Tensor:
    """
    Approximate the boundary of G(x, radius * B^d) by sampling.
    
    Args:
        model_G: PICNN model G(x, u) -> y
        x: Conditioning variable, shape (B, x_dim)
        radius: Radius of the ball (1 - α)
        n_boundary_samples: Number of samples on the sphere boundary
    
    Returns:
        Boundary points, shape (B, n_boundary_samples, y_dim)
    """
    device = x.device
    B, x_dim = x.shape
    y_dim = model_G.y_dim if hasattr(model_G, 'y_dim') else model_G.u_dim
    
    # Sample u on sphere of given radius
    u = sample_scaled_ball(n_boundary_samples, y_dim, radius, device)  # (M, d)
    
    # Expand for batch processing: (B, M, x_dim) and (B, M, d)
    x_exp = x.unsqueeze(1).expand(B, n_boundary_samples, x_dim)
    u_exp = u.unsqueeze(0).expand(B, n_boundary_samples, y_dim)
    
    # Flatten, forward, reshape
    x_flat = x_exp.reshape(B * n_boundary_samples, x_dim)
    u_flat = u_exp.reshape(B * n_boundary_samples, y_dim)
    
    y_boundary = model_G(x_flat, u_flat)  # (B*M, y_dim)
    y_boundary = y_boundary.reshape(B, n_boundary_samples, y_dim)
    
    return y_boundary


@torch.no_grad()
def distance_to_region(
    y: torch.Tensor,
    boundary_points: torch.Tensor,
    k_neighbors: int = 32,
) -> torch.Tensor:
    """
    Compute SIGNED distance from y to the boundary of G(x, (1-α)B^d).
    
    Distance signée:
        - d < 0 : y est à l'INTÉRIEUR de la région
        - d = 0 : y est sur la FRONTIÈRE
        - d > 0 : y est à l'EXTÉRIEUR de la région
    
    Méthode améliorée pour haute dimension:
    ───────────────────────────────────────
    En haute dimension (d >> 10), l'échantillonnage de la frontière est épars.
    Utiliser max(projections) sous-estime le rayon car aucun point de frontière
    n'est exactement dans la direction de y.
    
    SOLUTION: Interpolation k-NN
    1. Trouver les k points de frontière les plus proches de la direction de y
       (basé sur la similarité cosinus entre boundary_centered et direction)
    2. Moyenner leurs projections pour estimer le rayon
    
    Cela donne une estimation plus robuste du rayon dans la direction de y,
    même quand aucun point de frontière n'est exactement aligné.
    
    Args:
        y: Target points, shape (B, y_dim) or (N, y_dim) for MC sampling
        boundary_points: Boundary samples, shape (B, M, y_dim)
        k_neighbors: Number of nearest neighbors for radius interpolation
    
    Returns:
        Signed distances, shape (B,) or (N,)
        - Négatif: y à l'intérieur
        - Positif: y à l'extérieur
    """
    # Handle both cases: 
    # 1. y: (B, d), boundary: (B, M, d) - one boundary per point
    # 2. y: (N, d), boundary: (N, M, d) - MC sampling where N samples share same boundary
    
    if y.dim() == 2 and boundary_points.dim() == 3:
        B, M, d = boundary_points.shape
        N = y.shape[0]
        
        # Check if we're in MC sampling mode (N points, but boundary might be for 1 point)
        if N != B and B == 1:
            # Expand boundary for all samples
            boundary_points = boundary_points.expand(N, M, d)
            B = N
    else:
        B, M, d = boundary_points.shape
    
    # Adjust k_neighbors based on dimension and available points
    # In high-d, we need more neighbors for stable interpolation
    k = min(k_neighbors, M // 4, max(8, d * 2))
    
    # 1. Centroid de la frontière
    centroid = boundary_points.mean(dim=1)  # (B, d)
    
    # 2. Distance de y au centroid
    d_to_center = torch.linalg.norm(y - centroid, dim=-1)  # (B,)
    
    # 3. Direction unitaire de c vers y
    direction = y - centroid  # (B, d)
    direction_norm = torch.linalg.norm(direction, dim=-1, keepdim=True) + 1e-12
    direction_unit = direction / direction_norm  # (B, d)
    
    # 4. Points de frontière centrés et normalisés
    boundary_centered = boundary_points - centroid.unsqueeze(1)  # (B, M, d)
    boundary_norms = torch.linalg.norm(boundary_centered, dim=-1, keepdim=True) + 1e-12  # (B, M, 1)
    boundary_unit = boundary_centered / boundary_norms  # (B, M, d)
    
    # 5. Similarité cosinus entre direction et chaque point de frontière
    cos_sim = (boundary_unit * direction_unit.unsqueeze(1)).sum(dim=-1)  # (B, M)
    
    # 6. Trouver les k points les plus alignés avec la direction
    _, top_k_indices = torch.topk(cos_sim, k=k, dim=1)  # (B, k)
    
    # 7. Projections de tous les points sur la direction
    projections = (boundary_centered * direction_unit.unsqueeze(1)).sum(dim=-1)  # (B, M)
    
    # 8. Extraire les projections des k plus proches voisins directionnels
    # et calculer une moyenne pondérée par similarité
    top_k_projections = torch.gather(projections, 1, top_k_indices)  # (B, k)
    top_k_cos_sim = torch.gather(cos_sim, 1, top_k_indices)  # (B, k)
    
    # Pondération par similarité (softmax pour avoir des poids positifs)
    weights = torch.softmax(top_k_cos_sim * 10, dim=1)  # temperature=10 pour accentuer les meilleurs
    
    # Rayon estimé = moyenne pondérée des projections
    # On prend aussi le max pour ne pas sous-estimer
    weighted_radius = (top_k_projections * weights).sum(dim=1)  # (B,)
    max_radius = projections.max(dim=1).values  # (B,)
    
    # Utiliser le maximum entre interpolation et max direct
    radius_in_direction = torch.maximum(weighted_radius, max_radius).clamp(min=1e-8)
    
    # 9. Distance signée
    signed_dist = d_to_center - radius_in_direction
    
    return signed_dist


@torch.no_grad()
def distance_to_region_simple(
    y: torch.Tensor,
    boundary_points: torch.Tensor,
) -> torch.Tensor:
    """
    Version simple (originale) de distance_to_region.
    Gardée pour compatibilité et comparaison.
    """
    B, M, d = boundary_points.shape
    
    centroid = boundary_points.mean(dim=1)
    d_to_center = torch.linalg.norm(y - centroid, dim=-1)
    
    direction = y - centroid
    direction_norm = torch.linalg.norm(direction, dim=-1, keepdim=True) + 1e-12
    direction_unit = direction / direction_norm
    
    boundary_centered = boundary_points - centroid.unsqueeze(1)
    projections = (boundary_centered * direction_unit.unsqueeze(1)).sum(dim=-1)
    
    radius_in_direction = projections.max(dim=1).values.clamp(min=1e-8)
    signed_dist = d_to_center - radius_in_direction
    
    return signed_dist


# ============================================================================
# Optimized Distance Computation (Gradient-Based)
# ============================================================================

def _project_to_ball(u: Tensor, r0: float, eps: float = 1e-12) -> Tensor:
    """Project u onto the ball of radius r0."""
    n = u.norm(dim=1, keepdim=True).clamp_min(eps)
    factor = (r0 / n).clamp(max=1.0)
    return u * factor


def _normalize(v: Tensor, eps: float = 1e-12) -> Tensor:
    """Normalize vectors to unit norm."""
    return v / v.norm(dim=1, keepdim=True).clamp_min(eps)


@torch.no_grad()
def _sample_u_on_sphere(n: int, k: int, d: int, r0: float, device, dtype) -> Tensor:
    """
    Sample uniformly on sphere of radius r0.
    
    Returns:
        Tensor of shape (n, k, d) with ||u|| = r0
    """
    u = torch.randn(n, k, d, device=device, dtype=dtype)
    u = u / u.norm(dim=2, keepdim=True).clamp_min(1e-12)
    return r0 * u


@torch.no_grad()
def _sample_u_in_ball(n: int, k: int, d: int, r0: float, device, dtype) -> Tensor:
    """
    Sample uniformly in ball of radius r0.
    
    Uses direction * radius^(1/d) for uniform-ish distribution.
    
    Returns:
        Tensor of shape (n, k, d) with ||u|| <= r0
    """
    dir_ = torch.randn(n, k, d, device=device, dtype=dtype)
    dir_ = dir_ / dir_.norm(dim=2, keepdim=True).clamp_min(1e-12)
    rad = torch.rand(n, k, 1, device=device, dtype=dtype).pow(1.0 / d) * r0
    return dir_ * rad


def distance_to_region_uopt(
    Q: Callable[[Tensor, Tensor], Tensor],
    x: Tensor,                 # (N, p)
    y: Tensor,                 # (N, d)
    r0: float,                 # 1 - alpha
    n_sphere: int = 256,       # candidates on ||u||=r0
    n_ball: int = 256,         # candidates in ||u||<=r0
    refine_steps: int = 30,    # GD refinement
    lr: float = 5e-2,
    eps_u: float = 1e-3,       # "strictly inside" threshold on ||u|| (in u-space)
    chunk_size: int = 4096,    # for forward eval of candidates (flattened)
) -> Tuple[Tensor, Tensor, Tensor]:
    """
    Optimized signed distance computation using gradient-based refinement.
    
    This function computes the approximate signed distance from points y to the
    boundary of the region C_base(x) = Q(x, r0 * B^d), using gradient descent
    to refine initial candidates.
    
    Method:
    ───────
    1. Sample candidates on sphere (||u|| = r0) and in ball (||u|| <= r0)
    2. Select best initial candidates by minimum ||y - Q(x, u)||
    3. Refine boundary distance by optimizing v on sphere: u = r0 * v/||v||
    4. Refine inside-proxy by optimizing u in ball with projection
    5. Determine sign: if optimal u is strictly inside ball, y is inside region
    
    Args:
        Q: Model function Q(x, u) -> y (the PICNN geometric quantile map)
        x: Conditioning variables, shape (N, p)
        y: Target points, shape (N, d)
        r0: Radius of the ball (typically 1 - alpha)
        n_sphere: Number of candidates on sphere ||u|| = r0
        n_ball: Number of candidates in ball ||u|| <= r0
        refine_steps: Number of gradient descent refinement steps
        lr: Learning rate for Adam optimizer
        eps_u: Threshold for "strictly inside" in u-space
        chunk_size: Chunk size for forward evaluation of candidates
    
    Returns:
        signed_dist: (N,) Approximate signed distance to boundary ∂C_base(x)
                     - Positive: y is outside the region
                     - Negative: y is inside the region (proxy)
        d_boundary: (N,) Approximate inf_{||u||=r0} ||y - Q(x,u)||
        u_in_norm: (N,) Norm of the best u in the ball after refinement
    
    Notes:
        - Boundary distance is optimized on sphere: u = r0 * v/||v||
        - Inside/outside proxy: if argmin over ball lands strictly inside
          (||u|| < r0 - eps_u), we mark as inside.
    """
    device, dtype = x.device, x.dtype
    N = x.shape[0]
    d_u = y.shape[1]  # assuming u-dim == y-dim; adapt if different

    Q_was_training = getattr(Q, "training", False)
    if hasattr(Q, "eval"):
        Q.eval()

    # ---------- 1) init by sampling (no grad) ----------
    with torch.no_grad():
        U_s = _sample_u_on_sphere(N, n_sphere, d_u, r0, device, dtype)  # (N,K,d)
        U_b = _sample_u_in_ball(N, n_ball, d_u, r0, device, dtype)      # (N,K,d)

        def best_u_from_candidates(U: Tensor) -> Tensor:
            # U: (N,K,d) -> best u: (N,d)
            K = U.shape[1]
            U_flat = U.reshape(N * K, d_u)
            x_rep = x[:, None, :].expand(N, K, x.shape[1]).reshape(N * K, x.shape[1])
            y_rep = y[:, None, :].expand(N, K, d_u).reshape(N * K, d_u)

            # forward in chunks to avoid GPU spikes
            d2 = torch.empty(N * K, device=device, dtype=dtype)
            for s in range(0, N * K, chunk_size):
                e = min(s + chunk_size, N * K)
                y_hat = Q(x_rep[s:e], U_flat[s:e])
                d2[s:e] = ((y_hat - y_rep[s:e]) ** 2).sum(dim=1)

            d2 = d2.view(N, K)
            idx = d2.argmin(dim=1)  # (N,)
            return U[torch.arange(N, device=device), idx]  # (N,d)

        u_bd0 = best_u_from_candidates(U_s)  # init on sphere
        u_in0 = best_u_from_candidates(U_b)  # init in ball

    # ---------- 2) refine boundary distance: optimize v on sphere ----------
    v = (u_bd0 / max(r0, 1e-12)).clone().detach().requires_grad_(True)  # (N,d)
    opt_bd = torch.optim.Adam([v], lr=lr)

    for _ in range(refine_steps):
        opt_bd.zero_grad(set_to_none=True)
        u_bd = r0 * _normalize(v)
        y_hat = Q(x, u_bd)
        loss = ((y_hat - y) ** 2).sum(dim=1).mean()
        loss.backward()
        opt_bd.step()

    with torch.no_grad():
        u_bd = r0 * _normalize(v)
        d_boundary = (Q(x, u_bd) - y).norm(dim=1)

    # ---------- 3) refine inside-proxy: optimize u in ball ----------
    u_in = u_in0.clone().detach().requires_grad_(True)
    opt_in = torch.optim.Adam([u_in], lr=lr)

    for _ in range(refine_steps):
        opt_in.zero_grad(set_to_none=True)
        # projection handled after step
        y_hat = Q(x, _project_to_ball(u_in, r0))
        loss = ((y_hat - y) ** 2).sum(dim=1).mean()
        loss.backward()
        opt_in.step()
        with torch.no_grad():
            u_in.copy_(_project_to_ball(u_in, r0))

    with torch.no_grad():
        u_in_proj = _project_to_ball(u_in, r0)
        u_in_norm = u_in_proj.norm(dim=1)
        # inside proxy: strict interior in u-space
        inside = u_in_norm < (r0 - eps_u)
        signed_dist = d_boundary.clone()
        signed_dist[inside] = -signed_dist[inside]

    if Q_was_training and hasattr(Q, "train"):
        Q.train(True)

    return signed_dist, d_boundary, u_in_norm



def compute_nonconformity_scores(
    model_G: nn.Module,
    X: torch.Tensor,
    Y: torch.Tensor,
    alpha: float,
    n_boundary_samples: int = 256,
    batch_size: int = 128,
    refine_steps: int = 15,
    lr: float = 5e-2,
    optimize: bool = True,
) -> torch.Tensor:
    device = next(model_G.parameters()).device
    model_G.eval()

    N = X.shape[0]
    r0 = 1.0 - alpha
    scores = []

    for i in range(0, N, batch_size):
        x_batch = X[i:i+batch_size].to(device)
        y_batch = Y[i:i+batch_size].to(device)

        # IMPORTANT: enable grad because we optimize over u
        if optimize == True:
            with torch.enable_grad():
                signed_dist, _, _ = distance_to_region_uopt(
                    Q=model_G,
                    x=x_batch,
                    y=y_batch,
                    r0=r0,
                    n_sphere=n_boundary_samples,
                    n_ball=n_boundary_samples,
                    refine_steps=refine_steps,
                    lr=lr,
                )
        else:
            radius = 1.0 - alpha
            boundary = approximate_region_boundary(
            model_G, x_batch, radius, n_boundary_samples
            )
            signed_dist = distance_to_region(y_batch, boundary)

        scores.append(signed_dist.detach().cpu())

    return torch.cat(scores, dim=0)



# ============================================================================
# Split Conformal Prediction
# ============================================================================

@dataclass
class ConformalPICNNResult:
    """Result of conformal calibration."""
    alpha: float
    radius: float          # (1 - α) for initial region
    margin: float          # Calibrated margin q
    cal_scores: np.ndarray # Calibration scores
    
    def __repr__(self) -> str:
        return (
            f"ConformalPICNNResult(α={self.alpha:.3f}, "
            f"radius={self.radius:.3f}, margin={self.margin:.4f})"
        )


def calibrate_conformal_margin(
    model_G: nn.Module,
    X_cal: torch.Tensor,
    Y_cal: torch.Tensor,
    alpha: float = 0.1,
    n_boundary_samples: int = 256,
    batch_size: int = 128,
) -> ConformalPICNNResult:
    """
    Calibrate the conformal margin q using split conformal prediction.
    
    The final prediction region is: G(x, (1-α)B^d) ⊕ B(0, q)
    where q is the (1 - α)(1 + 1/n)-quantile of calibration scores.
    
    Args:
        model_G: Trained PICNN model
        X_cal: Calibration features, shape (n_cal, x_dim)
        Y_cal: Calibration targets, shape (n_cal, y_dim)
        alpha: Target miscoverage level
        n_boundary_samples: Samples for boundary approximation
        batch_size: Batch size
    
    Returns:
        ConformalPICNNResult with calibrated margin
    """
    # Compute nonconformity scores on calibration set
    scores = compute_nonconformity_scores(
        model_G, X_cal, Y_cal, alpha, n_boundary_samples, batch_size
    )
    scores_np = scores.numpy()
    scores_np_sorted = np.sort(scores_np)
    n_cal = len(scores_np)
    
    # Conformal quantile: ceil((n+1)(1-α)) / n
    quantile_level = np.ceil((n_cal + 1) * (1 - alpha))
    quantile_level = int(max(1, min(quantile_level, n_cal)))
    
    margin = scores_np_sorted[quantile_level - 1]
    
    return ConformalPICNNResult(
        alpha=alpha,
        radius=1.0 - alpha,
        margin=margin,
        cal_scores=scores_np,
    )


# ============================================================================
# Prediction Region
# ============================================================================

@torch.no_grad()
def predict_region(
    model_G: nn.Module,
    X: torch.Tensor,
    result: ConformalPICNNResult,
    n_boundary_samples: int = 256,
    batch_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compute prediction regions with conformal guarantee.
    
    Returns boundary points of G(x, (1-α)B^d) and the margin to add.
    The final region is the Minkowski sum: boundary ⊕ B(0, margin).
    
    Args:
        model_G: Trained PICNN model
        X: Test features, shape (N, x_dim)
        result: Calibration result with margin
        n_boundary_samples: Samples for boundary
        batch_size: Batch size
    
    Returns:
        boundary_points: Shape (N, n_boundary_samples, y_dim)
        margin: Scalar margin to expand the region
    """
    device = next(model_G.parameters()).device
    model_G.eval()
    
    N = X.shape[0]
    boundaries = []
    
    for i in range(0, N, batch_size):
        x_batch = X[i:i+batch_size].to(device)
        boundary = approximate_region_boundary(
            model_G, x_batch, result.radius, n_boundary_samples
        )
        boundaries.append(boundary.cpu())
    
    return torch.cat(boundaries, dim=0), result.margin


@torch.no_grad()
def check_coverage(
    model_G: nn.Module,
    X_test: torch.Tensor,
    Y_test: torch.Tensor,
    result: ConformalPICNNResult,
    n_boundary_samples: int = 256,
    batch_size: int = 128,
) -> Tuple[float, torch.Tensor]:
    """
    Check empirical coverage on test set according to GCQR definition.
    
    The prediction region is defined as (see Methods section):
    
    Case 1: margin >= 0 (EXPANSION)
        C(x) = C_base(x) ∪ {y ∉ C_base(x) : d(y, ∂C_base(x)) ≤ margin}
        → Point covered if: signed_dist <= margin
    
    Case 2: margin < 0 (CONTRACTION)
        C(x) = {y ∈ C_base(x) : d(y, ∂C_base(x)) >= |margin|}
        → Point covered if: signed_dist <= 0 AND signed_dist >= margin
        (i.e., inside base region, but not too close to boundary)
    
    Where signed_dist is:
        - Negative: y inside C_base(x)
        - Zero: y on boundary ∂C_base(x)
        - Positive: y outside C_base(x)
    
    Args:
        model_G: Trained PICNN model
        X_test: Test features
        Y_test: Test targets
        result: Calibration result
        n_boundary_samples: Samples for boundary
        batch_size: Batch size
    
    Returns:
        coverage: Empirical coverage rate
        is_covered: Boolean tensor of coverage for each point
    """
    scores = compute_nonconformity_scores(
        model_G, X_test, Y_test, result.alpha, n_boundary_samples, batch_size
    )
    
    # Apply coverage rule based on sign of margin
    # Unified formula: y is covered iff signed_dist <= margin
    # - margin >= 0 (expansion): covers base + external shell
    # - margin < 0 (contraction): covers only interior points with |signed_dist| >= |margin|
    #   (since signed_dist < 0 inside, signed_dist <= margin means deeper inside)
    is_covered = scores <= result.margin
    
    coverage = float(is_covered.float().mean())
    
    return coverage, is_covered


# ============================================================================
# Volume / Size Estimation
# ============================================================================

@torch.no_grad()
def estimate_region_volume(
    boundary_points: torch.Tensor,
    margin: float,
    n_mc_samples: int = 10000,
) -> torch.Tensor:
    """
    Estimate volume of prediction region via Monte Carlo.
    
    Volume of G(x, rB^d) ⊕ B(0, q) is estimated by sampling in a bounding box
    and counting points inside.
    
    Args:
        boundary_points: Shape (B, M, d)
        margin: Margin to add
        n_mc_samples: MC samples for volume estimation
    
    Returns:
        Estimated volumes, shape (B,)
    """
    B, M, d = boundary_points.shape
    device = boundary_points.device
    
    volumes = []
    for b in range(B):
        pts = boundary_points[b]  # (M, d)
        
        # Bounding box with margin
        mins = pts.min(dim=0).values - margin - 0.1
        maxs = pts.max(dim=0).values + margin + 0.1
        
        # Sample uniformly in bounding box
        samples = torch.rand(n_mc_samples, d, device=device)
        samples = samples * (maxs - mins) + mins
        
        # Check if inside: apply GCQR rule (expansion or contraction)
        # Compute signed distance to base region
        centroid = pts.mean(dim=0)
        boundary_expanded = pts.unsqueeze(0).expand(n_mc_samples, -1, -1)
        from GCQR_utils import distance_to_region
        signed_dists = distance_to_region(samples.unsqueeze(0), boundary_expanded.unsqueeze(0)).squeeze(0)
        
        # Apply GCQR rule: unified formula for expansion and contraction
        # y is in final region iff signed_dist <= margin
        inside = (signed_dists <= margin).float()
        
        box_volume = (maxs - mins).prod()
        vol = box_volume * inside.mean()
        volumes.append(vol)
    
    return torch.stack(volumes)


# ============================================================================
# Full Pipeline
# ============================================================================

class ConformalPICNN:
    """
    Conformal Prediction wrapper for PICNN geometric quantile maps.
    
    Usage:
        # After training model_G
        cp = ConformalPICNN(model_G, alpha=0.1)
        cp.calibrate(X_cal, Y_cal)
        
        # Predict
        coverage, is_covered = cp.evaluate(X_test, Y_test)
        boundary, margin = cp.predict(X_new)
    """
    
    def __init__(
        self,
        model_G: nn.Module,
        alpha: float = 0.1,
        n_boundary_samples: int = 2000,
        batch_size: int = 128,
    ):
        """
        Args:
            model_G: Trained PICNN model G(x, u) -> y
            alpha: Target miscoverage level (default 0.1 for 90% coverage)
            n_boundary_samples: Number of samples to approximate boundary
            batch_size: Batch size for processing
        """
        self.model_G = model_G
        self.alpha = alpha
        self.n_boundary_samples = n_boundary_samples
        self.batch_size = batch_size
        self.result: Optional[ConformalPICNNResult] = None
    
    def calibrate(
        self,
        X_cal: torch.Tensor,
        Y_cal: torch.Tensor,
    ) -> ConformalPICNNResult:
        """
        Calibrate conformal margin on calibration set.
        
        Args:
            X_cal: Calibration features, shape (n_cal, x_dim)
            Y_cal: Calibration targets, shape (n_cal, y_dim)
        
        Returns:
            ConformalPICNNResult with calibrated margin
        """
        self.result = calibrate_conformal_margin(
            self.model_G,
            X_cal,
            Y_cal,
            self.alpha,
            self.n_boundary_samples,
            self.batch_size,
        )
        print(f"[ConformalPICNN] Calibrated: {self.result}")
        return self.result
    
    def predict(
        self,
        X: torch.Tensor,
    ) -> Tuple[torch.Tensor, float]:
        """
        Get prediction regions.
        
        Args:
            X: Test features, shape (N, x_dim)
        
        Returns:
            boundary_points: Shape (N, n_boundary_samples, y_dim)
            margin: Scalar margin (region = boundary ⊕ B(0, margin))
        """
        if self.result is None:
            raise ValueError("Must call calibrate() before predict()")
        
        return predict_region(
            self.model_G,
            X,
            self.result,
            self.n_boundary_samples,
            self.batch_size,
        )
    
    def evaluate(
        self,
        X_test: torch.Tensor,
        Y_test: torch.Tensor,
    ) -> Tuple[float, torch.Tensor]:
        """
        Evaluate coverage on test set.
        
        Args:
            X_test: Test features
            Y_test: Test targets
        
        Returns:
            coverage: Empirical coverage rate
            is_covered: Boolean tensor for each point
        """
        if self.result is None:
            raise ValueError("Must call calibrate() before evaluate()")
        
        return check_coverage(
            self.model_G,
            X_test,
            Y_test,
            self.result,
            self.n_boundary_samples,
            self.batch_size,
        )
    
    def score(
        self,
        X: torch.Tensor,
        Y: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute nonconformity scores.
        
        Args:
            X: Features
            Y: Targets
        
        Returns:
            scores: Distance to region for each point
        """
        return compute_nonconformity_scores(
            self.model_G,
            X,
            Y,
            self.alpha,
            self.n_boundary_samples,
            self.batch_size,
        )


# ============================================================================
# Example Usage
# ============================================================================

if __name__ == "__main__":
    from geo_conditional_quantile_icnn_gpu_amp_stable_widecore import (
        ICNNGeometricQuantileMap,
        GeoQuantileTrainer,
    )
    
    # Synthetic data
    torch.manual_seed(42)
    N = 1000
    x_dim, y_dim = 5, 2
    
    X = torch.randn(N, x_dim)
    Y = X[:, :y_dim] + 0.3 * torch.randn(N, y_dim)
    
    # Split: train / cal / test
    n_train = 600
    n_cal = 200
    X_train, Y_train = X[:n_train], Y[:n_train]
    X_cal, Y_cal = X[n_train:n_train+n_cal], Y[n_train:n_train+n_cal]
    X_test, Y_test = X[n_train+n_cal:], Y[n_train+n_cal:]
    
    # Train PICNN
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = ICNNGeometricQuantileMap(
        x_dim=x_dim, u_dim=y_dim, y_dim=y_dim,
        width=64, depth=3
    ).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    trainer = GeoQuantileTrainer(model, optimizer, device=device)
    
    train_loader = DataLoader(
        TensorDataset(X_train, Y_train),
        batch_size=64, shuffle=True
    )
    trainer.fit(train_loader, n_epochs=50, verbose=True)
    
    # Conformal prediction
    cp = ConformalPICNN(model, alpha=0.1)
    cp.calibrate(X_cal, Y_cal)
    
    # Evaluate
    coverage, is_covered = cp.evaluate(X_test, Y_test)
    print(f"\nTest coverage: {coverage:.1%} (target: {1 - cp.alpha:.1%})")
