from typing import Dict

import torch
import numpy as np
from sklearn.metrics import roc_auc_score, roc_curve
from tqdm import tqdm

from .base import EvalBase
from ..utils.stats import batched_cdist


class DCR(EvalBase):
    """
    A class to compute the Distance-to-Closest-Record (DCR) metric, which measures
    how close synthetic samples are to real samples compared to a validation set.
    Inherits from EvalBase, which provides a `_evaluation` hook for custom evaluations.
    """

    def compute(self, real: torch.Tensor, val: torch.Tensor, fake: torch.Tensor) -> Dict[str, float]:
        """
        Compute the DCR score by comparing each row in the synthetic data to the closest
        point in the real dataset and the closest point in the validation dataset.

        Steps:
          1. For each synthetic sample, compute the distance to all real samples using L1 distance.
          2. Find the minimum distance to the real set.
          3. Do the same for the validation set.
          4. Compute the fraction of synthetic samples that are closer to the real set than to
             the validation set.

        Args:
            real (torch.Tensor):
                A tensor of real samples (N, D) where N is the number of real samples, D is the dimension.
            val (torch.Tensor):
                A tensor of validation samples (M, D) for comparison.
            fake (torch.Tensor):
                A tensor of synthetic samples (K, D).

        Returns:
            Dict[str, float]: A dictionary containing the DCR score:
                {
                    "score": float
                }
        """
        # Distance to real (per synthetic sample): min over real points
        dcr_real = batched_cdist(fake, real, p=1).nan_to_num(torch.inf).min(dim=-1).values

        # Distance to validation (per synthetic sample): min over val points
        dcr_val = batched_cdist(fake, val, p=1).nan_to_num(torch.inf).min(dim=-1).values

        # Fraction of synthetic samples whose min distance to real < min distance to val
        score = (dcr_real < dcr_val).float().mean()

        return dict(dcr_score=score.item())

    @torch.no_grad()
    def _evaluation(self, real_data, val_data, fake_data) -> Dict[str, float]:
        """
        Evaluate the DCR metric on real, validation, and fake data. In this approach,
        the real dataset is compared to the validation dataset for each synthetic sample.

        Args:
            real_data (pd.DataFrame):
                Real dataset as a DataFrame.
            val_data (pd.DataFrame):
                Validation dataset as a DataFrame.
            fake_data (pd.DataFrame):
                Synthetic dataset as a DataFrame.

        Returns:
            Dict[str, float]: Dictionary containing the computed DCR score. If either
                              real_data or val_data is empty, returns an empty dict.
        """
        if len(real_data) == 0 or len(val_data) == 0:
            return {}

        # Convert data to a numeric embedding
        real_embedding = self.transform.transform(real_data, scaler='minmax', onehot=True, return_as_tensor=True)
        val_embedding = self.transform.transform(val_data, scaler='minmax', onehot=True, return_as_tensor=True)
        fake_embedding = self.transform.transform(fake_data, scaler='minmax', onehot=True, return_as_tensor=True)

        return self.compute(real_embedding, val_embedding, fake_embedding)


