from typing import Dict, Tuple

import numpy as np
import torch
import pandas as pd

from .base import EvalBase
from ..utils.stats import batched_cdist


class AlphaPrecision(EvalBase):
    """
    Evaluate the "alpha precision" and "beta coverage" of a synthetic dataset relative
    to a real dataset. This metric examines how well the synthetic data covers the
    region of real data (beta coverage) and how precisely it fits into that region
    (alpha precision), along with an "authenticity" measure.

    This class inherits from EvalBase, which provides a `_evaluation` hook for
    custom evaluations. The key method here is `metrics_torch()`, which implements
    the main distance-based calculation using PyTorch.
    """

    @torch.no_grad()
    def metrics_torch(
        self,
        X: torch.Tensor,
        X_syn: torch.Tensor,
        emb_center: torch.Tensor = None,
        device: str = "cuda",
        magic_number: float = 2.5e8
    ) -> Tuple[
        torch.Tensor,       # alphas array
        torch.Tensor,       # alpha_precision_curve
        torch.Tensor,       # beta_coverage_curve
        float,            # Delta_precision_alpha
        float,            # Delta_coverage_beta
        float             # authenticity
    ]:
        """
        Compute alpha-precision and beta-coverage curves by comparing real and synthetic
        data in an embedding space. The "authenticity" is also measured by comparing the
        distance of each real point to its closest synthetic neighbor.

        Args:
            X (torch.Tensor):
                Real data embedding array of shape (N, D), where N is the number of samples
                and D is the embedding dimension.
            X_syn (torch.Tensor):
                Synthetic data embedding array of the same shape (N, D). The method
                requires the synthetic dataset to have the same number of rows as the real dataset.
            emb_center (torch.Tensor, optional):
                Center of the embedding (e.g., the mean of `X`) for distance-based calculations.
                If None, defaults to the mean of `X`.
            device (str):
                The device to use for batched distance computation. Default is "cuda".
            magic_number (float):
                A scaling parameter to handle large-scale distance computations in `batched_cdist`.

        Returns:
            Tuple containing:
              1) alphas (np.ndarray): Array of alpha thresholds (0 to 1).
              2) alpha_precision_curve (np.ndarray): The fraction of synthetic points within
                 a given alpha-derived radius around `emb_center`.
              3) beta_coverage_curve (np.ndarray): The fraction of real points whose nearest
                 synthetic neighbor is close in a manner consistent with real-to-real distances.
              4) Delta_precision_alpha (float): Summary statistic for alpha precision.
              5) Delta_coverage_beta (float): Summary statistic for beta coverage.
              6) authenticity (float): Fraction of real points for which the nearest synthetic
                 neighbor is closer than the nearest real neighbor (i.e., “authentic” matches).
        """
        # Convert arrays to PyTorch tensors
        X_torch = X.double()
        X_syn_torch = X_syn.double()

        # Compute or derive center
        if emb_center is None:
            emb_center_torch = X_torch.mean(dim=0)
        else:
            emb_center_torch = emb_center.double()

        # Ensure same shape
        if X_torch.shape[0] != X_syn_torch.shape[0]:
            raise RuntimeError("The real and synthetic data must have the same length.")

        # Setup for alpha steps
        n_steps = 30
        alphas_torch = torch.linspace(0, 1, n_steps).double()

        # Distance from real points to emb_center
        dist_X_center = torch.sqrt(((X_torch - emb_center_torch) ** 2).sum(dim=1))  # (N,)
        Radii_torch = torch.quantile(dist_X_center, alphas_torch)  # (n_steps,)

        # Distances from synthetic points to emb_center
        synth_center_torch = X_syn_torch.mean(dim=0)
        synth_to_center = torch.sqrt(((X_syn_torch - emb_center_torch) ** 2).sum(dim=1))  # (N,)

        # Real-to-real distances for nearest neighbor
        real_dist = batched_cdist(X_torch, X_torch, p=2.0, magic_number=magic_number, device=device)  # (N, N)
        real_values, _ = real_dist.topk(k=2, dim=1, largest=False)  # (N, 2) for each row
        real_to_real = real_values[:, 1]  # (N,)

        # Real-to-synthetic distances
        real_synth_dist = batched_cdist(X_torch, X_syn_torch, p=2.0, magic_number=magic_number, device=device)
        real_to_synth, real_to_synth_args = real_synth_dist.min(dim=1)

        # For each real point, find the closest synthetic point
        real_synth_closest = X_syn_torch[real_to_synth_args]
        real_synth_closest_d = torch.sqrt(((real_synth_closest - synth_center_torch) ** 2).sum(dim=1))

        # Radii for the “closest synthetic” distribution
        closest_synth_Radii_torch = torch.quantile(real_synth_closest_d, alphas_torch)

        alpha_precision_curve = []
        beta_coverage_curve = []

        # Compute alpha precision & beta coverage at each alpha step
        for k in range(n_steps):
            radius_k = Radii_torch[k]
            # alpha precision: fraction of synthetic points within radius_k
            precision_audit_mask = (synth_to_center <= radius_k).double()
            alpha_precision = precision_audit_mask.mean()

            # beta coverage: fraction of real points for which real_to_synth <= real_to_real,
            # and whose synthetic match is within the same alpha radius
            coverage_mask = (
                (real_to_synth <= real_to_real) &
                (real_synth_closest_d <= closest_synth_Radii_torch[k])
            ).double()
            beta_coverage = coverage_mask.mean()

            alpha_precision_curve.append(alpha_precision.item())
            beta_coverage_curve.append(beta_coverage.item())

        # Authenticity: For each synthetic neighbor, compare real_to_real_for_synth with real_to_synth
        real_to_real_for_synth = real_to_real[real_to_synth_args]
        authen = (real_to_real_for_synth < real_to_synth)
        authenticity = authen.double().mean().item()

        # Convert lists to torch tensors
        alpha_precision_curve_t = torch.tensor(alpha_precision_curve)
        beta_coverage_curve_t = torch.tensor(beta_coverage_curve)

        # Summaries: Delta_precision_alpha and Delta_coverage_beta
        diff_alpha_precision = torch.abs(alphas_torch - alpha_precision_curve_t).sum()
        Delta_precision_alpha = 1.0 - diff_alpha_precision / alphas_torch.sum()

        diff_beta_coverage = torch.abs(alphas_torch - beta_coverage_curve_t).sum()
        Delta_coverage_beta = 1.0 - diff_beta_coverage / alphas_torch.sum()

        # Convert final results to numpy
        alphas = alphas_torch.cpu().numpy()
        alpha_precision_curve = np.array(alpha_precision_curve)
        beta_coverage_curve = np.array(beta_coverage_curve)

        return (
            alphas,
            alpha_precision_curve,
            beta_coverage_curve,
            Delta_precision_alpha.item(),
            Delta_coverage_beta.item(),
            authenticity,
        )

    def _evaluation(self, real_data: pd.DataFrame, fake_data: pd.DataFrame) -> Dict[str, float]:
        """
        Evaluate alpha precision and beta coverage metrics on real vs. synthetic data.
        Real and fake data must be of the same length.

        Steps:
          1. Trim `fake_data` to the size of `real_data`.
          2. Transform both real and fake data to a numeric embedding (with a minmax scaler and one-hot encoding).
          3. Call `metrics_torch` to compute alpha-precision curves, coverage curves, and summary statistics.

        Args:
            real_data (pd.DataFrame):
                The real dataset, with columns matching the transform configuration.
            fake_data (pd.DataFrame):
                The synthetic dataset to be compared, at least as large as `real_data`.

        Returns:
            Dict[str, float]: A dictionary containing:
                - alpha_precision (float): Summary alpha precision metric (Delta_precision_alpha).
                - alpha_precision_error (float): 1 - alpha_precision.
                - beta_recall (float): Summary beta coverage metric (Delta_coverage_beta).
                - beta_recall_error (float): 1 - beta_recall.
                - alpha_precision_curve (np.ndarray): The computed alpha precision curve.
                - beta_coverage_curve (np.ndarray): The computed beta coverage curve.
        """
        if len(real_data) == 0:
            return {}

        # Match length
        fake_data = fake_data.iloc[:len(real_data)]

        # Transform real and synthetic data into embeddings
        real_embedding = self.transform.transform(
            real_data, scaler='minmax', onehot=True, return_as_tensor=True
        )

        # `self.transform.embedding` is used for the synthetic data if available
        fake_embedding = self.transform.transform(
            fake_data, scaler='minmax', onehot=True, return_as_tensor=True
        )

        if real_embedding.isnan().any() or fake_embedding.isnan().any():
            return {}

        (
            alphas_naive,
            alpha_precision_curve_naive,
            beta_coverage_curve_naive,
            Delta_precision_alpha_naive,
            Delta_coverage_beta_naive,
            authenticity_naive,
        ) = self.metrics_torch(real_embedding, fake_embedding, emb_center=None)

        alpha_precision = Delta_precision_alpha_naive
        beta_recall = Delta_coverage_beta_naive

        return {
            'alpha_precision': alpha_precision,
            'alpha_precision_error': 1 - alpha_precision,
            'beta_recall': beta_recall,
            'beta_recall_error': 1 - beta_recall,
            'alpha_precision_curve': alpha_precision_curve_naive,
            'beta_coverage_curve': beta_coverage_curve_naive,
            'authenticity': authenticity_naive
        }


