import re
import json
import copy
import random
import textwrap
import numpy as np
import pandas as pd
from collections import OrderedDict, Counter, defaultdict
from scipy.stats import skew, kurtosis
from typing import List, Tuple, Dict, Any

def transform_structure(obj):
    """
    Recursively transforms input dict/value to:
      - Numbers (int/float) -> ["min", "max"]
      - Dict -> recursively process each item
      - Other types -> "category"
    """
    if isinstance(obj, (int, float)):
        return ["min", "max"]
    elif isinstance(obj, dict):
        return {key: transform_structure(val) for key, val in obj.items()}
    else:
        return "category"


def parse_full_bins(var_bin_freq: Dict[str, Any]) -> Dict[str, List[Tuple[float, float, str]]]:
    """
    Converts var_bin_freq[var]['full_real_freq'] into a list of (lower, upper, label) for numerical discretization.
    """
    bins = {}
    for var, info in var_bin_freq.items():
        if not info:
            continue
        full = info.get('full_real_freq')
        if not full:
            continue
        edges = []
        for label in full:
            lo_str, hi_str = label.split('~')
            lo, hi = float(lo_str), float(hi_str)
            edges.append((lo, hi, label))
        edges.sort(key=lambda x: x[0])
        bins[var] = edges
    return bins


def discretize_record(
    record: Dict[str, Any],
    numeric_bins: Dict[str, List[Tuple[float, float, str]]]
) -> Dict[str, str]:
    """
    Discretizes numerical variables by replacing them with corresponding bin labels.
    If the value is out of range, assign to the nearest bin.
    Categorical variables remain unchanged.
    """
    result = {}
    for var, val in record.items():
        if var in numeric_bins:
            bins = numeric_bins[var]
            label = None
            for lo, hi, lbl in bins:
                if lo <= val <= hi:
                    label = lbl
                    break
            if label is None:
                # Assign to nearest bin
                min_dist = float('inf')
                nearest = None
                for lo, hi, lbl in bins:
                    dist = lo - val if val < lo else val - hi
                    if dist < min_dist:
                        min_dist = dist
                        nearest = lbl
                label = nearest
            result[var] = label
        else:
            result[var] = val
    return result


def topk_joint_diffs(
    real_data: Dict[str, Dict[str, Any]],
    synth_data: Dict[str, Dict[str, Any]],
    variable_groups: Dict[str, List[str]],
    var_bin_freq: Dict[str, Any],
    k: int = 10
) -> Dict[str, List[Tuple[Tuple[str, ...], float]]]:
    """
    For each variable group, compute the top-k differences where p_real - p_synth > 0.
    The result contains the top-k pairs sorted by the largest delta.
    """
    # Build numeric bins
    numeric_bins = parse_full_bins(var_bin_freq)

    # Discretize all records
    real_disc = [discretize_record(rec, numeric_bins) for rec in real_data.values()]
    synth_disc = [discretize_record(rec, numeric_bins) for rec in synth_data.values()]

    results: Dict[str, List[Tuple[Tuple[str, ...], float]]] = {}

    for grp, vars_list in variable_groups.items():
        real_ctr = Counter()
        synth_ctr = Counter()

        # Count frequencies
        for rec in real_disc:
            key = tuple(rec[v] for v in vars_list)
            real_ctr[key] += 1
        for rec in synth_disc:
            key = tuple(rec[v] for v in vars_list)
            synth_ctr[key] += 1

        total_real = sum(real_ctr.values()) or 1
        total_synth = sum(synth_ctr.values()) or 1

        # Only consider real > synth
        deltas: List[Tuple[Tuple[str, ...], float]] = []
        for key, real_count in real_ctr.items():
            p_r = real_count / total_real
            p_s = synth_ctr.get(key, 0) / total_synth
            delta = p_r - p_s
            if delta > 0:
                deltas.append((key, delta))

        # Sort by delta in descending order and take top-k
        deltas.sort(key=lambda x: x[1], reverse=True)
        results[grp] = deltas[:k]

    return results


