"""
Random Pruning Baseline - Keep K random clauses from trained TM.

This baseline establishes whether MaxSAT compression provides value beyond
random clause selection. If MaxSAT doesn't beat random pruning significantly,
the complexity isn't justified.

Usage:
    uv run scripts/run_random_pruning.py -d breast-cancer -k 20 -s 42
    uv run scripts/run_random_pruning.py -d spambase -k 10 -s 42 --rep 3

Results are organized as: results/<dataset>/random-pruning/
"""

import argparse
import sys
import numpy as np
from pathlib import Path

# Add experiments directory to path
sys.path.insert(0, str(Path(__file__).parent))

from utils import (
    ExperimentLogger,
    load_adult_uci,
    load_spambase,
    load_breast_cancer_binarized,
    load_iris_binary,
    load_iris_multiclass,
    load_wine_multiclass,
    load_phishing,
    load_banknote,
    load_mnist_binary,
    load_higgs_100k,
    load_mushroom,
    load_magic,
    load_electricity,
    load_kr_vs_kp,
    load_nursery,
    load_tictactoe,
    load_spect_heart,
    load_car,
    train_tsetlin_machine,
    build_oracle_predictor,
    predict_with_patterns
)


DATASETS = {
    'adult-uci': load_adult_uci,
    'spambase': load_spambase,
    'breast-cancer': load_breast_cancer_binarized,
    'iris-binary': load_iris_binary,
    'iris-multiclass': load_iris_multiclass,
    'wine-multiclass': load_wine_multiclass,
    'phishing': load_phishing,
    'banknote': load_banknote,
    'mnist-binary': load_mnist_binary,
    'higgs-100k': load_higgs_100k,
    'mushroom': load_mushroom,
    'magic': load_magic,
    'electricity': load_electricity,
    'kr-vs-kp': load_kr_vs_kp,
    'nursery': load_nursery,
    'tictactoe': load_tictactoe,
    'spect-heart': load_spect_heart,
    'car': load_car,
}


