"""
Baseline TM experiment runner (no compression).

Trains Tsetlin Machines with various clause counts to establish
baseline performance for comparison with compressed models.

Usage:
    uv run scripts/run_baseline.py --dataset spambase --clauses 25
    uv run scripts/run_baseline.py --dataset adult-uci -c 50 -n 10000
    uv run scripts/run_baseline.py --dataset breast-cancer -c 75

Results are organized as: results/<dataset>/baseline/
"""

import argparse
import sys
import numpy as np
from pathlib import Path

# Add experiments directory to path
sys.path.insert(0, str(Path(__file__).parent))

from utils import (
    ExperimentLogger,
    load_adult_uci,
    load_spambase,
    load_breast_cancer_binarized,
    load_iris_binary,
    load_iris_multiclass,
    load_wine_multiclass,
    load_phishing,
    load_banknote,
    load_pendigits,
    load_mnist_binary,
    load_higgs_100k,
    load_mushroom,
    load_magic,
    load_phoneme,
    load_electricity,
    load_kr_vs_kp,
    load_nursery,
    load_splice,
    load_connect4,
    load_tictactoe,
    load_vote,
    load_german_credit,
    load_spect_heart,
    load_car,
    train_tsetlin_machine
)


DATASETS = {
    'adult-uci': load_adult_uci,
    'spambase': load_spambase,
    'breast-cancer': load_breast_cancer_binarized,
    'iris-binary': load_iris_binary,
    'iris-multiclass': load_iris_multiclass,
    'wine-multiclass': load_wine_multiclass,
    'phishing': load_phishing,
    'banknote': load_banknote,
    'pendigits': load_pendigits,
    'mnist-binary': load_mnist_binary,
    'higgs-100k': load_higgs_100k,
    'mushroom': load_mushroom,
    'magic': load_magic,
    'phoneme': load_phoneme,
    'electricity': load_electricity,
    # New categorical datasets (ideal for TM)
    'kr-vs-kp': load_kr_vs_kp,
    'nursery': load_nursery,
    'splice': load_splice,
    'connect4': load_connect4,
    'tictactoe': load_tictactoe,
    'vote': load_vote,
    'german-credit': load_german_credit,
    'spect-heart': load_spect_heart,
    'car': load_car,
}


