"""
Shared utilities for SAT-based TM compression experiments.

Includes:
- Dataset loaders (binary and multi-class)
- Oracle compression functions (IMLI)
- Predictor building
- Evaluation metrics
- Logging utilities with real-time progress tracking

This is a self-contained module for reproducing all experiments.
"""

import numpy as np
import json
import time
import sys
from pathlib import Path
from datetime import datetime
from sklearn.datasets import load_breast_cancer, load_iris, load_wine
from sklearn.model_selection import train_test_split
from pyTsetlinMachine.tm import MultiClassTsetlinMachine
from pysat.formula import WCNF, IDPool
from pysat.examples.rc2 import RC2


# ==================== Logging ====================

class ExperimentLogger:
    """Real-time logging for long-running experiments."""

    def __init__(self, experiment_name, results_dir=None):
        self.experiment_name = experiment_name

        # Use new organized structure if no custom dir specified
        if results_dir is None:
            project_root = Path(__file__).parent.parent
            results_dir = project_root / "results"

        self.results_dir = Path(results_dir)
        self.results_dir.mkdir(parents=True, exist_ok=True)

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.json_file = self.results_dir / f"{experiment_name}_{timestamp}.json"
        self.status_file = self.results_dir / f"{experiment_name}_status.json"

        # Initialize JSON results file with log entries
        self.results = {
            "experiment": experiment_name,
            "start_time": datetime.now().isoformat(),
            "status": "running",
            "completed_tasks": 0,
            "total_tasks": 0,
            "log_entries": [],
            "results": []
        }
        self.save_json()
        self.save_status()

    def log(self, message, level="INFO"):
        """Write to both stdout and JSON log."""
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        log_line = f"[{timestamp}] [{level}] {message}"
        print(log_line, flush=True)

        # Add to JSON log
        self.results["log_entries"].append({
            "timestamp": datetime.now().isoformat(),
            "level": level,
            "message": message
        })

        # Save incrementally (only if significant message or every 10 entries)
        if level == "ERROR" or len(self.results["log_entries"]) % 10 == 0:
            self.save_json()

    def update_progress(self, completed, total):
        """Update progress counters."""
        self.results["completed_tasks"] = completed
        self.results["total_tasks"] = total
        self.save_status()

    def add_result(self, result_dict):
        """Add a result and save incrementally."""
        self.results["results"].append(result_dict)
        self.results["completed_tasks"] = len(self.results["results"])
        self.save_json()
        self.save_status()

    def save_json(self):
        """Save full results to JSON."""
        with open(self.json_file, "w") as f:
            json.dump(self.results, indent=2, fp=f)

    def save_status(self):
        """Save lightweight status for quick monitoring."""
        status = {
            "experiment": self.experiment_name,
            "status": self.results["status"],
            "completed": self.results["completed_tasks"],
            "total": self.results["total_tasks"],
            "last_update": datetime.now().isoformat(),
            "json_file": str(self.json_file),
            "num_log_entries": len(self.results["log_entries"])
        }
        with open(self.status_file, "w") as f:
            json.dump(status, indent=2, fp=f)

    def finish(self, status="completed"):
        """Mark experiment as finished."""
        self.results["status"] = status
        self.results["end_time"] = datetime.now().isoformat()
        self.save_json()
        self.save_status()
        self.log(f"Experiment finished: {status}", "INFO")


# ==================== TM Training ====================

def train_tsetlin_machine(X_train, y_train, n_clauses=200, T=5000, s=10.0,
                          epochs=100, seed=42, logger=None, weighted_clauses=True):
    """Train a binary Tsetlin Machine."""
    if logger:
        logger.log(f"Training TM: {n_clauses} clauses, {epochs} epochs, seed={seed}, weighted_clauses={weighted_clauses}")

    tm = MultiClassTsetlinMachine(
        n_clauses, T, s,
        boost_true_positive_feedback=0,
        number_of_state_bits=8,
        append_negated=True,
        max_included_literals=32,
        weighted_clauses=weighted_clauses
    )

    start = time.time()
    for epoch in range(epochs):
        tm.fit(X_train, y_train, epochs=1, incremental=True)
        if logger and (epoch + 1) % 20 == 0:
            y_pred = tm.predict(X_train)
            acc = np.mean(y_pred == y_train)
            logger.log(f"  Epoch {epoch+1}/{epochs}: acc={acc:.4f}")

    train_time = time.time() - start
    y_pred = tm.predict(X_train)
    train_acc = np.mean(y_pred == y_train)

    if logger:
        logger.log(f"Training complete in {train_time:.1f}s, acc={train_acc:.4f}")

    return tm, train_time


# ==================== Dataset Loaders ====================

def _binarize_with_train_medians(X_train, X_test):
    """
    CRITICAL: Binarize features using medians computed ONLY on training data.

    This ensures no data leakage from test set into preprocessing.

    Args:
        X_train: Training features (numpy array, shape [n_train, n_features])
        X_test: Test features (numpy array, shape [n_test, n_features])

    Returns:
        X_train_bin, X_test_bin: Binarized features (uint32)
    """
    # Compute medians on TRAIN ONLY
    train_medians = np.median(X_train, axis=0)

    # Binarize both sets using train-only medians
    X_train_bin = (X_train > train_medians).astype(np.uint32)
    X_test_bin = (X_test > train_medians).astype(np.uint32)

    return X_train_bin, X_test_bin


