import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
from sklearn.tree import export_text
from scipy.stats import (
    ks_2samp, kstest, chisquare,
    wasserstein_distance, entropy,
    skew, kurtosis, energy_distance,
    gaussian_kde
)
from scipy.spatial.distance import jensenshannon, cdist
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, roc_auc_score, pairwise_kernels, f1_score
from sklearn.metrics.pairwise import rbf_kernel

font_a = 20
font_lgd = 14
font_ticks = 13


def get_age_group(age):
    """Maps age to 'young'/'middle'/'old' categories."""
    if pd.isna(age):
        return np.nan
    if age < 35:
        return "young"
    elif age < 55:
        return "middle"
    else:
        return "old"


def compute_joint_mmd(real_data, synth_data, feature_types, kernel="rbf", **kernel_kwargs):
    """
    Computes the Maximum Mean Discrepancy (MMD^2) between real and synthetic data.
    
    Args:
        - real_data, synth_data: pd.DataFrame or dict (converted to DataFrame)
        - feature_types: dict, {feature_name: "numerical" or "categorical"}
        - kernel: kernel type, default is "rbf"
        - kernel_kwargs: additional parameters for pairwise_kernels
        
    Returns:
        - mmd_value: float, the MMD^2 value
    """
    
    def ensure_df(data):
        """Ensures input data is a DataFrame."""
        if isinstance(data, pd.DataFrame):
            return data.copy()
        elif isinstance(data, dict):
            return pd.DataFrame.from_dict(data, orient='index')
        else:
            return pd.DataFrame(data)

    real_df = ensure_df(real_data)
    synth_df = ensure_df(synth_data)

    parts_real, parts_synth = [], []
    for col, ftype in feature_types.items():
        if ftype == "numerical":
            parts_real.append(real_df[[col]].astype(float))
            parts_synth.append(synth_df[[col]].astype(float))
        else:
            parts_real.append(pd.get_dummies(real_df[col], prefix=col))
            parts_synth.append(pd.get_dummies(synth_df[col], prefix=col))

    Xr = pd.concat(parts_real, axis=1).fillna(0).astype(float)
    Xg = pd.concat(parts_synth, axis=1).fillna(0).astype(float)

    # Align columns
    Xr_aligned, Xg_aligned = Xr.align(Xg, join='outer', axis=1, fill_value=0)

    # Compute MMD^2 = E[k(X,X)] + E[k(Y,Y)] - 2 E[k(X,Y)]
    K_xx = pairwise_kernels(Xr_aligned, Xr_aligned, metric=kernel, **kernel_kwargs)
    K_yy = pairwise_kernels(Xg_aligned, Xg_aligned, metric=kernel, **kernel_kwargs)
    K_xy = pairwise_kernels(Xr_aligned, Xg_aligned, metric=kernel, **kernel_kwargs)

    mmd2 = K_xx.mean() + K_yy.mean() - 2 * K_xy.mean()
    return mmd2


def plot_numeric_feature(feature, real_series, synth_series, save_path,
                         bw_method='scott', grid_points=200):
    """
    Plots a KDE (Kernel Density Estimation) comparison between real and synthetic data.

    Args:
        - feature: Feature name for title and x-axis label
        - real_series: pandas Series, real data
        - synth_series: pandas Series, synthetic data
        - save_path: path to save the plot
        - bw_method: KDE bandwidth method, passed to scipy.stats.gaussian_kde
        - grid_points: number of grid points for KDE calculation
    """
    COLORS = {
        "Real":  {"line": "#836796", "fill": "#A58DB3"},
        "Synth": {"line": "#346A97", "fill": "#7EA2DF"}
    }

    min_val = min(real_series.min(), synth_series.min())
    max_val = max(real_series.max(), synth_series.max())
    x_grid = np.linspace(min_val, max_val, grid_points)

    kde_real  = gaussian_kde(real_series.dropna(), bw_method=bw_method)
    kde_synth = gaussian_kde(synth_series.dropna(), bw_method=bw_method)
    y_real  = kde_real(x_grid)
    y_synth = kde_synth(x_grid)

    plt.figure(figsize=(6, 6))
    plt.plot(x_grid, y_real, color=COLORS["Real"]["line"], label='Real')
    plt.fill_between(x_grid, y_real, color=COLORS["Real"]["fill"], alpha=0.4)
    plt.plot(x_grid, y_synth, color=COLORS["Synth"]["line"], label='Synth')
    plt.fill_between(x_grid, y_synth, color=COLORS["Synth"]["fill"], alpha=0.4)

    plt.xlabel(feature, fontsize=font_a)
    plt.ylabel("Density", fontsize=font_a)
    plt.title(f"{feature}", fontsize=font_a)
    plt.legend(fontsize=font_a)
    plt.savefig(save_path, dpi=300, bbox_inches='tight', transparent=True)
    plt.close()