def main():
    parser = argparse.ArgumentParser(
        description='Random Pruning Baseline - Keep K random clauses',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        '-d', '--dataset',
        type=str,
        required=True,
        choices=list(DATASETS.keys()),
        help='Dataset to use'
    )
    parser.add_argument(
        '-k', '--keep',
        type=int,
        required=True,
        help='Number of clauses to keep (K)'
    )
    parser.add_argument(
        '-n', '--samples',
        type=int,
        default=None,
        help='Number of samples to use. Only applies to adult-uci.'
    )
    parser.add_argument(
        '-c', '--clauses',
        type=int,
        default=100,
        help='Number of clauses per class for TM training'
    )
    parser.add_argument(
        '-e', '--epochs',
        type=int,
        default=100,
        help='Number of training epochs'
    )
    parser.add_argument(
        '-s', '--seed',
        type=int,
        default=42,
        help='Random seed for reproducibility'
    )
    parser.add_argument(
        '--rep',
        type=int,
        default=1,
        help='Repetition number (for averaging over random selections)'
    )
    parser.add_argument(
        '--results-dir',
        type=str,
        required=True,
        help='Directory to save results'
    )

    args = parser.parse_args()

    # Validate arguments
    if args.keep <= 0:
        parser.error("--keep must be a positive integer.")
    if args.clauses <= 0:
        parser.error("--clauses must be a positive integer.")

    # Setup logger
    exp_name = f"random_{args.dataset}_k{args.keep}_s{args.seed}_r{args.rep}"
    if args.samples:
        exp_name += f"_n{args.samples}"

    results_dir = Path(args.results_dir)

    try:
        logger = ExperimentLogger(exp_name, results_dir=results_dir)
        logger.log(f"Random Pruning Baseline")
        logger.log(f"  Dataset: {args.dataset}")
        logger.log(f"  Keep K clauses: {args.keep}")
        logger.log(f"  Clauses per class: {args.clauses}")
        logger.log(f"  Seed: {args.seed}")
        logger.log(f"  Repetition: {args.rep}")
        logger.log("")

        # Load dataset
        logger.log(f"Loading {args.dataset} dataset...")
        loader_fn = DATASETS[args.dataset]

        if args.dataset == 'adult-uci':
            subsample = args.samples if args.samples else 10000
            X_train, X_test, y_train, y_test, name = loader_fn(
                seed=args.seed,
                subsample=subsample
            )
        elif args.dataset == 'higgs-100k':
            subsample = args.samples if args.samples else 100000
            X_train, X_test, y_train, y_test, name = loader_fn(
                seed=args.seed,
                subsample=subsample
            )
        else:
            X_train, X_test, y_train, y_test, name = loader_fn(seed=args.seed)

        logger.log(f"Dataset: {name}")
        logger.log(f"  Train: {len(X_train)} samples, {X_train.shape[1]} features")
        logger.log(f"  Test: {len(X_test)} samples")

        # Train TM
        logger.log("\nTraining Tsetlin Machine...")
        tm, train_time = train_tsetlin_machine(
            X_train, y_train,
            n_clauses=args.clauses,
            epochs=args.epochs,
            seed=args.seed,
            logger=logger
        )

        # Get clause outputs
        O_train = tm.transform(X_train)
        O_test = tm.transform(X_test)
        y_oracle_train = tm.predict(X_train)
        y_oracle_test = tm.predict(X_test)

        total_clauses = O_train.shape[1]
        tm_train_acc = np.mean(y_oracle_train == y_train)
        tm_test_acc = np.mean(y_oracle_test == y_test)

        logger.log(f"TM: {total_clauses} total clauses")
        logger.log(f"  Train acc: {tm_train_acc:.4f}")
        logger.log(f"  Test acc: {tm_test_acc:.4f}")

        # Random pruning: select K random clauses
        # Use different RNG for each repetition
        rng = np.random.RandomState(args.seed * 1000 + args.rep)

        if args.keep >= total_clauses:
            keep = list(range(total_clauses))
            logger.log(f"\nWarning: K >= total clauses, keeping all {total_clauses}")
        else:
            keep = sorted(rng.choice(total_clauses, args.keep, replace=False).tolist())
            logger.log(f"\nRandom selection: keeping {len(keep)} of {total_clauses} clauses")

        # Build predictor with kept clauses
        patterns = build_oracle_predictor(O_train, y_oracle_train, keep)

        # Evaluate
        y_pred_train, train_stats = predict_with_patterns(
            O_train, keep, patterns, y_train_oracle=y_oracle_train, return_stats=True
        )
        y_pred_test, test_stats = predict_with_patterns(
            O_test, keep, patterns, y_train_oracle=y_oracle_train, return_stats=True
        )

        # Metrics
        compressed_train_acc = np.mean(y_pred_train == y_train)
        compressed_test_acc = np.mean(y_pred_test == y_test)
        train_fidelity = np.mean(y_pred_train == y_oracle_train)
        test_fidelity = np.mean(y_pred_test == y_oracle_test)
        compression_ratio = 1.0 - (len(keep) / total_clauses)

        result = {
            "method": "random_pruning",
            "dataset": args.dataset,
            "dataset_name": name,
            "original_clauses": total_clauses,
            "kept_clauses": len(keep),
            "k": args.keep,
            "compression_ratio": compression_ratio,
            "tm_train_acc": tm_train_acc,
            "tm_test_acc": tm_test_acc,
            "compressed_train_acc": compressed_train_acc,
            "compressed_test_acc": compressed_test_acc,
            "train_fidelity": train_fidelity,
            "test_fidelity": test_fidelity,
            "test_acc_delta": compressed_test_acc - tm_test_acc,
            "n_patterns": len(patterns),
            "train_fallback": train_stats,
            "test_fallback": test_stats,
            "seed": args.seed,
            "rep": args.rep,
            "train_time": train_time,
            "n_train": len(X_train),
            "n_test": len(X_test)
        }

        logger.add_result(result)

        # Summary
        logger.log(f"\n{'='*60}")
        logger.log(f"RANDOM PRUNING COMPLETE")
        logger.log(f"{'='*60}")
        logger.log(f"Dataset: {name}")
        logger.log(f"Compression: {total_clauses} -> {len(keep)} clauses ({compression_ratio*100:.1f}%)")
        logger.log(f"Test fidelity: {test_fidelity*100:.2f}%")
        logger.log(f"Test accuracy: {compressed_test_acc:.4f} (delta: {result['test_acc_delta']:+.4f})")
        logger.log(f"Patterns: {len(patterns)}")
        logger.log(f"Test fallback rate: {test_stats['fallback_rate']*100:.1f}%")

        logger.finish("completed")
        return 0

    except Exception as e:
        if logger:
            logger.log(f"ERROR: {e}", level="ERROR")
            import traceback
            logger.log(traceback.format_exc(), level="ERROR")
            logger.finish("failed")
        else:
            print(f"FATAL: {e}", file=sys.stderr)
            import traceback
            traceback.print_exc()
        return 1


if __name__ == "__main__":
    sys.exit(main())