def _prepare_dataset_from_arrays(X, y, dataset_name, seed=42, test_size=0.2):
    """
    Universal preprocessing pipeline for datasets already loaded as arrays.

    METHODOLOGY (Nov 5, 2025 - Data Leakage Fix):
    1. Split FIRST on raw data
    2. Compute medians on TRAIN ONLY
    3. Binarize using train-only statistics

    Args:
        X: Features (numpy array)
        y: Labels (numpy array)
        dataset_name: Human-readable name for logging
        seed: Random seed for splitting
        test_size: Fraction for test set (default 0.2)

    Returns:
        X_train_bin, X_test_bin, y_train, y_test, dataset_name
    """
    # Split FIRST on raw data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=seed, stratify=y
    )

    # Binarize using train-only medians
    X_train_bin, X_test_bin = _binarize_with_train_medians(X_train, X_test)

    return X_train_bin, X_test_bin, y_train, y_test, dataset_name


def load_breast_cancer_binarized(seed=42):
    """Load Breast Cancer Wisconsin dataset (binary classification)."""
    data = load_breast_cancer()
    return _prepare_dataset_from_arrays(data.data, data.target, "Breast Cancer", seed)


def load_iris_binary(seed=42):
    """Load Iris dataset (binary: setosa vs others)."""
    data = load_iris()
    y_binary = (data.target == 0).astype(int)  # setosa vs rest
    return _prepare_dataset_from_arrays(data.data, y_binary, "Iris Binary", seed)


def load_iris_multiclass(seed=42):
    """Load Iris dataset (3-class)."""
    data = load_iris()
    return _prepare_dataset_from_arrays(data.data, data.target, "Iris (3-class)", seed)


def load_wine_multiclass(seed=42):
    """Load Wine dataset (3-class)."""
    data = load_wine()
    return _prepare_dataset_from_arrays(data.data, data.target, "Wine (3-class)", seed)


def load_spambase(seed=42):
    """Load Spambase dataset from UCI."""
    try:
        from ucimlrepo import fetch_ucirepo
        dataset = fetch_ucirepo(id=94)  # Spambase ID
        X = dataset.data.features.values
        y = dataset.data.targets.values.ravel()
    except Exception as e:
        raise RuntimeError(f"Failed to load Spambase: {e}. Install: pip install ucimlrepo")

    return _prepare_dataset_from_arrays(X, y, "Spambase", seed)


def load_adult_uci(seed=42, subsample=10000):
    """Load Adult-UCI dataset (binary classification: income >50K or <=50K).

    This dataset requires special preprocessing due to:
    - Categorical features (need LabelEncoding)
    - Missing values (NaN handling)
    - Large size (optional subsampling)

    Args:
        seed: Random seed for reproducibility
        subsample: Number of samples to use (default 10000, full dataset is 32561)

    Returns:
        X_train, X_test, y_train, y_test, dataset_name
    """
    try:
        from ucimlrepo import fetch_ucirepo
        dataset = fetch_ucirepo(id=2)  # Adult dataset ID
        X = dataset.data.features
        y = dataset.data.targets
    except Exception as e:
        raise RuntimeError(f"Failed to load Adult-UCI: {e}. Install: pip install ucimlrepo")

    from sklearn.preprocessing import LabelEncoder

    # Convert targets BEFORE any preprocessing
    y_arr = (y.values.ravel() == '>50K').astype(int)  # Binary: 1 if income >50K

    # Remove any rows with NaN (on raw data)
    X_df = X.copy()
    valid_rows = ~X_df.isnull().any(axis=1)
    X_df = X_df[valid_rows]
    y_arr = y_arr[valid_rows]

    # Subsample if requested (on raw data, before any preprocessing)
    if subsample and subsample < len(X_df):
        rng = np.random.RandomState(seed)
        indices = rng.choice(len(X_df), subsample, replace=False)
        X_df = X_df.iloc[indices].reset_index(drop=True)
        y_arr = y_arr[indices]

    # Split FIRST on raw data
    X_train_df, X_test_df, y_train, y_test = train_test_split(
        X_df, y_arr, test_size=0.2, random_state=seed, stratify=y_arr
    )

    # Fit LabelEncoders on TRAIN ONLY
    X_train_encoded = X_train_df.copy()
    X_test_encoded = X_test_df.copy()

    for col in X_train_df.columns:
        if X_train_df[col].dtype == 'object':
            le = LabelEncoder()
            X_train_encoded[col] = le.fit_transform(X_train_df[col].astype(str))
            # Transform test, handling unseen categories
            X_test_encoded[col] = X_test_df[col].astype(str).map(
                lambda x: le.transform([x])[0] if x in le.classes_ else -1
            )

    X_train_arr = X_train_encoded.values.astype(float)
    X_test_arr = X_test_encoded.values.astype(float)

    # Binarize using train-only medians
    X_train_bin, X_test_bin = _binarize_with_train_medians(X_train_arr, X_test_arr)

    dataset_name = f"Adult-UCI ({subsample} samples)" if subsample else "Adult-UCI (full)"
    return X_train_bin, X_test_bin, y_train, y_test, dataset_name


def load_phishing(seed=42):
    """Load Phishing Websites dataset from UCI.

    11,055 samples × 30 features, binary classification.
    Mixed categorical & continuous features, moderate class imbalance.
    """
    try:
        from ucimlrepo import fetch_ucirepo
        dataset = fetch_ucirepo(id=327)  # Phishing Websites ID
        X = dataset.data.features.values
        y_raw = dataset.data.targets.values.ravel()

        # Convert labels to binary 0/1 format
        # Phishing dataset typically has -1 (legitimate) and 1 (phishing)
        unique_labels = np.unique(y_raw)
        if len(unique_labels) != 2:
            raise ValueError(f"Expected binary labels, got {len(unique_labels)} unique values: {unique_labels}")

        # Map to 0/1: smaller value → 0, larger value → 1
        y = (y_raw == unique_labels[1]).astype(int)
    except Exception as e:
        raise RuntimeError(f"Failed to load Phishing: {e}. Install: pip install ucimlrepo")

    return _prepare_dataset_from_arrays(X, y, "Phishing Websites", seed)