class EnergyDistance(EvalBase):
    """Measure the discrepancy between real and synthetic embeddings via energy distance."""

    @torch.no_grad()
    def energy_distance_torch(
        self,
        X: torch.Tensor,
        X_syn: torch.Tensor,
        device: str = "cuda",
        magic_number: float = 2.5e8
    ) -> Tuple[float, float, float, float]:
        """
        Compute the sample energy distance between two embedding tensors.

        Args:
            X (torch.Tensor): Real data embedding of shape (N, D).
            X_syn (torch.Tensor): Synthetic data embedding of shape (M, D).
            device (str): Device used by `batched_cdist` for distance computation.
            magic_number (float): Chunking factor passed to `batched_cdist`.

        Returns:
            Tuple[float, float, float, float]:
                energy distance, cross mean distance, real self-distance, synthetic self-distance.
        """
        X_real = X.double()
        X_fake = X_syn.double()

        n_real = X_real.shape[0]
        n_fake = X_fake.shape[0]

        if n_real == 0 or n_fake == 0:
            return 0.0, 0.0, 0.0, 0.0

        cross_dist = batched_cdist(
            X_real, X_fake, p=1.0, magic_number=magic_number, device=device
        )
        cross_mean = cross_dist.mean()

        if n_real > 1:
            real_dist = batched_cdist(
                X_real, X_real, p=1.0, magic_number=magic_number, device=device
            )
            real_self = real_dist.sum() / (n_real * (n_real - 1))
        else:
            real_self = torch.tensor(0.0, dtype=torch.double, device=X_real.device)

        if n_fake > 1:
            fake_dist = batched_cdist(
                X_fake, X_fake, p=1.0, magic_number=magic_number, device=device
            )
            fake_self = fake_dist.sum() / (n_fake * (n_fake - 1))
        else:
            fake_self = torch.tensor(0.0, dtype=torch.double, device=X_fake.device)

        energy_distance = 2.0 * cross_mean - real_self - fake_self

        return (
            energy_distance.item(),
            cross_mean.item(),
            real_self.item(),
            fake_self.item()
        )

    def _evaluation(self, real_data: pd.DataFrame, fake_data: pd.DataFrame) -> Dict[str, float]:
        """
        Transform inputs and report sample energy distance statistics.

        Returns:
            Dict[str, float]: Summary metrics keyed by distance component.
        """
        if len(real_data) == 0 or len(fake_data) == 0:
            return {}

        fake_data = fake_data.iloc[:len(real_data)]

        real_embedding = self.transform.transform(
            real_data, scaler='minmax', onehot=True, return_as_tensor=True
        )
        fake_embedding = self.transform.transform(
            fake_data, scaler='minmax', onehot=True, return_as_tensor=True
        )

        if real_embedding.isnan().any() or fake_embedding.isnan().any():
            return {}

        energy_distance, cross_mean, real_self, fake_self = self.energy_distance_torch(
            real_embedding, fake_embedding
        )

        return {
            'energy_distance': energy_distance,
            'cross_mean_distance': cross_mean,
            'real_self_distance': real_self,
            'synthetic_self_distance': fake_self
        }