def plot_categorical_feature(feature, real_series, synth_series, save_path):
    """
    Plots a bar chart comparison of categorical features between real and synthetic data.
    
    Args:
        - feature: Feature name for title and x-axis label
        - real_series: pandas Series, real data
        - synth_series: pandas Series, synthetic data
        - save_path: path to save the plot
    """
    COLORS = {
        "Real":  {"line": "#836796", "fill": "#A58DB3"},
        "Synth": {"line": "#346A97", "fill": "#7EA2DF"}
    }

    categories = sorted(set(real_series.dropna().unique()) | set(synth_series.dropna().unique()))

    real_counts  = real_series.value_counts().reindex(categories, fill_value=0)
    synth_counts = synth_series.value_counts().reindex(categories, fill_value=0)
    real_freq    = real_counts  / real_counts.sum()
    synth_freq   = synth_counts / synth_counts.sum()

    x     = np.arange(len(categories))
    width = 0.35  # bar width

    plt.figure(figsize=(6, 6))
    plt.bar(x - width/2, real_freq, width=width, color=COLORS["Real"]["fill"],
            alpha=0.4, label="Real", edgecolor=COLORS["Real"]["line"])
    plt.bar(x + width/2, synth_freq, width=width, color=COLORS["Synth"]["fill"],
            alpha=0.4, label="Synth", edgecolor=COLORS["Synth"]["line"])

    if feature == "product_category":
        plt.xticks(x, categories, rotation=20, fontsize=font_ticks)
    else:
        plt.xticks(x, categories, fontsize=font_ticks)

    plt.xlabel(feature, fontsize=font_a)
    plt.ylabel("Relative Frequency", fontsize=font_a)
    plt.title(f"{feature}", fontsize=font_a)
    plt.legend(fontsize=font_lgd)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight', transparent=True)
    plt.close()