# ---------- Helper functions for batch processing ----------
def parse_batch_plan(batch_json: Dict[str, Any]) -> List[Tuple[Dict[str, Any], int]]:
    """
    Converts LLM output into a list of (plan_dict, num_to_generate).
    Expected format:
    {
      "plan1": {"plan": {...}, "num": n1},
      "plan2": {"plan": {...}, "num": n2},
      ...
    }
    """
    plans = []
    for key, entry in batch_json.items():
        if isinstance(entry, dict):
            plan = entry.get("plan", {})
            num = entry.get("num")
            if plan and isinstance(num, (int, float)):
                plans.append((plan, int(num)))
    return plans


def _sample_numeric(low: float, high: float, is_int: bool) -> float or int:
    """
    Samples a random number from the range [low, high].
    If is_int, it samples an integer; otherwise, it samples a float.
    """
    if is_int:
        low_i, high_i = int(round(low)), int(round(high))
        return random.randint(low_i, high_i)
    return random.uniform(low, high)


def sample_from_plan(
    plan: Dict[str, Any],
    config: Dict[str, Any],
    real_df: pd.DataFrame,
) -> Dict[str, Any]:
    """
    Generate a synthetic sample following the specified `plan`.
    Missing variables are sampled by bootstrapping a random real row.
    """
    base = real_df.sample(n=1).iloc[0].to_dict()
    sample = base.copy()

    for var, spec in plan.items():
        if var not in config:
            continue

        # ----- Categorical -----
        if "categories" in config[var]:
            if isinstance(spec, str):
                sample[var] = spec
            elif isinstance(spec, list):
                sample[var] = spec[0]
            else:
                sample[var] = random.choice(config[var]["categories"])
            continue

        # ----- Numerical -----
        if isinstance(spec, list) and len(spec) == 2:
            low, high = spec
            is_int = pd.api.types.is_integer_dtype(real_df[var])
            if "constraints" in config[var]:
                low = max(low, config[var]["constraints"][0])
                high = min(high, config[var]["constraints"][1])
            sample[var] = _sample_numeric(low, high, is_int)
        elif isinstance(spec, (int, float)):
            sample[var] = spec
        elif isinstance(spec, list) and len(spec) == 1 and isinstance(spec[0], (int, float)):
            sample[var] = spec[0]
        else:
            raise ValueError(
                f"Invalid specification for variable '{var}': {spec}. "
                "Expected a list of two numbers or a single number."
            )

    return sample


def check_marginal_data(data, config):
    """
    Checks if the generated data matches the expected format and constraints.
    """
    for key, value in config.items():
        if key not in data:
            raise ValueError(f"Key '{key}' not found in generated data.")
        
        if "categories" in value:
            if data[key] not in value["categories"]:
                raise ValueError(f"Value '{data[key]}' for key '{key}' is not in the expected categories.")
        else:
            if not isinstance(data[key], (int, float)):
                raise ValueError(f"Value '{data[key]}' for key '{key}' is not a number.")
    
    return True, data


def check_joint_data(data, config):
    """
    Checks if the generated data matches the expected format and constraints.
    """
    for _, group in data.items():
        if not isinstance(group, list):
            raise ValueError(f"Group '{_}' is not a list.")
        
        if len(group) < 2:
            raise ValueError(f"Group '{_}' has fewer than 2 samples.")
        
        for sample in group:
            if sample not in config:
                raise ValueError(f"Sample '{sample}' is not in the expected configuration.")
    
    return True, data


def check_guide_data(data):
    """
    Checks if the generated data matches the expected format and constraints.
    """
    if "guide" not in data:
        raise ValueError(f"Key 'guide' not found in generated data.")

    return True, data


