from typing import Dict

import numpy as np
from sdmetrics.reports.single_table import QualityReport
from sdmetrics.single_table import LogisticDetection

from .base import EvalBase


class ColumnShapeTrend(EvalBase):
    """
    Evaluate synthetic data by comparing column-level shapes and pairwise trends
    against real data using SDMetrics' QualityReport.

    This class computes two main metrics:
      1. Column Shapes: How well each individual column in the synthetic data
         matches the real data distribution.
      2. Column Pair Trends: How well pairs of columns in the synthetic data
         preserve relationships found in the real data.

    The final score is an average of the "shape" and "trend" scores.

    Attributes:
        transform: A transform object that has a `columns` property specifying
            which columns should be included in the evaluation.
        metadata: Metadata needed by SDMetrics to understand data types or constraints.

    Inherits:
        EvalBase: A base class that defines an `_evaluation` method, which is
                  overridden here.
    """

    def _evaluation(self, real_data, fake_data) -> Dict:
        """
        Generate a QualityReport for the specified columns in real_data and fake_data,
        then compute a combined score based on column shapes and pairwise trends.

        Args:
            real_data (pd.DataFrame): The real dataset.
            fake_data (pd.DataFrame): The synthetic dataset to evaluate.

        Returns:
            Dict: A dictionary with the following structure:
                  {
                      "score": float,         # Average of shape and trend scores
                      "shape": float,         # Column Shapes score
                      "shape_error": float,   # (1 - shape)
                      "trend": float,         # Column Pair Trends score
                      "trend_error": float,   # (1 - trend)
                      "shapes": pd.DataFrame, # Detailed column shapes evaluation
                      "trends": pd.DataFrame, # Detailed pairwise trends evaluation
                  }
        """
        report_cls = QualityReport()

        # Ensure dtype consistency
        for col in self.transform.columns:
            fake_data[col] = fake_data[col].astype(real_data[col].dtype)

        # Generate the quality report using the relevant columns
        report_cls.generate(
            real_data[self.transform.columns],
            fake_data[self.transform.columns],
            self.metadata,
            verbose=len(self.transform.columns) >= 100
        )

        # Extract quality properties
        quality = report_cls.get_properties()

        # Details on shapes and trends
        shapes = report_cls.get_details(property_name='Column Shapes')
        trends = report_cls.get_details(property_name='Column Pair Trends')

        # Replace very small scores with NaN for stability
        trends.loc[trends['Score'] < 1e-8, 'Score'] = np.nan

        shape = quality['Score'][0]
        trend = trends['Score'].dropna().mean()
        score = (shape + trend) / 2

        return {
            'score': score,
            'shape': shape,
            'shape_error': 1 - shape,
            'trend': trend,
            'trend_error': 1 - trend,
            'shapes': shapes,
            'trends': trends
        }


class C2ST(EvalBase):
    """
    Evaluate synthetic data by performing a classifier two-sample test (C2ST)
    using LogisticDetection from SDMetrics.

    This test involves training a logistic regression model to distinguish
    between real and synthetic data. A higher score indicates the model was
    better at distinguishing the two datasets (implying synthetic data is less
    realistic), whereas a lower score suggests the synthetic data is closer
    to the real distribution.

    Attributes:
        transform: A transform object that has a `columns` property specifying
            which columns should be included in the evaluation.
        metadata: Metadata needed by SDMetrics to understand data types or constraints.

    Inherits:
        EvalBase: A base class defining an `_evaluation` method to be overridden.
    """

    def _evaluation(self, real_data, fake_data) -> Dict:
        """
        Compute the classification detection score of synthetic data against real data,
        using a logistic regression classifier.

        Args:
            real_data (pd.DataFrame): The real dataset.
            fake_data (pd.DataFrame): The synthetic dataset to evaluate.

        Returns:
            Dict: A dictionary with:
                  {
                      "score": float,   # The logistic detection score
                      "error": float,   # 1 - score
                  }
        """
        report_cls = LogisticDetection()
        score = report_cls.compute(
            real_data[self.transform.columns],
            fake_data[self.transform.columns],
            self.metadata
        )

        return {
            'score': score,
            'error': 1 - score
        }