class DPIMIA(EvalBase):
    """
    An implementation of the Data Plagiarism Index Membership Inference Attack (DPIMIA)
    from the paper:

        Ward, Joshua, Chi-Hua Wang, and Guang Cheng. "Data Plagiarism Index: Characterizing
        the Privacy Risk of Data-Copying in Tabular Generative Models."
        arXiv preprint arXiv:2406.13012 (2024).

    This evaluates how similar holdout data is to the reference (real) data vs.
    synthetic data, using an approach that looks for evidence of data copying.
    """

    @staticmethod
    def batched_cdist_score(
        x: torch.Tensor,
        y: torch.Tensor,
        group: torch.Tensor,
        p: float = 2.0,
        magic_number: float = 2.5e8,
        k: int = 30,
        device: str = 'cuda'
    ) -> torch.Tensor:
        """
        Compute a distance-based score for membership inference, comparing each row
        in `x` to rows in `y` using a batched distance computation. Instead of returning
        the full distance matrix, it only keeps top-k neighbors (lowest distances).
        Then, it computes a cumulative average of the group labels for those neighbors,
        effectively measuring how often `x` is closer to group=1 vs. group=0 samples.

        Args:
            x (torch.Tensor):
                The source tensor of shape (N, D).
            y (torch.Tensor):
                The target tensor of shape (M, D).
            group (torch.Tensor):
                A tensor of shape (M,) indicating which rows of `y` belong to group 0 or 1.
            p (float):
                The norm for distance calculation (2.0 = Euclidean).
            magic_number (float):
                A chunk size scaling factor to break up distance computations for large data.
            k (int):
                The number of top nearest neighbors to consider for each row in `x`.
            device (str):
                The device used for batch distance computation (e.g., "cuda" or "cpu").

        Returns:
            torch.Tensor: A tensor of shape (N, k), where each row corresponds to an
            element of `x` and each column holds a cumulative mean of the group label
            among the top-n neighbors.
        """
        _device = x.device
        y = y.to(device)
        group = group.to(device)
        if magic_number < 0:
            # Perform direct cdist in one shot if magic_number < 0
            return torch.cdist(x.to(device), y.to(device), p=p).to(_device)
        else:
            scores = []
            # Determine batch size by dividing magic_number by M
            magic_number = int(magic_number / len(y))

            # Batch the rows of x to avoid large memory usage
            for batch in tqdm(torch.split(x, magic_number)):
                # (batch_size, M)
                dist_batch = torch.cdist(batch[None].to(device), y[None].to(device), p=p)[0]
                dist_batch = dist_batch.nan_to_num(dist_batch.nan_to_num().max().item())
                nn_dist = -dist_batch.min(dim=-1, keepdim=True).values
                # For each row in the batch, find indices of the k closest neighbors
                ind = dist_batch.topk(k=k, dim=-1, largest=False).indices
                group_vals = group[ind]  # (batch_size, k)
                cumsum_vals = group_vals.cumsum(dim=-1)  # (batch_size, k)

                # For each row, compute partial means of group membership in the top neighbors
                div = torch.arange(1, k + 1, device=device, dtype=cumsum_vals.dtype).view(1, -1)
                means = cumsum_vals / div  # (batch_size, k)
                score = torch.cat([nn_dist, means], dim=1).to(_device)
                scores.append(score)

            return torch.cat(scores, dim=0)

    @torch.no_grad()
    def _evaluation(
        self,
        train,
        holdout,
        reference,
        fake
    ) -> Dict[str, float]:
        """
        Evaluate membership inference risk by comparing holdout and training data against
        a combined reference+synthetic set. The approach is:
          1. Embed or transform all datasets using the same transformations.
          2. Concatenate the reference dataset (label=0) with the fake dataset (label=1).
          3. Concatenate the holdout dataset (label=0) with the training dataset (label=1).
          4. For each row in holdout+train, compute top-k neighbors from reference+fake.
             Track how often the neighbors belong to label=1 (i.e., fake).
          5. Compare this to ground truth (which indicates whether the sample was from holdout or train).
          6. Compute ROC AUC across different k to see how well membership can be inferred.
          7. Return the best AUC and corresponding k.

        Args:
            train (pd.DataFrame): The training dataset used by the generative model.
            holdout (pd.DataFrame): A holdout dataset not seen by the model.
            reference (pd.DataFrame): The original real dataset (or a portion of it).
            fake (pd.DataFrame): The synthetic dataset produced by the model.

        Returns:
            Dict[str, float]: Dictionary containing:
                {
                    "roc_auc_score": float,  # Best ROC AUC among top-k neighbor checks
                    "best_k": int            # The k that yielded the best AUC
                }
        """
        # Transform data into embeddings/tensors
        train = self.transform.transform(train, scaler='minmax', onehot=True, return_as_tensor=True)
        holdout = self.transform.transform(holdout, scaler='minmax', onehot=True, return_as_tensor=True)
        reference = self.transform.transform(reference, scaler='minmax', onehot=True, return_as_tensor=True)
        fake = self.transform.transform(fake, scaler='minmax', onehot=True, return_as_tensor=True)

        # group=0 for reference, group=1 for fake
        group = torch.cat([torch.zeros(len(reference)), torch.ones(len(fake))])

        # ground truth: 0 for holdout, 1 for train
        gt = torch.cat([torch.zeros(len(holdout)), torch.ones(len(train))])

        # Combine reference and fake
        reference_fake = torch.cat([reference, fake])

        # Combine holdout and train
        holdout_train = torch.cat([holdout, train])

        # Compute group membership scores for each row in holdout_train
        scores = self.batched_cdist_score(
            holdout_train, reference_fake, group, p=1
        ).cpu().numpy()

        gt = gt.cpu().numpy()

        # For each column (top-1, top-2, ... top-k), compute the ROC AUC
        roc_auc_scores = [roc_auc_score(gt, score) for score in scores.T]
        best_k = np.argmax(roc_auc_scores)
        score = roc_auc_scores[best_k]

        return dict(
            roc_auc_score=score,
            roc_curve=roc_curve(gt, scores[:, best_k]),
            best_k=best_k if best_k > 0 else 'nn_distance'
        )