def main():
    parser = argparse.ArgumentParser(
        description='Baseline TM experiment runner (no compression)',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        '-d', '--dataset',
        type=str,
        required=True,
        choices=list(DATASETS.keys()),
        help='Dataset to use'
    )
    parser.add_argument(
        '-c', '--clauses',
        type=int,
        required=True,
        help='Number of clauses per class for TM training'
    )
    parser.add_argument(
        '-n', '--samples',
        type=int,
        default=None,
        help='Number of samples to use. Only applies to adult-uci (default: 10000).'
    )
    parser.add_argument(
        '-e', '--epochs',
        type=int,
        default=100,
        help='Number of training epochs'
    )
    parser.add_argument(
        '-s', '--seed',
        type=int,
        default=42,
        help='Random seed for reproducibility'
    )
    parser.add_argument(
        '--results-dir',
        type=str,
        required=True,
        help='Directory to save results'
    )
    parser.add_argument(
        '--check',
        action='store_true',
        help='Only load dataset and verify setup, do not run experiment'
    )

    args = parser.parse_args()

    # Validate arguments
    if args.clauses <= 0:
        parser.error("--clauses must be a positive integer.")
    if args.epochs <= 0:
        parser.error("--epochs must be a positive integer.")
    if args.samples is not None and args.samples <= 0:
        parser.error("--samples must be a positive integer.")

    # Warn if --samples is used with incompatible dataset
    if args.samples and args.dataset != 'adult-uci':
        print(f"Warning: --samples argument is only supported for 'adult-uci' and will be ignored for '{args.dataset}'.", file=sys.stderr)

    # Setup logger
    logger = None
    exp_name = f"baseline_{args.dataset}_c{args.clauses}_s{args.seed}"
    if args.samples:
        exp_name += f"_n{args.samples}"

    results_dir = Path(args.results_dir)

    try:
        logger = ExperimentLogger(exp_name, results_dir=results_dir)
        logger.log(f"Baseline TM Experiment (No Compression)")
        logger.log(f"  Dataset: {args.dataset}")
        logger.log(f"  Clauses: {args.clauses} per class")
        if args.samples:
            logger.log(f"  Samples: {args.samples}")
        logger.log(f"  Epochs: {args.epochs}")
        logger.log(f"  Seed: {args.seed}")
        logger.log("")

        # Load dataset
        logger.log(f"Loading {args.dataset} dataset...")
        loader_fn = DATASETS[args.dataset]

        # Handle datasets with different signatures
        if args.dataset == 'adult-uci':
            subsample = args.samples if args.samples else 10000
            X_train, X_test, y_train, y_test, name = loader_fn(
                seed=args.seed,
                subsample=subsample
            )
        elif args.dataset == 'higgs-100k':
            subsample = args.samples if args.samples else 100000
            X_train, X_test, y_train, y_test, name = loader_fn(
                seed=args.seed,
                subsample=subsample
            )
        else:
            X_train, X_test, y_train, y_test, name = loader_fn(seed=args.seed)

        logger.log(f"Dataset: {name}")
        logger.log(f"  Train: {len(X_train)} samples, {X_train.shape[1]} features")
        logger.log(f"  Test: {len(X_test)} samples")
        logger.log(f"  Class distribution: {np.bincount(y_train)}")

        # Check mode: just verify setup and exit
        if args.check:
            logger.log(f"\n✅ CHECK PASSED: Dataset '{args.dataset}' loaded successfully")
            logger.log(f"  Ready to run with c={args.clauses}, seed={args.seed}")
            logger.finish("check_passed")
            return 0

        # Infer number of classes
        n_classes = len(np.unique(y_train))
        total_clauses = args.clauses * n_classes

        # Train Tsetlin Machine
        logger.log(f"\nTraining Tsetlin Machine...")
        logger.log(f"  Clauses per class: {args.clauses}")
        logger.log(f"  Total clauses: {total_clauses}")

        tm, train_time = train_tsetlin_machine(
            X_train, y_train,
            n_clauses=args.clauses,
            epochs=args.epochs,
            seed=args.seed,
            logger=logger
        )

        # Evaluate
        logger.log(f"\nEvaluating model...")
        y_pred_train = tm.predict(X_train)
        y_pred_test = tm.predict(X_test)

        train_acc = np.mean(y_pred_train == y_train)
        test_acc = np.mean(y_pred_test == y_test)

        logger.log(f"  Train accuracy: {train_acc:.4f}")
        logger.log(f"  Test accuracy: {test_acc:.4f}")

        # Collect results
        result = {
            "dataset": args.dataset,
            "dataset_name": name,
            "n_clauses_per_class": args.clauses,
            "total_clauses": total_clauses,
            "n_classes": n_classes,
            "epochs": args.epochs,
            "seed": args.seed,
            "train_time": train_time,
            "train_acc": train_acc,
            "test_acc": test_acc,
            "n_train": len(X_train),
            "n_test": len(X_test),
            "n_features": X_train.shape[1]
        }

        if args.samples:
            result["n_samples"] = args.samples

        logger.add_result(result)

        # Summary
        logger.log(f"\n{'='*60}")
        logger.log(f"✅ BASELINE EXPERIMENT COMPLETE")
        logger.log(f"{'='*60}")
        logger.log(f"Dataset: {name}")
        logger.log(f"Clauses: {args.clauses} per class ({total_clauses} total)")
        logger.log(f"Train time: {train_time:.1f}s")
        logger.log(f"Train accuracy: {train_acc:.4f}")
        logger.log(f"Test accuracy: {test_acc:.4f}")

        logger.finish("completed")
        return 0

    except Exception as e:
        if logger:
            logger.log(f"❌ ERROR: {e}", level="ERROR")
            import traceback
            logger.log(traceback.format_exc(), level="ERROR")
            logger.finish("failed")
        else:
            print(f"FATAL: Failed to initialize logger or early error. Error: {e}", file=sys.stderr)
            import traceback
            traceback.print_exc()
        return 1


if __name__ == "__main__":
    sys.exit(main())
