"""
Greedy Importance-Only Pruning - Keep Top-K clauses by importance score (no MaxSAT).

This baseline isolates whether the MaxSAT structure adds value beyond simple
importance-based ranking. Clauses are ranked by their utility score (computed
on a validation set) and the top-K are kept.

Usage:
    uv run scripts/run_greedy_pruning.py -d breast-cancer -k 20 -s 42
    uv run scripts/run_greedy_pruning.py -d spambase -k 10 -s 42

Results are organized as: results/<dataset>/greedy-pruning/
"""

import argparse
import sys
import numpy as np
from pathlib import Path
from sklearn.model_selection import train_test_split

# Add experiments directory to path
sys.path.insert(0, str(Path(__file__).parent))

from utils import (
    ExperimentLogger,
    load_adult_uci,
    load_spambase,
    load_breast_cancer_binarized,
    load_iris_binary,
    load_iris_multiclass,
    load_wine_multiclass,
    load_phishing,
    load_banknote,
    load_mnist_binary,
    load_higgs_100k,
    load_mushroom,
    load_magic,
    load_electricity,
    load_kr_vs_kp,
    load_nursery,
    load_tictactoe,
    load_spect_heart,
    load_car,
    train_tsetlin_machine,
    build_oracle_predictor,
    predict_with_patterns
)


DATASETS = {
    'adult-uci': load_adult_uci,
    'spambase': load_spambase,
    'breast-cancer': load_breast_cancer_binarized,
    'iris-binary': load_iris_binary,
    'iris-multiclass': load_iris_multiclass,
    'wine-multiclass': load_wine_multiclass,
    'phishing': load_phishing,
    'banknote': load_banknote,
    'mnist-binary': load_mnist_binary,
    'higgs-100k': load_higgs_100k,
    'mushroom': load_mushroom,
    'magic': load_magic,
    'electricity': load_electricity,
    'kr-vs-kp': load_kr_vs_kp,
    'nursery': load_nursery,
    'tictactoe': load_tictactoe,
    'spect-heart': load_spect_heart,
    'car': load_car,
}


def compute_clause_importance(tm, X_val, y_val):
    """
    Compute importance score for each clause based on validation set.

    Higher score = clause is more discriminative for correct predictions.

    Returns:
        importance: Array of shape (n_clauses,) with importance scores
    """
    O_val = tm.transform(X_val)
    y_pred_val = tm.predict(X_val)

    n_clauses = O_val.shape[1]
    importance = np.zeros(n_clauses)

    correct_mask = (y_pred_val == y_val)

    for k in range(n_clauses):
        # Count activations in correct vs incorrect predictions
        correct_active = np.sum(O_val[correct_mask, k] > 0)
        correct_total = np.sum(correct_mask)

        incorrect_active = np.sum(O_val[~correct_mask, k] > 0)
        incorrect_total = np.sum(~correct_mask)

        # Importance = activation rate difference
        if correct_total > 0 and incorrect_total > 0:
            correct_rate = correct_active / correct_total
            incorrect_rate = incorrect_active / incorrect_total
            importance[k] = correct_rate - incorrect_rate
        elif correct_total > 0:
            importance[k] = correct_active / correct_total
        else:
            importance[k] = 0.0

    return importance


