"""
Universal IMLI compression experiment runner.

Usage:
    uv run scripts/run_imli.py --dataset adult-uci --partitions 80
    uv run scripts/run_imli.py --dataset spambase -p 16 -s 42
    uv run scripts/run_imli.py --dataset breast-cancer -p 32 --clauses 50

Results are organized as: results/<dataset>/p<partitions>/
"""

import argparse
import sys
import numpy as np
from pathlib import Path
from sklearn.model_selection import train_test_split

# Add experiments directory to path
sys.path.insert(0, str(Path(__file__).parent))

from utils import (
    ExperimentLogger,
    load_adult_uci,
    load_spambase,
    load_breast_cancer_binarized,
    load_iris_binary,
    load_iris_multiclass,
    load_wine_multiclass,
    load_phishing,
    load_banknote,
    load_pendigits,
    load_mnist_binary,
    load_higgs_100k,
    load_mushroom,
    load_magic,
    load_phoneme,
    load_electricity,
    load_kr_vs_kp,
    load_nursery,
    load_splice,
    load_connect4,
    load_tictactoe,
    load_vote,
    load_german_credit,
    load_spect_heart,
    load_car,
    train_tsetlin_machine,
    compress_oracle_imli,
    compute_clause_utilities,
    build_oracle_predictor,
    evaluate_compression
)


DATASETS = {
    'adult-uci': load_adult_uci,
    'spambase': load_spambase,
    'breast-cancer': load_breast_cancer_binarized,
    'iris-binary': load_iris_binary,
    'iris-multiclass': load_iris_multiclass,
    'wine-multiclass': load_wine_multiclass,
    'phishing': load_phishing,
    'banknote': load_banknote,
    'pendigits': load_pendigits,
    'mnist-binary': load_mnist_binary,
    'higgs-100k': load_higgs_100k,
    'mushroom': load_mushroom,
    'magic': load_magic,
    'phoneme': load_phoneme,
    'electricity': load_electricity,
    # New categorical datasets (ideal for TM)
    'kr-vs-kp': load_kr_vs_kp,
    'nursery': load_nursery,
    'splice': load_splice,
    'connect4': load_connect4,
    'tictactoe': load_tictactoe,
    'vote': load_vote,
    'german-credit': load_german_credit,
    'spect-heart': load_spect_heart,
    'car': load_car,
}