def check_plan_data(data, config, expected_plans=None, expected_total=None):
    """
    Strict validation:
    1. Numeric variables that are single values or lists of length 1 will be expanded to [v, v];
       two-element lists will be processed as [lo, hi].
    2. Categorical variables with single-element lists will be expanded to a string.
    3. Nested dicts (e.g., person1) cannot be empty dicts.
    Returns (True, fixed_data) on success, raises ValueError on failure.
    """
    fixed = copy.deepcopy(data)

    if not isinstance(fixed, dict):
        raise ValueError(f"Top-level data must be a dict. | data: {fixed}")

    # Optional n_plans check
    if "n_plans" in fixed:
        n_plans = fixed["n_plans"]
        if not (isinstance(n_plans, int) and n_plans >= 0):
            raise ValueError(f"'n_plans' must be a non-negative integer, got: {n_plans!r} | data: {fixed}")
        if expected_plans is not None and n_plans != expected_plans:
            raise ValueError(f"'n_plans' does not match expected number of plans: expected {expected_plans}, got {n_plans}.")
    
    total_generated = 0

    for key, block in fixed.items():
        if key == "n_plans":
            continue

        if not (isinstance(key, str) and key.startswith("plan")):
            raise ValueError(f"Invalid plan key: {key!r}, expected 'plan1', 'plan2', etc. | data: {fixed}")
        if not isinstance(block, dict):
            raise ValueError(f"Value of {key!r} must be a dict. | data: {fixed}")

        if "reason" in block and not isinstance(block["reason"], str):
            raise ValueError(f"{key!r}['reason'] must be a string, got {block['reason']!r} | data: {fixed}")

        if "plan" not in block or "num" not in block:
            missing = [f for f in ("plan", "num") if f not in block]
            raise ValueError(f"{key!r} missing fields: {missing}. | data: {fixed}")

        plan_spec = block["plan"]
        count     = block["num"]

        if not (isinstance(count, int) and count >= 0):
            raise ValueError(f"{key}['num'] must be a non-negative integer, got: {count!r} | data: {fixed}")
        total_generated += count

        if not isinstance(plan_spec, dict):
            raise ValueError(f"{key}['plan'] must be a dict. | data: {fixed}")

        seen_vars = set()

        def _to_number(x):
            if isinstance(x, (int, float)):
                return x
            if isinstance(x, str):
                try:
                    return int(x)
                except ValueError:
                    return float(x)
            raise ValueError

        def _validate(subdict, path):
            for var, spec in list(subdict.items()):
                current_path = path + [var]

                # Check for empty nested dicts
                if isinstance(spec, dict):
                    if not spec:
                        raise ValueError(
                            f"{key}['plan']{'.'.join(current_path)!r} must be a non-empty dict, got {{}} | data: {fixed}"
                        )
                    _validate(spec, current_path)
                    continue

                # Variable must exist in config
                if var not in config:
                    raise ValueError(f"{key}['plan']{'.'.join(current_path)!r} variable not in config. | data: {fixed}")
                seen_vars.add(var)
                var_conf = config[var]

                # Special case for time_str variable
                if var == "time_str":
                    if isinstance(spec, (list, tuple)) and len(spec) == 2:
                        try:
                            lo = _to_number(spec[0])
                            hi = _to_number(spec[1])
                        except ValueError:
                            raise ValueError(
                                f"{key}['plan']{'.'.join(current_path)!r} time_str range items must be numeric, got {spec!r} | data: {fixed}"
                            )
                        if lo > hi:
                            raise ValueError(
                                f"{key}['plan']{'.'.join(current_path)!r} lower bound {lo} > upper bound {hi}. | data: {fixed}"
                            )
                        # Format as string "lo~hi"
                        def _fmt(x):
                            if isinstance(x, float) and x.is_integer():
                                return str(int(x))
                            return str(x)
                        new_str = f"{_fmt(lo)}~{_fmt(hi)}"
                        subdict[var] = new_str
                        spec = new_str

                # Categorical variable: single-element list expansion
                if "categories" in var_conf:
                    if isinstance(spec, (list, tuple)) and len(spec) == 1 and isinstance(spec[0], str):
                        subdict[var] = spec[0]
                        spec = spec[0]
                    if not isinstance(spec, str):
                        raise ValueError(
                            f"{key}['plan']{'.'.join(current_path)!r} categorical variable must be a string, got {spec!r} | data: {fixed}"
                        )
                    if spec not in var_conf["categories"]:
                        raise ValueError(
                            f"{key}['plan']{'.'.join(current_path)!r} value {spec!r} not in categories. | data: {fixed}"
                        )
                    continue

                # Numerical variable handling
                if isinstance(spec, (int, float, str)):
                    try:
                        num = _to_number(spec)
                    except ValueError:
                        raise ValueError(
                            f"{key}['plan']{'.'.join(current_path)!r} invalid numeric value {spec!r}. | data: {fixed}"
                        )
                    subdict[var] = [num, num]
                    lo, hi = num, num

                elif isinstance(spec, (list, tuple)):
                    if len(spec) == 1:
                        try:
                            num = _to_number(spec[0])
                        except ValueError:
                            raise ValueError(
                                f"{key}['plan']{'.'.join(current_path)!r} invalid numeric value {spec[0]!r}. | data: {fixed}"
                            )
                        subdict[var] = [num, num]
                        lo, hi = num, num

                    elif len(spec) == 2:
                        try:
                            lo = _to_number(spec[0])
                            hi = _to_number(spec[1])
                        except ValueError:
                            raise ValueError(
                                f"{key}['plan']{'.'.join(current_path)!r} range items must be numeric, got {spec!r} | data: {fixed}"
                            )
                        if lo > hi:
                            raise ValueError(
                                f"{key}['plan']{'.'.join(current_path)!r} lower bound {lo} > upper bound {hi}. | data: {fixed}"
                            )
                        subdict[var] = [lo, hi]

                    else:
                        raise ValueError(
                            f"{key}['plan']{'.'.join(current_path)!r} numeric list must have 1 or 2 items, got {spec!r} | data: {fixed}"
                        )
        _validate(plan_spec, [])
        missing = set(config.keys()) - seen_vars
        if missing:
            raise ValueError(f"{key}['plan'] missing variables {sorted(missing)}. | data: {fixed}")

    # Check total samples
    if expected_total is not None and total_generated != expected_total:
        raise ValueError(f"Total sample count mismatch: expected {expected_total}, got {total_generated}.")

    return True, fixed