def load_banknote(seed=42):
    """Load Banknote Authentication dataset from UCI.

    1,372 samples × 4 features, binary classification.
    Low-dimensional but noisy, continuous features only.
    """
    try:
        from ucimlrepo import fetch_ucirepo
        dataset = fetch_ucirepo(id=267)  # Banknote Authentication ID
        X = dataset.data.features.values
        y_raw = dataset.data.targets.values.ravel()

        # Convert labels to binary 0/1 format
        unique_labels = np.unique(y_raw)
        if len(unique_labels) != 2:
            raise ValueError(f"Expected binary labels, got {len(unique_labels)} unique values: {unique_labels}")

        # Ensure labels are 0/1
        y = (y_raw == unique_labels[1]).astype(int)
    except Exception as e:
        raise RuntimeError(f"Failed to load Banknote: {e}. Install: pip install ucimlrepo")

    return _prepare_dataset_from_arrays(X, y, "Banknote Authentication", seed)


def load_pendigits(seed=42):
    """Load Pendigits dataset via sklearn.

    10,992 samples × 16 features, 10-class classification.
    Medium-sized multiclass with correlated features.
    """
    try:
        from sklearn.datasets import fetch_openml
        dataset = fetch_openml('pendigits', version=1, parser='auto')
        X = dataset.data
        y = dataset.target.astype(int)

        # Convert to numpy arrays if needed
        if hasattr(X, 'values'):
            X = X.values
        if hasattr(y, 'values'):
            y = y.values
    except Exception as e:
        raise RuntimeError(f"Failed to load Pendigits: {e}. Install: pip install scikit-learn")

    return _prepare_dataset_from_arrays(X, y, "Pendigits (10-class)", seed)


def load_mnist_binary(seed=42):
    """Load MNIST dataset (binary: digit 0 vs digit 1).

    ~12,665 samples × 784 features, binary classification.
    High-dimensional image dataset to address "UCI ghetto" concern.
    """
    try:
        from sklearn.datasets import fetch_openml
        mnist = fetch_openml('mnist_784', version=1, parser='auto', as_frame=False)
        X = mnist.data
        y = mnist.target.astype(int)

        # Filter to only digits 0 and 1
        mask = (y == 0) | (y == 1)
        X = X[mask]
        y = y[mask]

        # Convert to binary 0/1
        y = (y == 1).astype(int)

    except Exception as e:
        raise RuntimeError(f"Failed to load MNIST: {e}. Install: pip install scikit-learn")

    return _prepare_dataset_from_arrays(X, y, "MNIST Binary (0 vs 1)", seed)


def load_higgs_100k(seed=42, subsample=100000):
    """Load HIGGS dataset (subsampled to 100K samples).

    Original: 11M samples × 28 features, binary classification.
    Large-scale benchmark for scalability validation.

    Downloads from UCI repository and caches locally for reproducibility.

    Args:
        seed: Random seed for reproducibility
        subsample: Number of samples to use (default 100K)
    """
    import gzip
    import urllib.request
    import os

    # Cache directory for downloaded datasets
    cache_dir = Path(__file__).parent.parent / "data_cache"
    cache_dir.mkdir(parents=True, exist_ok=True)
    cache_file = cache_dir / "HIGGS.csv.gz"

    # UCI repository URL for HIGGS dataset
    url = "https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz"

    # Download if not cached
    if not cache_file.exists():
        print(f"Downloading HIGGS dataset from UCI repository...")
        print(f"This is a large file (~2.6GB compressed, 11M samples).")
        print(f"Saving to: {cache_file}")
        try:
            urllib.request.urlretrieve(url, cache_file)
            print("Download complete.")
        except Exception as e:
            raise RuntimeError(f"Failed to download HIGGS dataset: {e}")

    # Load data from gzipped CSV
    print(f"Loading HIGGS dataset from cache...")
    try:
        with gzip.open(cache_file, 'rt') as f:
            # HIGGS format: label (column 0), 28 features (columns 1-28)
            # Read only the needed number of samples for efficiency
            lines_to_read = subsample if subsample else None

            data = []
            labels = []
            for i, line in enumerate(f):
                if lines_to_read and i >= lines_to_read:
                    break
                parts = line.strip().split(',')
                labels.append(float(parts[0]))
                data.append([float(x) for x in parts[1:]])

            X = np.array(data, dtype=np.float32)
            y = np.array(labels, dtype=np.int32)

    except Exception as e:
        raise RuntimeError(f"Failed to load HIGGS from cache: {e}")

    # Shuffle with seed for reproducibility (since we read first N lines)
    rng = np.random.RandomState(seed)
    shuffle_idx = rng.permutation(len(X))
    X = X[shuffle_idx]
    y = y[shuffle_idx]

    print(f"Loaded {len(X)} samples with {X.shape[1]} features.")

    dataset_name = f"HIGGS ({len(X)//1000}K samples)"
    return _prepare_dataset_from_arrays(X, y, dataset_name, seed)


def load_mushroom(seed=42):
    """Load Mushroom dataset from OpenML.

    8,124 samples × 22 categorical features, binary classification.
    Classify mushrooms as edible (e) or poisonous (p).
    TM achieves 96-97% accuracy on this dataset (per TM literature).
    Categorical features are one-hot encoded for TM compatibility.
    """
    from sklearn.datasets import fetch_openml
    from sklearn.preprocessing import LabelEncoder, OneHotEncoder

    data = fetch_openml(name='mushroom', version=1, as_frame=False, parser='auto')
    X_cat = data.data  # Categorical string features
    y_str = data.target

    # One-hot encode categorical features
    encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
    X = encoder.fit_transform(X_cat)

    # Encode target: 'e' (edible) -> 0, 'p' (poisonous) -> 1
    y = LabelEncoder().fit_transform(y_str)

    n_features = X.shape[1]
    return _prepare_dataset_from_arrays(X, y, f"Mushroom ({n_features} one-hot features)", seed)