def main():
    parser = argparse.ArgumentParser(
        description='Universal IMLI compression experiment runner',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        '-d', '--dataset',
        type=str,
        required=True,
        choices=list(DATASETS.keys()),
        help='Dataset to use'
    )
    parser.add_argument(
        '-p', '--partitions',
        type=int,
        required=True,
        help='Number of IMLI partitions (higher = faster, lower quality)'
    )
    parser.add_argument(
        '-n', '--samples',
        type=int,
        default=None,
        help='Number of samples to use. Only applies to adult-uci (default: 10000).'
    )
    parser.add_argument(
        '-c', '--clauses',
        type=int,
        default=100,
        help='Number of clauses per class for TM training'
    )
    parser.add_argument(
        '-e', '--epochs',
        type=int,
        default=100,
        help='Number of training epochs'
    )
    parser.add_argument(
        '-s', '--seed',
        type=int,
        default=42,
        help='Random seed for reproducibility'
    )
    parser.add_argument(
        '--weighted',
        action='store_true',
        help='Use clause-importance weighting (requires validation split)'
    )
    parser.add_argument(
        '--results-dir',
        type=str,
        required=True,
        help='Directory to save results'
    )
    parser.add_argument(
        '--check',
        action='store_true',
        help='Only load dataset and verify setup, do not run experiment'
    )

    args = parser.parse_args()

    # Validate arguments
    if args.partitions <= 0:
        parser.error("--partitions must be a positive integer.")
    if args.clauses <= 0:
        parser.error("--clauses must be a positive integer.")
    if args.epochs <= 0:
        parser.error("--epochs must be a positive integer.")
    if args.samples is not None and args.samples <= 0:
        parser.error("--samples must be a positive integer.")

    # Warn if --samples is used with incompatible dataset
    if args.samples and args.dataset != 'adult-uci':
        print(f"Warning: --samples argument is only supported for 'adult-uci' and will be ignored for '{args.dataset}'.", file=sys.stderr)

    # Setup logger
    logger = None
    weighted_str = "_w" if args.weighted else ""
    exp_name = f"imli_{args.dataset}_p{args.partitions}_s{args.seed}{weighted_str}"
    if args.samples:
        exp_name += f"_n{args.samples}"

    # Results directory
    results_dir = Path(args.results_dir)

    try:
        logger = ExperimentLogger(exp_name, results_dir=results_dir)
        logger.log(f"IMLI Compression Experiment")
        logger.log(f"  Dataset: {args.dataset}")
        logger.log(f"  Partitions: {args.partitions}")
        if args.samples:
            logger.log(f"  Samples: {args.samples}")
        logger.log(f"  Clauses: {args.clauses} per class")
        logger.log(f"  Epochs: {args.epochs}")
        logger.log(f"  Seed: {args.seed}")
        logger.log("")

        # Load dataset
        logger.log(f"Loading {args.dataset} dataset...")
        loader_fn = DATASETS[args.dataset]

        # Handle datasets with different signatures
        if args.dataset == 'adult-uci':
            subsample = args.samples if args.samples else 10000
            X_train, X_test, y_train, y_test, name = loader_fn(
                seed=args.seed,
                subsample=subsample
            )
        elif args.dataset == 'higgs-100k':
            subsample = args.samples if args.samples else 100000
            X_train, X_test, y_train, y_test, name = loader_fn(
                seed=args.seed,
                subsample=subsample
            )
        else:
            X_train, X_test, y_train, y_test, name = loader_fn(seed=args.seed)

        logger.log(f"Dataset: {name}")
        logger.log(f"  Train: {len(X_train)} samples, {X_train.shape[1]} features")
        logger.log(f"  Test: {len(X_test)} samples")
        logger.log(f"  Class distribution: {np.bincount(y_train)}")

        # Check mode: just verify setup and exit
        if args.check:
            logger.log(f"\n✅ CHECK PASSED: Dataset '{args.dataset}' loaded successfully")
            logger.log(f"  Ready to run with p={args.partitions}, seed={args.seed}")
            logger.finish("check_passed")
            return 0

        # Handle weighted compression: split train into train/val
        if args.weighted:
            logger.log(f"\n⚙️  Weighted compression mode enabled")
            logger.log(f"Splitting training set: 80% train, 20% validation")
            X_train_sub, X_val, y_train_sub, y_val = train_test_split(
                X_train, y_train,
                test_size=0.2,
                random_state=args.seed,
                stratify=y_train
            )
            logger.log(f"  Train subset: {len(X_train_sub)} samples")
            logger.log(f"  Validation: {len(X_val)} samples")
            X_train_for_tm = X_train_sub
            y_train_for_tm = y_train_sub
        else:
            X_train_for_tm = X_train
            y_train_for_tm = y_train
            X_val = None
            y_val = None

        # Train Tsetlin Machine
        logger.log("\nTraining Tsetlin Machine...")
        tm, train_time = train_tsetlin_machine(
            X_train_for_tm, y_train_for_tm,
            n_clauses=args.clauses,
            epochs=args.epochs,
            seed=args.seed,
            logger=logger
        )

        # Get oracle outputs (always on FULL training set for compatibility)
        O_train = tm.transform(X_train)
        O_test = tm.transform(X_test)
        y_oracle_train = tm.predict(X_train)
        y_oracle_test = tm.predict(X_test)

        tm_train_acc = np.mean(y_oracle_train == y_train)
        tm_test_acc = np.mean(y_oracle_test == y_test)

        logger.log(f"\nTM: {O_train.shape[1]} total clauses")
        logger.log(f"  Train acc: {tm_train_acc:.4f}")
        logger.log(f"  Test acc: {tm_test_acc:.4f}")

        # Compute clause weights (if weighted mode)
        clause_weights = None
        if args.weighted:
            logger.log(f"\n{'='*60}")
            logger.log(f"Computing Clause Importance Weights")
            logger.log(f"{'='*60}")
            clause_weights = compute_clause_utilities(
                tm, X_val, y_val, X_train_for_tm, y_train_for_tm,
                logger=logger
            )

        # IMLI compression
        logger.log(f"\n{'='*60}")
        mode_str = "Weighted " if args.weighted else ""
        logger.log(f"{mode_str}IMLI Compression (p={args.partitions})")
        logger.log(f"{'='*60}")

        keep, solve_time, verification_stats = compress_oracle_imli(
            O_train, y_oracle_train,
            n_partitions=args.partitions,
            clause_weights=clause_weights,
            logger=logger
        )

        # Evaluate
        logger.log("\nEvaluating compressed model...")
        patterns = build_oracle_predictor(O_train, y_oracle_train, keep)
        result = evaluate_compression(
            tm, X_train, X_test, y_train, y_test,
            keep, patterns, logger
        )

        # Add metadata including verification stats
        result.update({
            "solve_time": solve_time,
            "train_time": train_time,
            "n_partitions": args.partitions,
            "n_clauses": args.clauses,
            "epochs": args.epochs,
            "seed": args.seed,
            "dataset": args.dataset,
            "dataset_name": name,
            "weighted": args.weighted,
            "verification": verification_stats
        })

        if args.samples:
            result["n_samples"] = args.samples

        logger.add_result(result)

        # Summary
        logger.log(f"\n{'='*60}")
        logger.log(f"✅ EXPERIMENT COMPLETE")
        logger.log(f"{'='*60}")
        logger.log(f"Dataset: {name}")

        # Safe compression summary (handle zero clauses edge case)
        total_clauses = O_train.shape[1]
        if total_clauses > 0:
            compression_str = f"{total_clauses} → {len(keep)} clauses ({result['compression_ratio']*100:.1f}%)"
        else:
            compression_str = "0 clauses, no compression possible"
        logger.log(f"Compression: {compression_str}")
        logger.log(f"Solve time: {solve_time:.1f}s")
        logger.log(f"Train time: {train_time:.1f}s")
        logger.log(f"Test fidelity: {result['test_fidelity']*100:.2f}%")
        logger.log(f"Test acc delta: {result['test_acc_delta']:+.4f}")
        logger.log(f"Patterns: {len(patterns)}")
        logger.log(f"Test fallback rate: {result['test_fallback']['fallback_rate']*100:.1f}%")
        logger.log(f"Global separation: {'✓ PRESERVED' if verification_stats['global_separation_preserved'] else '⚠ VIOLATED'}")

        logger.finish("completed")
        return 0

    except Exception as e:
        if logger:
            logger.log(f"❌ ERROR: {e}", level="ERROR")
            import traceback
            logger.log(traceback.format_exc(), level="ERROR")
            logger.finish("failed")
        else:
            print(f"FATAL: Failed to initialize logger or early error. Error: {e}", file=sys.stderr)
            import traceback
            traceback.print_exc()
        return 1


if __name__ == "__main__":
    sys.exit(main())