class Evaluator:
    def __init__(self, config, result_dir):
        self.config = config
        self.result_dir = result_dir
        feature_types = {
            k: "categorical" if "categories" in v else "numerical"
            for k, v in config.items()
        }
        categorical_values = {
            k: v['categories']
            for k, v in config.items() if 'categories' in v
        }
        self.evaluator = DataDistributionEvaluator(feature_types, categorical_values)

    @staticmethod
    def _price_stats(df):
        """Returns mean, std, and 75th quantile by product_category."""
        tbl = df.groupby("product_category")["price"].agg(
            mean="mean", std="std", q3=lambda x: x.quantile(0.75)
        )
        return tbl.to_dict(orient="index")

    @staticmethod
    def _build_discount_propensity(df, price_tbl):
        """Builds the discount propensity label based on the price stats."""
        def score(row):
            μ = price_tbl[row["product_category"]]["mean"]
            σ = price_tbl[row["product_category"]]["std"] or 1
            z = (row["price"] - μ) / σ
            age = row["user_age"]
            pm = row["payment_method"]
            loc = row["location_tier"]
            s = (
                -np.tanh(z)
                + 0.01 * (age - 35) ** 2 / 100
                + (0.5 if pm == "Cash on Delivery" else -0.2)
                + (0.3 if loc == "Developing" else 0)
            )
            return "High" if s > 1 else "Low" if s < -1 else "Mid"
        return df.apply(score, axis=1)

    @staticmethod
    def _build_ltv_band(df):
        """Builds the lifetime value band based on user attributes."""
        def clv_score(row):
            base = row["price"]
            age = row["user_age"]
            pm = row["payment_method"]
            cat = row["product_category"]

            age_score = np.log1p(abs(age - 35))
            method_weight = 1.2 if pm == "Online Payment" else 0.85
            category_weight = {
                "Electronics": 1.3,
                "Apparel": 1.1,
                "Food & Beverages": 0.9,
                "Furniture & Appliances": 1.4
            }.get(cat, 1.0)

            clv = np.sqrt(base) * method_weight / (age_score + 1) * category_weight
            return "High" if clv > 20 else "Low" if clv < 10 else "Mid"
        return df.apply(clv_score, axis=1)

    def _ml_efficiency(self, target, feature_cols, real_df, syn_df):
        """
        Assesses machine learning efficiency for synthetic vs real data.
        
        Args:
            - target: target variable
            - feature_cols: list of feature columns
            - real_df: real data DataFrame
            - syn_df: synthetic data DataFrame
            
        Returns:
            - dict: Model evaluation results for Logistic Regression, Decision Tree, and Random Forest
        """
        train_df = syn_df[feature_cols + [target]].dropna()
        test_df = real_df[feature_cols + [target]].dropna()

        if len(train_df) < 50 or len(test_df) < 50 or train_df[target].nunique() < 2 or test_df[target].nunique() < 2:
            empty = dict(train_size=len(train_df), test_size=len(test_df),
                         acc=None, f1_macro=None, auc=None)
            return {"LR": empty, "DT": empty, "RF": empty}

        X_train = pd.get_dummies(train_df[feature_cols], drop_first=False)
        X_test = pd.get_dummies(test_df[feature_cols], drop_first=False)
        X_test = X_test.reindex(columns=X_train.columns, fill_value=0)
        y_train = train_df[target]
        y_test = test_df[target]

        models = {
            "LR": LogisticRegression(max_iter=1000, multi_class="multinomial"),
            "DT": DecisionTreeClassifier(random_state=0),
            "RF": RandomForestClassifier(random_state=0)
        }

        out = {}
        for name, clf in models.items():
            clf.fit(X_train, y_train)
            y_pred = clf.predict(X_test)
            acc = accuracy_score(y_test, y_pred)
            f1 = f1_score(y_test, y_pred, average="macro")
            try:
                y_prob = clf.predict_proba(X_test)
                auc = roc_auc_score(y_test, y_prob, multi_class="ovr", average="macro")
            except Exception:
                auc = None
            out[name] = {
                "train_size": len(train_df),
                "test_size": len(test_df),
                "acc": round(float(acc), 4),
                "f1_macro": round(float(f1), 4),
                "auc": round(float(auc), 4) if auc is not None else None
            }
        return out

    def evaluate(self, real_data, synthetic_data, eval_joint=True):
        """
        Evaluate the quality of synthetic data by comparing it to real data.
        
        Args:
            - real_data: real data as a dictionary or DataFrame
            - synthetic_data: synthetic data as a dictionary or DataFrame
            - eval_joint: whether to evaluate joint distributions (default: True)
            
        Returns:
            - marginal_results: evaluation of marginal distributions
            - joint_results: evaluation of joint distributions (if eval_joint=True)
        """
        print(type(real_data), type(synthetic_data))
        print(len(real_data), len(synthetic_data))

        marginal_results = self.evaluator.evaluate_marginals(real_data, synthetic_data)
        if eval_joint:
            joint_results = {
                'joint_age_gender_product': self.evaluator.evaluate_joint(
                    ['user_age', 'user_gender', 'product_category'],
                    real_data, synthetic_data),
                'joint_age_gender_product': self.evaluator.evaluate_joint(
                    ['user_age', 'product_category'],
                    real_data, synthetic_data),
                'joint_age_gender_product': self.evaluator.evaluate_joint(
                    ['user_gender', 'product_category'],
                    real_data, synthetic_data),
                'joint_product_price': self.evaluator.evaluate_joint(
                    ['product_category', 'price'],
                    real_data, synthetic_data),
                'joint_location_payment': self.evaluator.evaluate_joint(
                    ['location_tier', 'payment_method'],
                    real_data, synthetic_data),
            }

            # Convert real/synthetic data to DataFrame
            real_df = pd.DataFrame.from_dict(real_data, orient='index')
            syn_df  = pd.DataFrame.from_dict(synthetic_data, orient='index')

            product_cats = sorted(set(real_df['product_category']).union(syn_df['product_category']))
            genders = sorted(set(real_df['user_gender']).union(syn_df['user_gender']))
            COLORS = {'Male': {'line': '#836796', 'fill': '#A58DB3'}, 'Female': {'line': '#346A97', 'fill': '#7EA2DF'}}

            for label, df in [('Real', real_df), ('Synthetic', syn_df)]:
                fig, ax = plt.subplots(figsize=(6, 6))
                n_cat = len(product_cats)
                n_hue = len(genders)
                width = 0.35

                data = []
                positions = []
                for j, cat in enumerate(product_cats):
                    for i, gender in enumerate(genders):
                        vals = df.loc[(df['product_category'] == cat) & (df['user_gender'] == gender), 'user_age'].dropna()
                        data.append(vals)
                        pos = j + (i - (n_hue - 1) / 2) * width
                        positions.append(pos)

                # Plot boxplot
                bp = ax.boxplot(data, positions=positions, widths=width, patch_artist=True, showfliers=False)

                # Set boxplot colors and transparency
                for idx, box in enumerate(bp['boxes']):
                    gender = genders[idx % n_hue]
                    fc = to_rgba(COLORS[gender]['fill'], alpha=0.5)
                    ec = to_rgba(COLORS[gender]['line'], alpha=0.5)
                    box.set_facecolor(fc)
                    box.set_edgecolor(ec)
                    box.set_alpha(0.5)

                # Set whiskers, caps, and medians colors
                for line in bp['whiskers'] + bp['caps'] + bp['medians']:
                    line.set_color('grey')

                ax.set_xticks(range(n_cat))
                ax.set_xticklabels(product_cats, fontsize=13, rotation=20)
                ax.set_xlabel('Product Category', fontsize=20)
                ax.set_ylabel('User Age', fontsize=20)
                ax.set_title(f'Product Category by Age & Gender ({label})', fontsize=20)

                # Custom legend
                handles = [plt.Line2D([], [], color=COLORS[g]['line'], lw=3, alpha=0.7) for g in genders]
                ax.legend(handles, genders, loc='upper left', ncol=1, fontsize=14)

                # Save figure
                fig.savefig(os.path.join(self.result_dir, f'joint_age_gender_product_{label}.png'),
                            dpi=300, bbox_inches='tight', transparent=True)
                plt.close(fig)

            # 2) Plot price x product_category (single color boxplot)
            first_key = list(COLORS.keys())[1]
            fill_rgba = to_rgba(COLORS[first_key]['fill'], alpha=0.5)
            line_rgba = to_rgba(COLORS[first_key]['line'], alpha=0.5)

            for label, df in [('Real', real_df), ('Synthetic', syn_df)]:
                fig, ax = plt.subplots(figsize=(6, 6))
                data = [df.loc[df['product_category'] == cat, 'price'].dropna() for cat in product_cats]

                bp = ax.boxplot(data, positions=range(len(product_cats)), widths=0.6, patch_artist=True, showfliers=False)
                for box in bp['boxes']:
                    box.set_facecolor(fill_rgba)
                    box.set_edgecolor(line_rgba)
                    box.set_alpha(0.5)
                for line in bp['whiskers'] + bp['caps'] + bp['medians']:
                    line.set_color(line_rgba)
                    line.set_alpha(0.5)

                ax.set_xticks(range(len(product_cats)))
                ax.set_xticklabels(product_cats, fontsize=13, rotation=20)
                ax.set_xlabel('Product Category', fontsize=20)
                ax.set_ylabel('Price', fontsize=20)
                ax.set_title(f'Price by Product Category ({label})', fontsize=20)

                # Save figure
                fig.savefig(os.path.join(self.result_dir, f'joint_product_price_{label}.png'),
                            dpi=300, bbox_inches='tight', transparent=True)
                plt.close(fig)

            # 3) Plot payment_method x location_tier grouped bar chart
            COLORS = {'Cash on Delivery': {'line': '#836796', 'fill': '#A58DB3'}, 'Online Payment': {'line': '#346A97', 'fill': '#7EA2DF'}}
            location_tiers = sorted(set(real_df['location_tier']).union(syn_df['location_tier']))
            payment_methods = sorted(set(real_df['payment_method']).union(syn_df['payment_method']))

            for label, df in [('Real', real_df), ('Synthetic', syn_df)]:
                # Calculate frequency and reindex to ensure consistent order
                ct = df.groupby(['location_tier', 'payment_method']).size().unstack(fill_value=0)
                ct = ct.reindex(index=location_tiers, columns=payment_methods, fill_value=0)
                freq = ct.div(ct.sum(axis=1), axis=0)

                fig, ax = plt.subplots(figsize=(6, 6))

                # Set fill and line colors for each payment method
                fill_colors = [to_rgba(COLORS[pm]['fill'], alpha=0.5) for pm in payment_methods]
                line_colors = [to_rgba(COLORS[pm]['line'], alpha=0.5) for pm in payment_methods]

                freq.plot(kind='bar', width=0.8, color=fill_colors, edgecolor=line_colors, alpha=0.5, ax=ax)

                ax.set_title(f'Payment Method by Location Tier ({label})', fontsize=20)
                ax.set_xlabel('Location Tier', fontsize=20)
                ax.set_ylabel('Relative Frequency', fontsize=20)
                ax.set_xticklabels(location_tiers, rotation=0, fontsize=14)
                ax.legend(loc='upper right', fontsize=14)

                # Save figure
                fig.savefig(os.path.join(self.result_dir, f'joint_location_payment_{label}.png'),
                            dpi=300, bbox_inches='tight', transparent=True)
                plt.close(fig)

        else:
            joint_results = {}

        # Marginal distributions visualization
        real_df = pd.DataFrame.from_dict(real_data, orient='index')
        syn_df  = pd.DataFrame.from_dict(synthetic_data, orient='index')
        for feature in self.config:
            path = os.path.join(self.result_dir, f"marginal_{feature}_dist.png")
            if pd.api.types.is_numeric_dtype(real_df[feature]):
                plot_numeric_feature(feature,
                                     real_df[feature],
                                     syn_df[feature],
                                     save_path=path)
            else:
                plot_categorical_feature(feature,
                                         real_df[feature],
                                         syn_df[feature],
                                         save_path=path)

        # DataFrame processing for synthetic and real data
        real_df = self.evaluator._ensure_dataframe(real_data).copy()
        syn_df = self.evaluator._ensure_dataframe(synthetic_data).copy()

        # Price stats and derived variables
        stats_real = self._price_stats(real_df)

        real_df["discount_propensity"] = self._build_discount_propensity(real_df, stats_real)
        syn_df["discount_propensity"] = self._build_discount_propensity(syn_df, stats_real)

        real_df["lifetime_value_band"] = self._build_ltv_band(real_df)
        syn_df["lifetime_value_band"] = self._build_ltv_band(syn_df)

        # ML-efficiency evaluation
        ml_results = {}

        ml_results["discount_propensity"] = self._ml_efficiency(
            target="discount_propensity",
            feature_cols=["user_age", "price", "product_category",
                          "payment_method"],
            real_df=real_df, syn_df=syn_df)

        ml_results["lifetime_value_band"] = self._ml_efficiency(
            target="lifetime_value_band",
            feature_cols=["user_age", "price", "payment_method"],
            real_df=real_df, syn_df=syn_df)

        joint_results["ml_efficiency"] = ml_results
        return marginal_results, joint_results