def main():
    parser = argparse.ArgumentParser(
        description='Greedy Importance-Only Pruning - Keep Top-K clauses',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        '-d', '--dataset',
        type=str,
        required=True,
        choices=list(DATASETS.keys()),
        help='Dataset to use'
    )
    parser.add_argument(
        '-k', '--keep',
        type=int,
        required=True,
        help='Number of clauses to keep (K)'
    )
    parser.add_argument(
        '-n', '--samples',
        type=int,
        default=None,
        help='Number of samples to use. Only applies to adult-uci.'
    )
    parser.add_argument(
        '-c', '--clauses',
        type=int,
        default=100,
        help='Number of clauses per class for TM training'
    )
    parser.add_argument(
        '-e', '--epochs',
        type=int,
        default=100,
        help='Number of training epochs'
    )
    parser.add_argument(
        '-s', '--seed',
        type=int,
        default=42,
        help='Random seed for reproducibility'
    )
    parser.add_argument(
        '--results-dir',
        type=str,
        required=True,
        help='Directory to save results'
    )

    args = parser.parse_args()

    # Validate arguments
    if args.keep <= 0:
        parser.error("--keep must be a positive integer.")
    if args.clauses <= 0:
        parser.error("--clauses must be a positive integer.")

    # Setup logger
    exp_name = f"greedy_{args.dataset}_k{args.keep}_s{args.seed}"
    if args.samples:
        exp_name += f"_n{args.samples}"

    results_dir = Path(args.results_dir)

    try:
        logger = ExperimentLogger(exp_name, results_dir=results_dir)
        logger.log(f"Greedy Importance-Only Pruning")
        logger.log(f"  Dataset: {args.dataset}")
        logger.log(f"  Keep K clauses: {args.keep}")
        logger.log(f"  Clauses per class: {args.clauses}")
        logger.log(f"  Seed: {args.seed}")
        logger.log("")

        # Load dataset
        logger.log(f"Loading {args.dataset} dataset...")
        loader_fn = DATASETS[args.dataset]

        if args.dataset == 'adult-uci':
            subsample = args.samples if args.samples else 10000
            X_train, X_test, y_train, y_test, name = loader_fn(
                seed=args.seed,
                subsample=subsample
            )
        elif args.dataset == 'higgs-100k':
            subsample = args.samples if args.samples else 100000
            X_train, X_test, y_train, y_test, name = loader_fn(
                seed=args.seed,
                subsample=subsample
            )
        else:
            X_train, X_test, y_train, y_test, name = loader_fn(seed=args.seed)

        logger.log(f"Dataset: {name}")
        logger.log(f"  Train: {len(X_train)} samples, {X_train.shape[1]} features")
        logger.log(f"  Test: {len(X_test)} samples")

        # Split train into train/val for importance computation
        logger.log(f"\nSplitting training set: 80% train, 20% validation")
        X_train_sub, X_val, y_train_sub, y_val = train_test_split(
            X_train, y_train,
            test_size=0.2,
            random_state=args.seed,
            stratify=y_train
        )
        logger.log(f"  Train subset: {len(X_train_sub)} samples")
        logger.log(f"  Validation: {len(X_val)} samples")

        # Train TM on subset
        logger.log("\nTraining Tsetlin Machine...")
        tm, train_time = train_tsetlin_machine(
            X_train_sub, y_train_sub,
            n_clauses=args.clauses,
            epochs=args.epochs,
            seed=args.seed,
            logger=logger
        )

        # Get clause outputs on full training set (for predictor building)
        O_train = tm.transform(X_train)
        O_test = tm.transform(X_test)
        y_oracle_train = tm.predict(X_train)
        y_oracle_test = tm.predict(X_test)

        total_clauses = O_train.shape[1]
        tm_train_acc = np.mean(y_oracle_train == y_train)
        tm_test_acc = np.mean(y_oracle_test == y_test)

        logger.log(f"TM: {total_clauses} total clauses")
        logger.log(f"  Train acc: {tm_train_acc:.4f}")
        logger.log(f"  Test acc: {tm_test_acc:.4f}")

        # Compute clause importance on validation set
        logger.log("\nComputing clause importance scores...")
        importance = compute_clause_importance(tm, X_val, y_val)

        logger.log(f"  Importance range: [{np.min(importance):.4f}, {np.max(importance):.4f}]")
        logger.log(f"  Mean importance: {np.mean(importance):.4f}")

        # Greedy selection: keep top-K by importance
        if args.keep >= total_clauses:
            keep = list(range(total_clauses))
            logger.log(f"\nWarning: K >= total clauses, keeping all {total_clauses}")
        else:
            # Sort by importance (descending) and take top K
            ranked_indices = np.argsort(importance)[::-1]
            keep = sorted(ranked_indices[:args.keep].tolist())
            logger.log(f"\nGreedy selection: keeping top {len(keep)} of {total_clauses} clauses")
            logger.log(f"  Top-K importance range: [{importance[ranked_indices[args.keep-1]]:.4f}, {importance[ranked_indices[0]]:.4f}]")

        # Build predictor with kept clauses
        patterns = build_oracle_predictor(O_train, y_oracle_train, keep)

        # Evaluate
        y_pred_train, train_stats = predict_with_patterns(
            O_train, keep, patterns, y_train_oracle=y_oracle_train, return_stats=True
        )
        y_pred_test, test_stats = predict_with_patterns(
            O_test, keep, patterns, y_train_oracle=y_oracle_train, return_stats=True
        )

        # Metrics
        compressed_train_acc = np.mean(y_pred_train == y_train)
        compressed_test_acc = np.mean(y_pred_test == y_test)
        train_fidelity = np.mean(y_pred_train == y_oracle_train)
        test_fidelity = np.mean(y_pred_test == y_oracle_test)
        compression_ratio = 1.0 - (len(keep) / total_clauses)

        result = {
            "method": "greedy_pruning",
            "dataset": args.dataset,
            "dataset_name": name,
            "original_clauses": total_clauses,
            "kept_clauses": len(keep),
            "k": args.keep,
            "compression_ratio": compression_ratio,
            "tm_train_acc": tm_train_acc,
            "tm_test_acc": tm_test_acc,
            "compressed_train_acc": compressed_train_acc,
            "compressed_test_acc": compressed_test_acc,
            "train_fidelity": train_fidelity,
            "test_fidelity": test_fidelity,
            "test_acc_delta": compressed_test_acc - tm_test_acc,
            "n_patterns": len(patterns),
            "train_fallback": train_stats,
            "test_fallback": test_stats,
            "seed": args.seed,
            "train_time": train_time,
            "n_train": len(X_train),
            "n_test": len(X_test),
            "importance_stats": {
                "min": float(np.min(importance)),
                "max": float(np.max(importance)),
                "mean": float(np.mean(importance)),
                "kept_min": float(np.min(importance[keep])),
                "kept_max": float(np.max(importance[keep]))
            }
        }

        logger.add_result(result)

        # Summary
        logger.log(f"\n{'='*60}")
        logger.log(f"GREEDY PRUNING COMPLETE")
        logger.log(f"{'='*60}")
        logger.log(f"Dataset: {name}")
        logger.log(f"Compression: {total_clauses} -> {len(keep)} clauses ({compression_ratio*100:.1f}%)")
        logger.log(f"Test fidelity: {test_fidelity*100:.2f}%")
        logger.log(f"Test accuracy: {compressed_test_acc:.4f} (delta: {result['test_acc_delta']:+.4f})")
        logger.log(f"Patterns: {len(patterns)}")
        logger.log(f"Test fallback rate: {test_stats['fallback_rate']*100:.1f}%")

        logger.finish("completed")
        return 0

    except Exception as e:
        if logger:
            logger.log(f"ERROR: {e}", level="ERROR")
            import traceback
            logger.log(traceback.format_exc(), level="ERROR")
            logger.finish("failed")
        else:
            print(f"FATAL: {e}", file=sys.stderr)
            import traceback
            traceback.print_exc()
        return 1


if __name__ == "__main__":
    sys.exit(main())