def df2json(df):
    """
    Convert DataFrame rows to a list of dictionaries.
    """
    return [row.to_dict() for _, row in df.iterrows()]


def assign_ctrl_vars(args, data_, config):
    """
    Assign control variables for categorical and numerical features.
    """
    C = {}

    for variable, value in config.items():
        C[variable] = {}
        if "categories" in value:
            for cat in value["categories"]:
                C[variable][cat] = 0
            for data in data_:
                if data[variable] not in C[variable]:
                    C[variable][data[variable]] = 1
                else:
                    C[variable][data[variable]] += 1 
            C[variable] = C[variable]
        else:
            data_list = [data[variable] for data in data_]
            C[variable] = compute_stats(data_list, args.n_quantiles, verbal=True)
            
    return C


def parse_json(response: str) -> dict:
    """
    Parse JSON from various response formats.
    """
    parsed_data = None
    
    # Attempt full parsing
    try:
        parsed_data = json.loads(response)
        return parsed_data
    except:
        pass
    
    # Clean and attempt parsing again
    try:
        cleaned_resp = response.replace("\\", "")
        parsed_data = json.loads(cleaned_resp)
        return parsed_data
    except:
        pass

    # Try to remove markdown block and parse
    for slice_args in [(8, -4), (8, -5)]:
        try:
            parsed_data = json.loads(response[slice_args[0]:slice_args[1]])
            return parsed_data
        except:
            continue

    # Extract JSON code block using regex
    try:
        match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL)
        if match:
            parsed_data = json.loads(match.group(1))
            return parsed_data
    except:
        pass

    # Non-strict decoding after dedenting
    try:
        match = re.search(r"```json\s*(.*?)\s*```", response, re.DOTALL)
        if match:
            raw_block = match.group(1)
            cleaned = textwrap.dedent(raw_block).strip()
            decoder = json.JSONDecoder(strict=False)
            parsed_data = decoder.decode(cleaned)
            return parsed_data
    except:
        pass

    # Fallback parsing
    try:
        start = response.find('{')
        end   = response.rfind('}')
        if start != -1 and end != -1 and start < end:
            snippet = response[start:end+1]
            parsed_data = json.loads(snippet)
            return parsed_data
    except:
        pass

    raise ValueError(f"Failed to parse JSON: |{response}|")
    