def load_magic(seed=42):
    """Load MAGIC Gamma Telescope dataset from OpenML.

    19,020 samples × 10 features, binary classification.
    Distinguishes gamma ray signals from hadronic background.
    Challenging non-linear decision boundary, ~87% best accuracy with RF.
    """
    from sklearn.datasets import fetch_openml
    from sklearn.preprocessing import LabelEncoder

    data = fetch_openml(name='magic', version=1, as_frame=False, parser='auto')
    X = data.data.astype(np.float64)
    y = LabelEncoder().fit_transform(data.target)

    return _prepare_dataset_from_arrays(X, y, "MAGIC Gamma Telescope", seed)


def load_phoneme(seed=42):
    """Load Phoneme dataset from OpenML.

    5,404 samples × 5 features, binary classification.
    Speech recognition task distinguishing nasal vs oral sounds.
    Low-dimensional but challenging (~88% best accuracy with RF).
    """
    from sklearn.datasets import fetch_openml
    from sklearn.preprocessing import LabelEncoder

    data = fetch_openml(name='phoneme', version=1, as_frame=False, parser='auto')
    X = data.data.astype(np.float64)
    y = LabelEncoder().fit_transform(data.target)

    return _prepare_dataset_from_arrays(X, y, "Phoneme", seed)


def load_electricity(seed=42, subsample=20000):
    """Load Electricity dataset from OpenML (subsampled).

    Original: 45,312 samples × 8 features, binary classification.
    Australian electricity market price movement prediction.
    Time-series data with concept drift, ~84% best accuracy.

    Args:
        seed: Random seed for reproducibility
        subsample: Number of samples to use (default 20K for reasonable runtime)
    """
    from sklearn.datasets import fetch_openml
    from sklearn.preprocessing import LabelEncoder

    data = fetch_openml(name='electricity', version=1, as_frame=False, parser='auto')
    X = data.data.astype(np.float64)
    y = LabelEncoder().fit_transform(data.target)

    # Subsample if requested
    if subsample and len(X) > subsample:
        rng = np.random.RandomState(seed)
        idx = rng.choice(len(X), size=subsample, replace=False)
        X = X[idx]
        y = y[idx]

    dataset_name = f"Electricity ({len(X)//1000}K samples)"
    return _prepare_dataset_from_arrays(X, y, dataset_name, seed)


def load_kr_vs_kp(seed=42):
    """Load Chess (King-Rook vs King-Pawn) dataset from OpenML.

    3,196 samples × 36 categorical features, binary classification.
    Predict whether white can win in KRKPA7 chess endgame.
    All features are categorical (board relations) - ideal for TM.
    """
    from sklearn.datasets import fetch_openml
    from sklearn.preprocessing import LabelEncoder, OneHotEncoder

    data = fetch_openml(name='kr-vs-kp', version=1, as_frame=False, parser='auto')
    X_cat = data.data
    y_str = data.target

    # One-hot encode categorical features
    encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
    X = encoder.fit_transform(X_cat)

    # Encode target: 'won' -> 1, 'nowin' -> 0
    y = LabelEncoder().fit_transform(y_str)

    n_features = X.shape[1]
    return _prepare_dataset_from_arrays(X, y, f"Kr-vs-Kp ({n_features} one-hot features)", seed)


def load_nursery(seed=42, binary_mode='recommend'):
    """Load Nursery dataset from OpenML with binary reduction.

    12,960 samples × 8 categorical features, originally 5-class.
    Nursery school admission recommendations.
    All features are categorical - ideal for TM.

    Args:
        seed: Random seed
        binary_mode: How to reduce to binary classification:
            - 'recommend': (priority, spec_prior, very_recom) vs (not_recom, recommend)
            - 'priority': (priority, spec_prior) vs others
            - 'not_recom': not_recom vs others
    """
    from sklearn.datasets import fetch_openml
    from sklearn.preprocessing import LabelEncoder, OneHotEncoder

    data = fetch_openml(name='nursery', version=1, as_frame=False, parser='auto')
    X_cat = data.data
    y_str = data.target

    # One-hot encode categorical features
    encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
    X = encoder.fit_transform(X_cat)

    # Binary reduction based on mode
    if binary_mode == 'recommend':
        # High recommendation vs low/none
        positive_classes = ['priority', 'spec_prior', 'very_recom']
        y = np.isin(y_str, positive_classes).astype(int)
    elif binary_mode == 'priority':
        # Priority admissions vs others
        positive_classes = ['priority', 'spec_prior']
        y = np.isin(y_str, positive_classes).astype(int)
    elif binary_mode == 'not_recom':
        # Rejected vs any positive recommendation
        y = (y_str == 'not_recom').astype(int)
    else:
        raise ValueError(f"Unknown binary_mode: {binary_mode}")

    n_features = X.shape[1]
    return _prepare_dataset_from_arrays(X, y, f"Nursery-{binary_mode} ({n_features} one-hot)", seed)


def load_splice(seed=42, binary_mode='boundary'):
    """Load Splice-junction Gene Sequences dataset from OpenML.

    3,190 samples × 60 categorical features (DNA positions), originally 3-class.
    Classify DNA sequences as exon-intron (EI), intron-exon (IE), or neither (N).
    All features are categorical (nucleotides) - ideal for TM.

    Args:
        seed: Random seed
        binary_mode: How to reduce to binary classification:
            - 'boundary': (EI or IE) vs N - detect splice boundaries
            - 'ei': EI vs others
            - 'ie': IE vs others
    """
    from sklearn.datasets import fetch_openml
    from sklearn.preprocessing import OneHotEncoder

    data = fetch_openml(name='splice', version=1, as_frame=False, parser='auto')
    X_cat = data.data
    y_str = data.target

    # One-hot encode categorical features (DNA nucleotides)
    encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
    X = encoder.fit_transform(X_cat)

    # Binary reduction based on mode
    if binary_mode == 'boundary':
        # Splice boundary vs non-boundary
        y = np.isin(y_str, ['EI', 'IE']).astype(int)
    elif binary_mode == 'ei':
        y = (y_str == 'EI').astype(int)
    elif binary_mode == 'ie':
        y = (y_str == 'IE').astype(int)
    else:
        raise ValueError(f"Unknown binary_mode: {binary_mode}")

    n_features = X.shape[1]
    return _prepare_dataset_from_arrays(X, y, f"Splice-{binary_mode} ({n_features} one-hot)", seed)