class DataDistributionEvaluator:
    def __init__(self, feature_types, categorical_values=None):
        """
        Initializes the evaluator with feature types and possible categorical values.
        
        Args:
            - feature_types: dict, mapping feature names to "numerical"/"categorical"
            - categorical_values: dict, mapping categorical feature names to possible categories
        """
        self.feature_types = feature_types
        self.categorical_values = categorical_values or {}

    def _ensure_dataframe(self, data):
        """Ensures input data is a DataFrame."""
        if isinstance(data, pd.DataFrame):
            return data.copy()
        elif isinstance(data, dict):
            return pd.DataFrame.from_dict(data, orient='index')
        else:
            return pd.DataFrame(data)

    def evaluate_marginals(self, real_data, gen_data, real_pdf_funcs=None):
        """
        Evaluates marginal distributions for real vs synthetic data.

        Args:
            - real_data: real data (DataFrame or dict)
            - gen_data: generated data (DataFrame or dict)
            - real_pdf_funcs: optional PDF functions for analytic tests
        
        Returns:
            - results: Dictionary containing evaluation results
        """
        real_df = self._ensure_dataframe(real_data)
        gen_df = self._ensure_dataframe(gen_data)
        results = {}

        for feature, ftype in self.feature_types.items():
            results[feature] = {}
            if ftype == 'numerical':
                if feature in real_df and feature in gen_df:
                    # Sample KS Test
                    data_real = real_df[feature].dropna().values
                    data_gen  = gen_df[feature].dropna().values
                    try:
                        ks_res = ks_2samp(data_gen, data_real)
                    except Exception as e:
                        print(f'Error in KS test for {feature}: {e}, data_gen: {data_gen}')

                    results[feature]['sample_ks'] = {
                        'statistic': ks_res.statistic.item(),
                        'pvalue':    ks_res.pvalue.item()
                    }

                    # Wasserstein Distance
                    wd = wasserstein_distance(data_real, data_gen)
                    results[feature]['wasserstein_distance'] = float(wd)

                    # Energy Distance
                    ed = energy_distance(data_real, data_gen)
                    results[feature]['energy_distance'] = float(ed)

                    # Maximum Mean Discrepancy (MMD)
                    vals = np.concatenate([data_real, data_gen])[:, None]
                    dists = np.abs(vals - vals.T)
                    bw = np.median(dists)
                    if bw <= 0:
                        bw = 1.0
                    gamma = 1.0 / (2 * bw**2)
                    K_rr = rbf_kernel(data_real[:, None], data_real[:, None], gamma=gamma)
                    K_gg = rbf_kernel(data_gen[:, None],  data_gen[:, None],  gamma=gamma)
                    K_rg = rbf_kernel(data_real[:, None], data_gen[:, None], gamma=gamma)
                    mmd = K_rr.mean() + K_gg.mean() - 2 * K_rg.mean()
                    results[feature]['mmd'] = float(mmd)

                    # Moments Difference
                    results[feature]['mean_diff']  = abs(data_real.mean() - data_gen.mean())
                    results[feature]['std_diff']   = abs(data_real.std()  - data_gen.std())
                    results[feature]['skew_diff']  = abs(skew(data_real)   - skew(data_gen))
                    results[feature]['kurtosis_diff'] = abs(kurtosis(data_real) - kurtosis(data_gen))
                else:
                    results[feature]['sample_ks'] = None

                # Analytic KS
                if real_pdf_funcs and feature in real_pdf_funcs:
                    cdf_f = real_pdf_funcs[feature].get('cdf')
                    if cdf_f is not None:
                        data_gen = gen_df[feature].dropna().values
                        ks_an = kstest(data_gen, cdf_f)
                        results[feature]['analytic_ks'] = {
                            'statistic': ks_an.statistic.item(),
                            'pvalue':    ks_an.pvalue.item()
                        }
                    else:
                        results[feature]['analytic_ks'] = None
            # Handle categorical features
            elif ftype == 'categorical':
                if feature in self.categorical_values:
                    cats = self.categorical_values[feature]
                else:
                    cats = sorted(set(real_df[feature].dropna().unique()) | set(gen_df[feature].dropna().unique()))

                if feature in real_df and feature in gen_df:
                    real_counts = real_df[feature].value_counts().reindex(cats, fill_value=0)
                    gen_counts  = gen_df[feature].value_counts().reindex(cats, fill_value=0)
                    if real_counts.sum() > 0:
                        exp = gen_counts.sum() * (real_counts / real_counts.sum())
                        chi = chisquare(f_obs=gen_counts, f_exp=exp)
                        results[feature]['sample_chi2'] = {
                            'statistic': chi.statistic.item(),
                            'pvalue':    chi.pvalue.item()
                        }
                    else:
                        results[feature]['sample_chi2'] = None
                else:
                    results[feature]['sample_chi2'] = None

                # Analytic Chi2
                if real_pdf_funcs and feature in real_pdf_funcs:
                    probs = real_pdf_funcs[feature]
                    total = gen_df.shape[0]
                    exp_a, obs = [], []
                    for cat in cats:
                        exp_a.append(total * probs.get(cat, 0))
                        obs.append(gen_df[feature].value_counts().get(cat, 0))
                    if sum(exp_a) > 0:
                        chi_a = chisquare(f_obs=obs, f_exp=exp_a)
                        results[feature]['analytic_chi2'] = {
                            'statistic': chi_a.statistic.item(),
                            'pvalue':    chi_a.pvalue.item()
                        }
                    else:
                        results[feature]['analytic_chi2'] = None

                # Total Variation Distance (TVD)
                gen_freq = gen_df[feature].value_counts().reindex(cats, fill_value=0) / gen_df.shape[0]
                real_freq = real_df[feature].value_counts().reindex(cats, fill_value=0) / real_df.shape[0]
                results[feature]['sample_tvd'] = float(0.5 * np.sum(np.abs(gen_freq - real_freq)))

                # Symmetric KL & Hellinger
                eps = 1e-8
                p = real_freq.values + eps
                q = gen_freq.values  + eps
                kl_rq = entropy(p, q)
                kl_qr = entropy(q, p)
                results[feature]['sample_kl'] = 0.5 * (kl_rq + kl_qr)
                results[feature]['sample_hellinger'] = float(
                    np.sqrt(0.5 * np.sum((np.sqrt(p) - np.sqrt(q))**2))
                )

                # Analytic TVD/KL/Hellinger
                if real_pdf_funcs and feature in real_pdf_funcs:
                    a_probs = np.array([real_pdf_funcs[feature].get(cat, 0) for cat in cats]) + eps
                    results[feature]['analytic_tvd'] = float(
                        0.5 * np.sum(np.abs(gen_freq.values - a_probs))
                    )
                    kl_ra = entropy(a_probs, gen_freq.values + eps)
                    kl_ar = entropy(gen_freq.values + eps, a_probs)
                    results[feature]['analytic_kl'] = 0.5 * (kl_ra + kl_ar)
                    results[feature]['analytic_hellinger'] = float(
                        np.sqrt(0.5 * np.sum((np.sqrt(a_probs) - np.sqrt(gen_freq.values + eps))**2))
                    )

            else:
                results[feature]['error'] = "Unknown feature type"

            # === New C2ST: LR, DT & RF Accuracy and AUC ===
            Xr = real_df[[feature]].dropna()
            Xg = gen_df[[feature]].dropna()
            n_min = min(len(Xr), len(Xg))
            if n_min >= 10:
                X = pd.concat([
                    Xr.sample(n_min, random_state=0),
                    Xg.sample(n_min, random_state=0)
                ], ignore_index=True)
                y = np.array([1]*n_min + [0]*n_min)

                if ftype == 'categorical':
                    X_enc = pd.get_dummies(X[feature], prefix=feature)
                else:
                    X_enc = X

                X_tr, X_te, y_tr, y_te = train_test_split(
                    X_enc, y, test_size=0.3, random_state=0
                )

                # Logistic Regression
                clf_lr = LogisticRegression(max_iter=1000)
                clf_lr.fit(X_tr, y_tr)
                y_pred_lr = clf_lr.predict(X_te)
                y_prob_lr = clf_lr.predict_proba(X_te)[:, 1]
                results[feature]['sample_c2st_acc']     = float(accuracy_score(y_te, y_pred_lr))
                results[feature]['sample_c2st_auc']     = float(roc_auc_score(y_te, y_prob_lr))

                # Decision Tree
                clf_dt = DecisionTreeClassifier(random_state=0)
                clf_dt.fit(X_tr, y_tr)
                y_pred_dt = clf_dt.predict(X_te)
                y_prob_dt = clf_dt.predict_proba(X_te)[:, 1]
                results[feature]['sample_c2st_acc_dt']  = float(accuracy_score(y_te, y_pred_dt))
                results[feature]['sample_c2st_auc_dt']  = float(roc_auc_score(y_te, y_prob_dt))

                # Random Forest
                clf_rf = RandomForestClassifier(random_state=0)
                clf_rf.fit(X_tr, y_tr)
                y_pred_rf = clf_rf.predict(X_te)
                y_prob_rf = clf_rf.predict_proba(X_te)[:, 1]
                results[feature]['sample_c2st_acc_rf']  = float(accuracy_score(y_te, y_pred_rf))
                results[feature]['sample_c2st_auc_rf']  = float(roc_auc_score(y_te, y_prob_rf))
            else:
                results[feature]['sample_c2st_acc']     = None
                results[feature]['sample_c2st_auc']     = None
                results[feature]['sample_c2st_acc_dt']  = None
                results[feature]['sample_c2st_auc_dt']  = None
                results[feature]['sample_c2st_acc_rf']  = None
                results[feature]['sample_c2st_auc_rf']  = None

        return results

    def evaluate_joint(self, var_list, real_data, gen_data, bins=5):
        """
        Evaluates joint distributions between multiple variables.
        
        Args:
            - var_list: list of variables to evaluate jointly
            - real_data: real data (DataFrame or dict)
            - gen_data: synthetic data (DataFrame or dict)
            - bins: number of bins to discretize numerical features (default: 5)
            
        Returns:
            - results: dictionary with evaluation results for joint distributions
        """
        real_df = self._ensure_dataframe(real_data).copy()
        gen_df  = self._ensure_dataframe(gen_data).copy()

        # Discretize numerical features
        for var in var_list:
            if var == "user_age":
                real_df['user_age'] = real_df['user_age'].map(get_age_group)
                gen_df['user_age']  = gen_df['user_age'].map(get_age_group)
                continue

            if self.feature_types.get(var, 'numerical') == 'numerical':
                all_vals = np.concatenate([
                    real_df[var].dropna().values,
                    gen_df[var].dropna().values
                ])
                if len(all_vals) == 0:
                    continue
                edges = np.linspace(all_vals.min(), all_vals.max(), bins+1)
                real_df[var] = pd.cut(real_df[var], bins=edges, include_lowest=True)
                gen_df[var]  = pd.cut(gen_df[var], bins=edges, include_lowest=True)

        # Joint frequency counts
        real_cnt = real_df.groupby(var_list).size()
        gen_cnt  = gen_df.groupby(var_list).size()
        idx_all  = real_cnt.index.union(gen_cnt.index)
        r_cnt    = real_cnt.reindex(idx_all, fill_value=0)
        g_cnt    = gen_cnt.reindex(idx_all, fill_value=0)

        # Normalize the frequencies
        r_prob = r_cnt / r_cnt.sum() if r_cnt.sum() > 0 else r_cnt
        g_prob = g_cnt / g_cnt.sum() if g_cnt.sum() > 0 else g_cnt

        # Calculate TVD, JSD, and other distances
        tvd = 0.5 * np.sum(np.abs(r_prob - g_prob))
        jsd = jensenshannon(r_prob, g_prob, base=2)

        results = {
            'joint_tvd':  float(tvd),
            'joint_jsd':  float(jsd)
        }

        # Symmetric KL
        eps = 1e-8
        p = r_prob.values + eps
        q = g_prob.values + eps
        kl_pq = entropy(p, q)
        kl_qp = entropy(q, p)
        results['joint_kl'] = 0.5 * (kl_pq + kl_qp)

        # Hellinger Distance
        results['joint_hellinger'] = float(
            np.sqrt(0.5 * np.sum((np.sqrt(p) - np.sqrt(q))**2))
        )

        # L2 Norm
        results['joint_l2'] = float(np.linalg.norm(p - q))

        # Classifier Two-Sample Test (C2ST)
        Xr = real_df[var_list].dropna()
        Xg = gen_df[var_list].dropna()
        n_min = min(len(Xr), len(Xg))
        if n_min >= 10:
            X = pd.concat([
                Xr.sample(n_min, random_state=0),
                Xg.sample(n_min, random_state=0)
            ], ignore_index=True)
            y = np.array([1]*n_min + [0]*n_min)

            # One-hot encode mixed variables
            X_enc = pd.get_dummies(X, drop_first=False)

            X_tr, X_te, y_tr, y_te = train_test_split(
                X_enc, y, test_size=0.3, random_state=0
            )

            # Logistic Regression
            clf_lr = LogisticRegression(max_iter=1000)
            clf_lr.fit(X_tr, y_tr)
            y_pred_lr = clf_lr.predict(X_te)
            y_prob_lr = clf_lr.predict_proba(X_te)[:, 1]
            results['joint_c2st_acc']    = float(accuracy_score(y_te, y_pred_lr))
            results['joint_c2st_auc']    = float(roc_auc_score(y_te, y_prob_lr))

            # Decision Tree
            clf_dt = DecisionTreeClassifier(random_state=0)
            clf_dt.fit(X_tr, y_tr)
            y_pred_dt = clf_dt.predict(X_te)
            y_prob_dt = clf_dt.predict_proba(X_te)[:, 1]
            results['joint_c2st_acc_dt'] = float(accuracy_score(y_te, y_pred_dt))
            results['joint_c2st_auc_dt'] = float(roc_auc_score(y_te, y_prob_dt))

            # Random Forest
            clf_rf = RandomForestClassifier(random_state=0)
            clf_rf.fit(X_tr, y_tr)
            y_pred_rf = clf_rf.predict(X_te)
            y_prob_rf = clf_rf.predict_proba(X_te)[:, 1]
            results['joint_c2st_acc_rf'] = float(accuracy_score(y_te, y_pred_rf))
            results['joint_c2st_auc_rf'] = float(roc_auc_score(y_te, y_prob_rf))
        else:
            # Set to None if sample size is insufficient
            results['joint_c2st_acc']    = None
            results['joint_c2st_auc']    = None
            results['joint_c2st_acc_dt'] = None
            results['joint_c2st_auc_dt'] = None
            results['joint_c2st_acc_rf'] = None
            results['joint_c2st_auc_rf'] = None

        return results