EPS = 1e-6  # Small epsilon to break duplicate boundaries

def _fix_edges(edges):
    """
    Ensures edges are strictly increasing. If not, adds a small epsilon.
    """
    e = edges.astype(float).copy()
    for i in range(1, len(e)):
        while e[i] <= e[i-1]:
            e[i] = e[i-1] + EPS
    return e

def _label(a, b, is_int, rd):
    """
    Generates label "lo~hi" for bin edges:
    - For integers with width < 1, merge as k~k
    - For integers, ceil(a)~floor(b-EPS)
    - For floats, round(a)~round(b)
    """
    if is_int and (b - a) < 1.0 - EPS:
        k = int(round((a + b) / 2.0))
        return f"{k}~{k}"
    if is_int:
        lo = int(np.ceil(a))
        hi = int(np.floor(b - EPS))
        return f"{lo}~{hi}"
    return f"{round(a, rd)}~{round(b, rd)}"

def _freq_dict(data, edges, is_int, rd):
    """
    Compute histogram bin frequencies and normalize.
    """
    counts, _ = np.histogram(data, bins=edges)
    N = data.size

    # Generate all labels
    labels = [_label(edges[i], edges[i+1], is_int, rd)
              for i in range(len(counts))]

    # Handle empty data
    if N == 0:
        return OrderedDict((lbl, 0.0) for lbl in labels)

    # Raw frequency accumulation
    raw = {}
    for i, lbl in enumerate(labels):
        raw[lbl] = raw.get(lbl, 0.0) + counts[i] / N * 100

    # Round and adjust errors
    rounded = OrderedDict()
    for lbl in labels:
        rounded[lbl] = round(raw[lbl], rd)
    diff = round(100.0 - sum(rounded.values()), rd)
    if labels:
        last = labels[-1]
        rounded[last] = round(rounded[last] + diff, rd)

    return rounded

def _sub_freq(data, a, b, parent_pct, is_int, rd, n_sub_bins):
    """
    Compute sub-bin frequencies within a parent bin.
    """
    edges = _fix_edges(np.linspace(a, b, n_sub_bins + 1))
    counts, _ = np.histogram(data, bins=edges)
    cnt_parent = counts.sum()

    labels = [_label(edges[i], edges[i+1], is_int, rd)
              for i in range(len(counts))]

    # Handle empty or zero-frequency parent bins
    if cnt_parent == 0 or parent_pct == 0:
        return OrderedDict((lbl, 0.0) for lbl in labels)

    # Raw sub-bin frequency accumulation
    raw = {}
    for i, lbl in enumerate(labels):
        raw[lbl] = raw.get(lbl, 0.0) + counts[i] / cnt_parent * parent_pct

    # Round and adjust errors
    rounded = OrderedDict()
    for lbl in labels:
        rounded[lbl] = round(raw[lbl], rd)
    diff = round(parent_pct - sum(rounded.values()), rd)
    if labels:
        rounded[labels[-1]] = round(rounded[labels[-1]] + diff, rd)

    return rounded