def load_connect4(seed=42, binary_mode='win', subsample=None):
    """Load Connect-4 dataset from OpenML.

    67,557 samples × 42 categorical features (board cells), originally 3-class.
    Predict game outcome: win, loss, or draw.
    All features are categorical (x, o, b) - ideal for TM.

    Note: Weighted TM achieved 87.9% accuracy with 50x compression on this dataset.

    Args:
        seed: Random seed
        binary_mode: How to reduce to binary classification:
            - 'win': win vs (loss or draw)
            - 'decisive': (win or loss) vs draw
        subsample: Optional number of samples to use (full dataset is large)
    """
    from sklearn.datasets import fetch_openml
    from sklearn.preprocessing import OneHotEncoder
    from scipy import sparse

    data = fetch_openml(name='connect-4', version=1, as_frame=False, parser='auto')
    X_cat = data.data
    y_str = data.target

    # Convert sparse to dense if necessary
    if sparse.issparse(X_cat):
        X_cat = X_cat.toarray()

    # One-hot encode categorical features (board cells: x, o, b)
    encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
    X = encoder.fit_transform(X_cat)

    # Binary reduction based on mode
    if binary_mode == 'win':
        y = (y_str == 'win').astype(int)
    elif binary_mode == 'decisive':
        y = np.isin(y_str, ['win', 'loss']).astype(int)
    else:
        raise ValueError(f"Unknown binary_mode: {binary_mode}")

    # Subsample if requested
    if subsample and len(X) > subsample:
        rng = np.random.RandomState(seed)
        idx = rng.choice(len(X), size=subsample, replace=False)
        X = X[idx]
        y = y[idx]

    n_features = X.shape[1]
    n_samples = len(X)
    return _prepare_dataset_from_arrays(X, y, f"Connect4-{binary_mode} ({n_samples} samples, {n_features} one-hot)", seed)


def load_tictactoe(seed=42):
    """Load Tic-Tac-Toe Endgame dataset from OpenML.

    958 samples × 9 categorical features (board cells), binary classification.
    Predict whether X wins (all possible endgame board states).
    All features are categorical (x, o, b) - ideal for TM.
    """
    from sklearn.datasets import fetch_openml
    from sklearn.preprocessing import LabelEncoder, OneHotEncoder

    data = fetch_openml(name='tic-tac-toe', version=1, as_frame=False, parser='auto')
    X_cat = data.data
    y_str = data.target

    # One-hot encode categorical features (board cells: x, o, b)
    encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
    X = encoder.fit_transform(X_cat)

    # Encode target: 'positive' (X wins) -> 1, 'negative' -> 0
    y = LabelEncoder().fit_transform(y_str)

    n_features = X.shape[1]
    return _prepare_dataset_from_arrays(X, y, f"TicTacToe ({n_features} one-hot features)", seed)


def load_vote(seed=42):
    """Load Congressional Voting Records dataset from UCI.

    435 samples × 16 categorical features (votes), binary classification.
    Classify US Congress members as Democrat or Republican based on voting record.
    All features are categorical (y, n, ?) - ideal for TM.
    """
    from sklearn.datasets import fetch_openml
    from sklearn.preprocessing import LabelEncoder, OneHotEncoder

    data = fetch_openml(name='vote', version=1, as_frame=False, parser='auto')
    X_cat = data.data
    y_str = data.target

    # One-hot encode categorical features (votes: y, n, ?)
    encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
    X = encoder.fit_transform(X_cat)

    # Encode target: 'democrat' -> 0, 'republican' -> 1
    y = LabelEncoder().fit_transform(y_str)

    n_features = X.shape[1]
    return _prepare_dataset_from_arrays(X, y, f"Vote ({n_features} one-hot features)", seed)


def load_german_credit(seed=42):
    """Load German Credit dataset from OpenML.

    1,000 samples × 20 mixed features, binary classification.
    Classify credit applicants as good or bad risk.
    Mostly categorical features with some integer attributes.
    """
    from sklearn.datasets import fetch_openml
    from sklearn.preprocessing import LabelEncoder, OneHotEncoder
    import pandas as pd

    data = fetch_openml(name='credit-g', version=1, as_frame=True, parser='auto')
    X_df = data.data
    y_str = data.target

    # Identify categorical and numeric columns
    cat_cols = X_df.select_dtypes(include=['object', 'category']).columns.tolist()
    num_cols = X_df.select_dtypes(include=['int64', 'float64']).columns.tolist()

    # One-hot encode categorical columns
    if cat_cols:
        encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
        X_cat_encoded = encoder.fit_transform(X_df[cat_cols])
    else:
        X_cat_encoded = np.array([]).reshape(len(X_df), 0)

    # Keep numeric columns as-is (will be binarized later by _prepare_dataset_from_arrays)
    if num_cols:
        X_num = X_df[num_cols].values.astype(np.float64)
    else:
        X_num = np.array([]).reshape(len(X_df), 0)

    # Combine
    X = np.hstack([X_cat_encoded, X_num])

    # Encode target: 'good' -> 0, 'bad' -> 1
    y = LabelEncoder().fit_transform(y_str)

    n_features = X.shape[1]
    return _prepare_dataset_from_arrays(X, y, f"German Credit ({n_features} features)", seed)


def load_spect_heart(seed=42):
    """Load SPECT Heart dataset from OpenML.

    267 samples × 22 binary features, binary classification.
    Diagnose cardiac SPECT images as normal or abnormal.
    All features are already binary - ideal for TM (no preprocessing needed).
    """
    from sklearn.datasets import fetch_openml
    from sklearn.preprocessing import LabelEncoder

    data = fetch_openml(name='spect', version=1, as_frame=False, parser='auto')
    X = data.data.astype(np.float64)  # Already binary (0/1)
    y_str = data.target

    # Encode target
    y = LabelEncoder().fit_transform(y_str)

    return _prepare_dataset_from_arrays(X, y, "SPECT Heart (22 binary features)", seed)


def load_car(seed=42, binary_mode='acceptable'):
    """Load Car Evaluation dataset from OpenML.

    1,728 samples × 6 categorical features, originally 4-class.
    Evaluate car acceptability based on price, maintenance, etc.
    All features are categorical - ideal for TM.

    Args:
        seed: Random seed
        binary_mode: How to reduce to binary classification:
            - 'acceptable': (acc, good, vgood) vs unacc
            - 'good': (good, vgood) vs (unacc, acc)
    """
    from sklearn.datasets import fetch_openml
    from sklearn.preprocessing import OneHotEncoder

    data = fetch_openml(name='car', version=1, as_frame=False, parser='auto')
    X_cat = data.data
    y_str = data.target

    # One-hot encode categorical features
    encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
    X = encoder.fit_transform(X_cat)

    # Binary reduction based on mode
    if binary_mode == 'acceptable':
        # Acceptable vs unacceptable
        y = np.isin(y_str, ['acc', 'good', 'vgood']).astype(int)
    elif binary_mode == 'good':
        # High quality vs low/medium
        y = np.isin(y_str, ['good', 'vgood']).astype(int)
    else:
        raise ValueError(f"Unknown binary_mode: {binary_mode}")

    n_features = X.shape[1]
    return _prepare_dataset_from_arrays(X, y, f"Car-{binary_mode} ({n_features} one-hot)", seed)


# ==================== Oracle Compression ====================

def compute_clause_utilities(tm, X_val, y_val, X_train, y_train, logger=None):
    """
    Compute utility score for each clause based on validation set accuracy impact.

    Algorithm:
    For each clause k, compute:
      u_k = importance score based on how often clause k participates in
            correct vs incorrect predictions on validation set

    Higher utility = clause is more important for accurate predictions

    Args:
        tm: Trained TsetlinMachine
        X_val: Validation features
        y_val: Validation labels
        X_train: Training features (for baseline)
        y_train: Training labels (for baseline)
        logger: Optional logger

    Returns:
        utilities: Array of shape (n_clauses,) with utility scores
                  Higher = more important to keep
    """
    if logger:
        logger.log("Computing clause utilities on validation set...")

    # Get clause outputs on validation set
    O_val = tm.transform(X_val)
    y_pred_val = tm.predict(X_val)

    n_clauses = O_val.shape[1]
    utilities = np.zeros(n_clauses)

    # For each clause, compute how often it's active in correct vs incorrect predictions
    correct_mask = (y_pred_val == y_val)

    for k in range(n_clauses):
        # Count activations in correct predictions
        correct_active = np.sum(O_val[correct_mask, k] > 0)
        correct_total = np.sum(correct_mask)

        # Count activations in incorrect predictions
        incorrect_active = np.sum(O_val[~correct_mask, k] > 0)
        incorrect_total = np.sum(~correct_mask)

        # Utility = activation rate difference (correct - incorrect)
        # Higher = clause is more discriminative for correct predictions
        if correct_total > 0 and incorrect_total > 0:
            correct_rate = correct_active / correct_total
            incorrect_rate = incorrect_active / incorrect_total
            utilities[k] = correct_rate - incorrect_rate
        elif correct_total > 0:
            utilities[k] = correct_active / correct_total
        else:
            utilities[k] = 0.0

    # Normalize to [0, 1] range and convert to costs (higher utility = lower cost to delete)
    # Add small constant to avoid division by zero
    min_util = np.min(utilities)
    max_util = np.max(utilities)

    if max_util > min_util:
        utilities_norm = (utilities - min_util) / (max_util - min_util)
    else:
        utilities_norm = np.ones_like(utilities)

    # Convert to MaxSAT costs: high utility = high cost to delete = low weight for -z[k]
    # We want: utility=1 → cost=10, utility=0 → cost=1
    # Formula: cost = 1 + 9 * utility
    costs = 1.0 + 9.0 * utilities_norm

    if logger:
        logger.log(f"Clause utilities computed:")
        logger.log(f"  Min utility: {np.min(utilities):.4f}, Max: {np.max(utilities):.4f}")
        logger.log(f"  Mean cost: {np.mean(costs):.2f}, Std: {np.std(costs):.2f}")
        logger.log(f"  High-value clauses (cost>5): {np.sum(costs > 5)}")

    return costs