def analyze_bin_freqs(real_data,
                      synthetic_data,
                      n_bins: int = 8,
                      n_sub_bins: int = 4,
                      round_digits: int = 1):
    """
    Perform bin frequency analysis: 
    - Compute main frequencies based on real data's quantiles.
    - Refine bins with the largest real-synthetic difference.
    """
    real = np.asarray(real_data)
    synth = np.asarray(synthetic_data, dtype=real.dtype)
    is_int = np.issubdtype(real.dtype, np.integer)

    # Main bin edges
    pct = np.linspace(0, 100, n_bins + 1)
    edges = _fix_edges(np.percentile(real, pct, method="linear"))
    if len(np.unique(edges)) == 1:
        # Handle constant data
        edges = np.array([edges[0] - 0.5, edges[0] + 0.5])

    # Main frequency calculation
    main_real  = _freq_dict(real,  edges, is_int, round_digits)
    main_synth = _freq_dict(synth, edges, is_int, round_digits)

    # Find the bin with the largest real - synthetic difference
    rv   = np.fromiter(main_real.values(),  float)
    sv   = np.fromiter(main_synth.values(), float)
    diff = rv - sv
    pos  = np.where(diff > 0)[0]
    idx  = pos[np.argmax(diff[pos])] if len(pos) else int(np.argmax(np.abs(diff)))

    refine_key = list(main_real.keys())[idx]
    a, b = edges[idx], edges[idx + 1]
    pr = main_real[refine_key]
    ps = main_synth[refine_key]

    # Refine the selected bin if it's large enough
    if (b - a) > (1 if is_int else EPS):
        sub_real  = _sub_freq(real,  a, b, pr, is_int, round_digits, n_sub_bins)
        sub_synth = _sub_freq(synth, a, b, ps,  is_int, round_digits, n_sub_bins)
    else:
        sub_real = sub_synth = OrderedDict()

    # Merge full results
    def _merge(main, sub):
        out = OrderedDict()
        for k, v in main.items():
            if k == refine_key and sub:
                out.update(sub)
            else:
                out[k] = v
        return out

    return OrderedDict([
        ("real_freq",          main_real),
        ("synthetic_freq",     main_synth),
        ("sub_real_freq",      sub_real),
        ("sub_synthetic_freq", sub_synth),
        ("full_real_freq",     _merge(main_real,  sub_real)),
        ("full_synthetic_freq",_merge(main_synth, sub_synth)),
    ])


def compute_stats(data, n, verbal=False):
    """
    Compute statistics of the dataset, including quantiles, mean, std (standard deviation),
    skewness, kurtosis, MAD, and Z-scores.
    """
    if len(data) == 0:
        return {
            "min": 0,
            "max": 0,
            "quantile_5": 0,
            "quantile_25": 0,
            "quantile_75": 0,
            "quantile_95": 0,
            "median": 0,
            "mean": 0,
            "std": 0,
            "skewness": 0,
            "kurtosis": 0,
            "mad": 0,
            "count": 0
        }

    data = np.array(data)
    sorted_data = np.sort(data)

    # Basic statistics
    count = len(data)
    mean = np.mean(data)
    std = np.std(data)
    median = np.median(data)
    max_value = np.max(data)
    min_value = np.min(data)

    skewness_value = skew(data)
    kurtosis_value = kurtosis(data)

    mad = np.median(np.abs(data - median))

    z_scores = (data - mean) / std

    quantile_percentages = [100 * i / n for i in range(1, n)]
    quantile_values = np.percentile(sorted_data, quantile_percentages)

    quantile_5 = np.percentile(sorted_data, 5)
    quantile_25 = np.percentile(sorted_data, 25)
    quantile_75 = np.percentile(sorted_data, 75)
    quantile_95 = np.percentile(sorted_data, 95)

    result_dict = {
        "min": min_value.item(),
        "max": max_value.item(),
        "quantile_5": quantile_5.item(),
        "quantile_25": quantile_25.item(),
        "quantile_75": quantile_75.item(),
        "quantile_95": quantile_95.item(),
        "median": median.item(),
        "mean": mean.item(), 
        "std": std.item(), 
        "skewness": skewness_value.item(), 
        "kurtosis": kurtosis_value.item(), 
        "mad": mad.item(),
        "count": count
    }

    return result_dict


def numerical_marginal_table(variable: str,
                             real_stats: dict,
                             syn_stats: dict,
                             bin_freq: dict | None = None,
                             only_positive: bool = True) -> str:
    """
    Markdown prompt for one numerical feature.
    - The first table: overall statistics (Real / Synthetic / Δ).
    - The second table (optional): histogram bin frequencies (%), with 3 rows.
    Δ row = Real – Synthetic.
    If only_positive=True, only rows where Δ > 0 are kept.
    """
    # -------- 统计量表 --------
    cols = [
        ("Min",            "min"),
        ("Quantile 5%",    "quantile_5"),
        ("Quantile 25%",   "quantile_25"),
        ("Median",         "median"),
        ("Quantile 75%",   "quantile_75"),
        ("Quantile 95%",   "quantile_95"),
        ("Max",            "max"),
        ("Std",            "std"),
        ("MAD",            "mad"),
        ("Skewness",       "skewness"),
        ("Kurtosis",       "kurtosis"),
    ]

    # If only_positive=True, filter out columns where Δ <= 0
    if only_positive:
        cols = [
            (name, key)
            for name, key in cols
            if (real_stats[key] - syn_stats[key]) > 0
        ]

    def _stat_row(label, a, b=None):
        # When b is not None, calculate real - syn (Δ)
        vals = [
            f"{(a[key] - b[key]):+.4g}" if b else f"{a[key]:.4g}"
            for _, key in cols
        ]
        return "| " + label + " | " + " | ".join(vals) + " |\n"

    header = "| Statistic | " + " | ".join(n for n, _ in cols) + " |\n"
    sep    = "|" + " --- |" * (len(cols) + 1) + "\n"
    stat_table  = header + sep
    stat_table += _stat_row("Δ (Real-Syn) pp", real_stats, syn_stats)

    # -------- bin 频率表（可选）--------
    bin_table = ""
    if bin_freq:
        bins = list(bin_freq["full_real_freq"].keys())   # Maintain bin order

        # If only_positive=True, filter bins where Δ <= 0
        if only_positive:
            bins = [
                b for b in bins
                if (bin_freq["full_real_freq"][b] - bin_freq["full_synthetic_freq"][b]) > 0
            ]

        def _fmt_row(label, freq_a, freq_b=None):
            vals = [
                f"{(freq_a[b] - freq_b[b]):+.2f}" if freq_b else f"{freq_a[b]:.2f} %"
                for b in bins
            ]
            return "| " + label + " | " + " | ".join(vals) + " |\n"

        head = "| Histogram Bin | " + " | ".join(bins) + " |\n"
        sep2 = "|" + " --- |" * (len(bins) + 1) + "\n"
        bin_table  = "\n**Histogram bin frequencies (%):**\n\n"
        bin_table += head + sep2
        bin_table += _fmt_row("Δ (Real-Syn) pp", bin_freq["full_real_freq"], bin_freq["full_synthetic_freq"])

    return (
        f"#### {variable}\n"
        f"The tables compare **real** vs **synthetic** distributions. "
        f"The rows labelled ‘Δ (Real‑Syn)’ show absolute differences (real minus synthetic).\n\n"
        f"{stat_table}\n{bin_table}\n"
    )


def categorical_marginal_table(variable: str,
                               real_freq: dict,
                               syn_freq: dict,
                               only_positive: bool = True) -> str:
    """
    Markdown table (4 rows) for one categorical feature.
    Frequencies are percentages; Δ row = Real – Synthetic (pp).
    If only_positive=True, only categories where Δ > 0 are kept.
    """
    categories = list(real_freq.keys())  # assume same keys order

    # Calculate percentages for each category
    total_real = sum(real_freq.values())
    total_syn  = sum(syn_freq.values())
    if total_real == 0:
        total_real = 1
    if total_syn == 0:
        total_syn = 1
    perc_real = {c: real_freq[c] / total_real * 100 for c in categories}
    perc_syn  = {c: syn_freq[c]  / total_syn  * 100 for c in categories}

    # If only_positive=True, filter out categories where Δ <= 0
    if only_positive:
        categories = [
            c for c in categories
            if (perc_real[c] - perc_syn[c]) > 0
        ]

    header = "| Category | " + " | ".join(categories) + " |\n"
    sep    = "|" + " --- |" * (len(categories) + 1) + "\n"

    def fmt_row(label, pct_dict):
        vals = [f"{pct_dict[c]:.1f} %" for c in categories]
        return "| " + label + " | " + " | ".join(vals) + " |\n"

    def fmt_delta(label, pct_dict_a, pct_dict_b):
        vals = [f"{(pct_dict_a[c] - pct_dict_b[c]):+.1f}" for c in categories]
        return "| " + label + " | " + " | ".join(vals) + " |\n"

    table  = header + sep
    table += fmt_delta("Δ (Real‑Syn) pp", perc_real, perc_syn)

    return (
        f"#### {variable}\n"
        f"Frequency comparison (**percentage of samples**); Δ row shows "
        f"the absolute percentage‑point gap (real minus synthetic).\n\n"
        f"{table}\n"
    )