def compress_oracle_oneshot(O, y_oracle, timeout_seconds=600, logger=None):
    """
    One-shot oracle-based compression with timeout.

    Args:
        O: Clause outputs (n_samples, n_clauses)
        y_oracle: Oracle labels from TM.predict()
        timeout_seconds: MaxSAT solver timeout (default 10 min)
        logger: Optional ExperimentLogger for progress tracking

    Returns: (keep_indices, solve_time)
    """
    n, m = O.shape
    O = (O.astype(np.uint8) & 1)

    if logger:
        logger.log(f"Starting oracle compression: {n} samples, {m} clauses")

    start_time = time.time()

    vpool = IDPool()
    z = [vpool.id(f'z_{k}') for k in range(m)]

    wcnf = WCNF()

    # Soft: minimize clauses
    for k in range(m):
        wcnf.append([-z[k]], weight=1)

    # Hard: separate opposite-label pairs
    pos = np.where(y_oracle == 1)[0]
    neg = np.where(y_oracle == 0)[0]

    if logger:
        logger.log(f"Adding pairwise constraints: {len(pos)} × {len(neg)} = {len(pos) * len(neg):,}")

    for i in pos:
        for j in neg:
            diff = np.flatnonzero(O[i] ^ O[j])
            if diff.size > 0:
                wcnf.append([z[k] for k in diff])

    if logger:
        logger.log(f"Solving MaxSAT with RC2...")

    with RC2(wcnf, solver='g3') as rc2:
        model = rc2.compute()

    solve_time = time.time() - start_time

    if model is None:
        raise RuntimeError("MaxSAT returned UNSAT (should not happen)")

    keep = [k for k in range(m) if model[z[k]-1] > 0]

    if logger:
        logger.log(f"Compression complete: {m} → {len(keep)} clauses in {solve_time:.1f}s")

    return keep, solve_time


def compress_oracle_imli(O, y_oracle, n_partitions=16, clause_weights=None, logger=None):
    """
    IMLI (Incremental MaxSAT with Label Independence) compression.

    Splits dataset into partitions and solves incrementally, biasing
    toward clauses kept in previous partitions.

    Args:
        O: Clause outputs (n_samples, n_clauses)
        y_oracle: Oracle labels from TM.predict()
        n_partitions: Number of partitions (higher = faster, potentially worse quality)
        clause_weights: Optional array of shape (n_clauses,) with importance weights.
                       Higher weight = higher cost to delete = more likely to keep.
                       If None, uses uniform weights (all clauses equally important).
        logger: Optional ExperimentLogger

    Returns: (keep_indices, total_solve_time, verification_stats)
    """
    n, m = O.shape
    O = (O.astype(np.uint8) & 1)
    partition_size = n // n_partitions

    if logger:
        logger.log(f"IMLI p={n_partitions}, partition size: ~{partition_size} samples")
        if clause_weights is not None:
            logger.log(f"Using clause importance weighting (mean={np.mean(clause_weights):.2f})")

    keep_prev = []
    total_solve_time = 0.0

    for i in range(n_partitions):
        start = i * partition_size
        end = (i+1) * partition_size if i < n_partitions-1 else n
        O_part = O[start:end]
        y_part = y_oracle[start:end]

        if logger:
            logger.log(f"\nPartition {i+1}/{n_partitions}: samples [{start}, {end})")

        vpool = IDPool()
        z = [vpool.id(f'z_{k}') for k in range(m)]
        wcnf = WCNF()

        # Soft: prefer clauses kept in previous partitions, weighted by importance
        for k in range(m):
            # Determine weight: use clause_weights if provided, otherwise uniform
            weight = int(clause_weights[k]) if clause_weights is not None else 1
            wcnf.append([z[k]] if k in keep_prev else [-z[k]], weight=weight)

        # Hard: pairwise separation
        pos = np.where(y_part == 1)[0]
        neg = np.where(y_part == 0)[0]
        added = 0
        for pi in pos:
            for nj in neg:
                diff = np.flatnonzero(O_part[pi] ^ O_part[nj])
                if diff.size > 0:
                    wcnf.append([z[k] for k in diff])
                    added += 1

        if logger:
            logger.log(f"  Constraints: {added:,}, solving...")

        t0 = time.time()
        with RC2(wcnf, solver='g3') as rc2:
            model = rc2.compute()
        solve_time = time.time() - t0

        if model is None:
            raise RuntimeError(f"Partition {i+1} UNSAT!")

        keep_prev = [k for k in range(m) if model[z[k]-1] > 0]
        total_solve_time += solve_time

        if logger:
            logger.log(f"  Solved in {solve_time:.1f}s, kept {len(keep_prev)} clauses")

    # VERIFICATION: Check global separation property
    if logger:
        logger.log(f"\n=== Verifying Global Separation Property ===")

    O_compressed = O[:, keep_prev]
    pos_all = np.where(y_oracle == 1)[0]
    neg_all = np.where(y_oracle == 0)[0]

    violations = 0
    total_pairs = len(pos_all) * len(neg_all)

    for i in pos_all:
        for j in neg_all:
            if np.array_equal(O_compressed[i], O_compressed[j]):
                violations += 1

    verification_stats = {
        "total_pairs": total_pairs,
        "violations": violations,
        "violation_rate": violations / total_pairs if total_pairs > 0 else 0.0,
        "global_separation_preserved": (violations == 0)
    }

    if logger:
        logger.log(f"Total pos/neg pairs: {total_pairs:,}")
        logger.log(f"Separation violations: {violations:,} ({verification_stats['violation_rate']*100:.2f}%)")
        if violations == 0:
            logger.log(f"✓ Global separation PRESERVED (all pairs differ in ≥1 clause)")
        else:
            logger.log(f"⚠ Global separation VIOLATED ({violations} pairs are identical)")

    return keep_prev, total_solve_time, verification_stats


# ==================== Predictor Building ====================

def build_oracle_predictor(O_train, y_oracle, keep_indices):
    """
    Build pattern dictionary from compressed clauses.

    Returns: dict mapping pattern tuples to majority class
    """
    O_compressed = O_train[:, keep_indices]
    O_compressed = (O_compressed.astype(np.uint8) & 1)

    patterns = {}
    for i, pattern in enumerate(O_compressed):
        key = tuple(pattern)
        if key not in patterns:
            patterns[key] = []
        patterns[key].append(y_oracle[i])

    # Majority vote for each pattern
    pattern_dict = {}
    for key, labels in patterns.items():
        pattern_dict[key] = int(np.round(np.mean(labels)))

    return pattern_dict


def predict_with_patterns(O_test, keep_indices, pattern_dict, y_train_oracle=None, return_stats=False):
    """
    Predict using pattern dictionary with 1-NN Hamming fallback for unseen patterns.

    The predictor is fully self-contained: K clauses + pattern dict + 1-NN fallback.
    For unseen patterns, 1-NN finds the nearest training pattern in Hamming space
    over the compressed representation, making it significantly better than majority class.

    FALLBACK MECHANISM CHOICE (Nov 5, 2025):
    - Current: 1-NN Hamming (empirically 7% better on Spambase)
    - Alternative: Majority class (simpler claims, worse accuracy)
    - To switch: Replace lines 582-586 with: predictions[i] = majority_class
      where majority_class = int(np.argmax(np.bincount(y_train_oracle)))

    Args:
        O_test: Test clause outputs (n_test, n_clauses)
        keep_indices: Indices of kept clauses
        pattern_dict: Pattern dictionary from build_oracle_predictor()
        y_train_oracle: Training labels for computing majority class (used for stats only)
        return_stats: If True, also return fallback usage statistics

    Returns:
        If return_stats=False: predictions array
        If return_stats=True: (predictions, stats_dict)
            stats_dict contains:
                - n_exact_matches: Number of exact pattern matches
                - n_fallback: Number of 1-NN fallback predictions
                - fallback_rate: Fraction using fallback
    """
    O_compressed = O_test[:, keep_indices]
    O_compressed = (O_compressed.astype(np.uint8) & 1)

    predictions = np.zeros(len(O_test), dtype=int)
    exact_matches = 0
    fallback_used = 0

    # Convert pattern_dict to arrays for vectorized Hamming distance
    known_patterns = np.array([list(k) for k in pattern_dict.keys()])
    known_labels = np.array([pattern_dict[k] for k in pattern_dict.keys()])

    for i, pattern in enumerate(O_compressed):
        key = tuple(pattern)
        if key in pattern_dict:
            predictions[i] = pattern_dict[key]
            exact_matches += 1
        else:
            # 1-NN Hamming distance fallback on compressed representation
            distances = np.sum(known_patterns != pattern, axis=1)
            nearest = np.argmin(distances)
            predictions[i] = known_labels[nearest]
            fallback_used += 1

    if return_stats:
        stats = {
            "n_exact_matches": exact_matches,
            "n_fallback": fallback_used,
            "fallback_rate": fallback_used / len(O_test) if len(O_test) > 0 else 0.0
        }
        return predictions, stats
    else:
        return predictions


# ==================== Evaluation ====================

def evaluate_compression(tm, X_train, X_test, y_train, y_test, keep_indices,
                        pattern_dict, logger=None):
    """
    Comprehensive evaluation of compressed model.

    Returns: dict with all metrics including fallback statistics
    """
    O_train = tm.transform(X_train)
    O_test = tm.transform(X_test)
    y_oracle_train = tm.predict(X_train)
    y_oracle_test = tm.predict(X_test)

    # Compressed predictions with default class tracking
    y_pred_train, train_fallback_stats = predict_with_patterns(
        O_train, keep_indices, pattern_dict, y_train_oracle=y_oracle_train, return_stats=True
    )
    y_pred_test, test_fallback_stats = predict_with_patterns(
        O_test, keep_indices, pattern_dict, y_train_oracle=y_oracle_train, return_stats=True
    )

    # Original TM accuracy
    tm_train_acc = np.mean(y_oracle_train == y_train)
    tm_test_acc = np.mean(y_oracle_test == y_test)

    # Compressed model accuracy
    compressed_train_acc = np.mean(y_pred_train == y_train)
    compressed_test_acc = np.mean(y_pred_test == y_test)

    # Fidelity (match with original TM)
    train_fidelity = np.mean(y_pred_train == y_oracle_train)
    test_fidelity = np.mean(y_pred_test == y_oracle_test)

    # Compression ratio
    original_clauses = O_train.shape[1]
    compressed_clauses = len(keep_indices)
    compression_ratio = 1.0 - (compressed_clauses / original_clauses)

    result = {
        "original_clauses": original_clauses,
        "compressed_clauses": compressed_clauses,
        "compression_ratio": compression_ratio,
        "tm_train_acc": tm_train_acc,
        "tm_test_acc": tm_test_acc,
        "compressed_train_acc": compressed_train_acc,
        "compressed_test_acc": compressed_test_acc,
        "train_fidelity": train_fidelity,
        "test_fidelity": test_fidelity,
        "test_acc_delta": compressed_test_acc - tm_test_acc,
        "n_patterns": len(pattern_dict),
        "n_train": len(X_train),
        "n_test": len(X_test),
        "train_fallback": train_fallback_stats,
        "test_fallback": test_fallback_stats
    }

    if logger:
        logger.log("\n=== Evaluation Results ===")
        logger.log(f"Compression: {original_clauses} → {compressed_clauses} ({compression_ratio*100:.1f}%)")
        logger.log(f"TM accuracy: train={tm_train_acc:.4f}, test={tm_test_acc:.4f}")
        logger.log(f"Compressed accuracy: train={compressed_train_acc:.4f}, test={compressed_test_acc:.4f}")
        logger.log(f"Fidelity: train={train_fidelity:.4f}, test={test_fidelity:.4f}")
        logger.log(f"Test accuracy delta: {result['test_acc_delta']:+.4f}")
        logger.log(f"Patterns: {len(pattern_dict)}")
        logger.log(f"\n=== Prediction Statistics ===")
        logger.log(f"Train: {train_fallback_stats['n_exact_matches']} exact / {train_fallback_stats['n_fallback']} fallback ({train_fallback_stats['fallback_rate']*100:.1f}%)")
        logger.log(f"Test:  {test_fallback_stats['n_exact_matches']} exact / {test_fallback_stats['n_fallback']} fallback ({test_fallback_stats['fallback_rate']*100:.1f}%)")

    return result