def get_topk_gaps(
    dist_summary: Dict[str, Dict[str, Any]],
    var_bin_freq: Dict[str, Any],
    vars_: List[str],
    k: int = 5,
) -> Dict[str, List[Tuple[str, float, float]]]:
    """
    For each variable in vars_:
    - If var_bin_freq[v] exists, use its full_real_freq and full_synthetic_freq;
    - Otherwise, use dist_summary["real"][v] vs. dist_summary["synthetic"][v] counts.
    
    For each variable, compute (p_real, p_synth) for each label, filter where p_real > p_synth,
    and sort by p_real - p_synth to return the top k.
    
    Args:
        dist_summary (dict): Contains real and synthetic distribution summaries.
        var_bin_freq (dict): Precomputed frequency information, if available.
        vars_ (list): List of variables to evaluate.
        k (int): Top `k` items to return.

    Returns:
        dict: A dictionary mapping each variable to its top k label gaps in the form
              {v1: [(label, real_prob, synth_prob), ...], v2: [...], …}
    """
    
    def _topk_from_mapping(
        real_map: Dict[Any, float],
        syn_map: Dict[Any, float],
        k: int
    ) -> List[Tuple[str, float, float]]:
        """
        Helper function to compute the top k gaps between real and synthetic distributions
        based on their probabilities.
        
        Args:
            real_map (dict): Mapping of labels to counts in the real distribution.
            syn_map (dict): Mapping of labels to counts in the synthetic distribution.
            k (int): Top `k` items to return.

        Returns:
            list: A list of tuples with the top k labels (label, real_prob, synth_prob).
        """
        # Calculate total counts for normalization
        total_r = sum(real_map.values())
        total_s = sum(syn_map.values()) or total_r  # Prevent division by zero
        
        diffs: List[Tuple[str, float, float]] = []
        
        # Calculate probabilities and filter based on p_real > p_synth
        for key, r_val in real_map.items():
            s_val = syn_map.get(key, 0.0)
            r_prob = r_val / total_r
            s_prob = s_val / total_s
            if r_prob > s_prob:
                diffs.append((str(key), r_prob, s_prob))
        
        # Sort by the difference between p_real and p_synth, and return the top k
        diffs.sort(key=lambda x: x[1] - x[2], reverse=True)
        return diffs[:k]

    result: Dict[str, List[Tuple[str, float, float]]] = {}

    # Iterate over all variables in vars_
    for v in vars_:
        freq_info = var_bin_freq.get(v)
        
        if freq_info:
            # Use precomputed full frequencies if available
            real_freq = freq_info["full_real_freq"]
            synth_freq = freq_info["full_synthetic_freq"]
            result[v] = _topk_from_mapping(real_freq, synth_freq, k)
        else:
            # Fall back to raw counts in dist_summary
            real_counts = dist_summary["real"][v]
            synth_counts = dist_summary["synthetic"][v]
            
            # Normalize counts to probabilities
            total_r = sum(real_counts.values())
            total_s = sum(synth_counts.values()) or total_r  # Prevent division by zero
            
            # Convert raw counts to probabilities for both real and synthetic distributions
            real_freq = {cat: cnt / total_r for cat, cnt in real_counts.items()}
            synth_freq = {cat: synth_counts.get(cat, 0) / total_s for cat in real_counts}
            
            # Get the top k gaps
            result[v] = _topk_from_mapping(real_freq, synth_freq, k)

    return result