"""
Knowledge Distillation Baseline - Train small TM on teacher (oracle) labels.

This is the standard knowledge distillation approach: train a small student TM
directly on the predictions of the large teacher TM, rather than on ground truth.

Comparison: Our MaxSAT method is "structured distillation" that preserves clause
structure, while this baseline loses that structure.

Usage:
    uv run scripts/run_knowledge_distillation.py -d breast-cancer -c 10 -s 42
    uv run scripts/run_knowledge_distillation.py -d spambase -c 20 -s 42

Results are organized as: results/<dataset>/knowledge-distillation/
"""

import argparse
import sys
import numpy as np
from pathlib import Path

# Add experiments directory to path
sys.path.insert(0, str(Path(__file__).parent))

from utils import (
    ExperimentLogger,
    load_adult_uci,
    load_spambase,
    load_breast_cancer_binarized,
    load_iris_binary,
    load_iris_multiclass,
    load_wine_multiclass,
    load_phishing,
    load_banknote,
    load_mnist_binary,
    load_higgs_100k,
    load_mushroom,
    load_magic,
    load_electricity,
    load_kr_vs_kp,
    load_nursery,
    load_tictactoe,
    load_spect_heart,
    load_car,
    train_tsetlin_machine
)


DATASETS = {
    'adult-uci': load_adult_uci,
    'spambase': load_spambase,
    'breast-cancer': load_breast_cancer_binarized,
    'iris-binary': load_iris_binary,
    'iris-multiclass': load_iris_multiclass,
    'wine-multiclass': load_wine_multiclass,
    'phishing': load_phishing,
    'banknote': load_banknote,
    'mnist-binary': load_mnist_binary,
    'higgs-100k': load_higgs_100k,
    'mushroom': load_mushroom,
    'magic': load_magic,
    'electricity': load_electricity,
    'kr-vs-kp': load_kr_vs_kp,
    'nursery': load_nursery,
    'tictactoe': load_tictactoe,
    'spect-heart': load_spect_heart,
    'car': load_car,
}


def main():
    parser = argparse.ArgumentParser(
        description='Knowledge Distillation Baseline - Train small TM on teacher labels',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        '-d', '--dataset',
        type=str,
        required=True,
        choices=list(DATASETS.keys()),
        help='Dataset to use'
    )
    parser.add_argument(
        '-c', '--clauses',
        type=int,
        required=True,
        help='Number of clauses per class for student TM'
    )
    parser.add_argument(
        '--teacher-clauses',
        type=int,
        default=100,
        help='Number of clauses per class for teacher TM'
    )
    parser.add_argument(
        '-n', '--samples',
        type=int,
        default=None,
        help='Number of samples to use. Only applies to adult-uci.'
    )
    parser.add_argument(
        '-e', '--epochs',
        type=int,
        default=100,
        help='Number of training epochs (for both teacher and student)'
    )
    parser.add_argument(
        '-s', '--seed',
        type=int,
        default=42,
        help='Random seed for reproducibility'
    )
    parser.add_argument(
        '--results-dir',
        type=str,
        required=True,
        help='Directory to save results'
    )

    args = parser.parse_args()

    # Validate arguments
    if args.clauses <= 0:
        parser.error("--clauses must be a positive integer.")
    if args.teacher_clauses <= 0:
        parser.error("--teacher-clauses must be a positive integer.")

    # Setup logger
    exp_name = f"kd_{args.dataset}_c{args.clauses}_s{args.seed}"
    if args.samples:
        exp_name += f"_n{args.samples}"

    results_dir = Path(args.results_dir)

    try:
        logger = ExperimentLogger(exp_name, results_dir=results_dir)
        logger.log(f"Knowledge Distillation Baseline")
        logger.log(f"  Dataset: {args.dataset}")
        logger.log(f"  Teacher clauses per class: {args.teacher_clauses}")
        logger.log(f"  Student clauses per class: {args.clauses}")
        logger.log(f"  Epochs: {args.epochs}")
        logger.log(f"  Seed: {args.seed}")
        logger.log("")

        # Load dataset
        logger.log(f"Loading {args.dataset} dataset...")
        loader_fn = DATASETS[args.dataset]

        if args.dataset == 'adult-uci':
            subsample = args.samples if args.samples else 10000
            X_train, X_test, y_train, y_test, name = loader_fn(
                seed=args.seed,
                subsample=subsample
            )
        elif args.dataset == 'higgs-100k':
            subsample = args.samples if args.samples else 100000
            X_train, X_test, y_train, y_test, name = loader_fn(
                seed=args.seed,
                subsample=subsample
            )
        else:
            X_train, X_test, y_train, y_test, name = loader_fn(seed=args.seed)

        logger.log(f"Dataset: {name}")
        logger.log(f"  Train: {len(X_train)} samples, {X_train.shape[1]} features")
        logger.log(f"  Test: {len(X_test)} samples")
        logger.log(f"  Class distribution: {np.bincount(y_train)}")

        n_classes = len(np.unique(y_train))

        # Step 1: Train teacher TM
        logger.log(f"\n{'='*60}")
        logger.log("Step 1: Training Teacher TM")
        logger.log(f"{'='*60}")

        teacher_tm, teacher_train_time = train_tsetlin_machine(
            X_train, y_train,
            n_clauses=args.teacher_clauses,
            epochs=args.epochs,
            seed=args.seed,
            logger=logger
        )

        # Teacher predictions
        teacher_pred_train = teacher_tm.predict(X_train)
        teacher_pred_test = teacher_tm.predict(X_test)

        teacher_train_acc = np.mean(teacher_pred_train == y_train)
        teacher_test_acc = np.mean(teacher_pred_test == y_test)

        logger.log(f"\nTeacher TM: {args.teacher_clauses * n_classes} total clauses")
        logger.log(f"  Train accuracy: {teacher_train_acc:.4f}")
        logger.log(f"  Test accuracy: {teacher_test_acc:.4f}")

        # Step 2: Train student TM on teacher predictions (distillation)
        logger.log(f"\n{'='*60}")
        logger.log("Step 2: Training Student TM on Teacher Labels (Distillation)")
        logger.log(f"{'='*60}")

        # Use different seed for student to avoid correlation
        student_seed = args.seed + 1000

        student_tm, student_train_time = train_tsetlin_machine(
            X_train, teacher_pred_train,  # Key: train on teacher predictions, not ground truth
            n_clauses=args.clauses,
            epochs=args.epochs,
            seed=student_seed,
            logger=logger
        )

        # Student predictions
        student_pred_train = student_tm.predict(X_train)
        student_pred_test = student_tm.predict(X_test)

        # Fidelity: how well student matches teacher
        train_fidelity = np.mean(student_pred_train == teacher_pred_train)
        test_fidelity = np.mean(student_pred_test == teacher_pred_test)

        # Accuracy: student vs ground truth
        student_train_acc = np.mean(student_pred_train == y_train)
        student_test_acc = np.mean(student_pred_test == y_test)

        logger.log(f"\nStudent TM: {args.clauses * n_classes} total clauses")
        logger.log(f"  Train fidelity (vs teacher): {train_fidelity:.4f}")
        logger.log(f"  Test fidelity (vs teacher): {test_fidelity:.4f}")
        logger.log(f"  Train accuracy (vs ground truth): {student_train_acc:.4f}")
        logger.log(f"  Test accuracy (vs ground truth): {student_test_acc:.4f}")

        # Collect results
        result = {
            "method": "knowledge_distillation",
            "dataset": args.dataset,
            "dataset_name": name,
            "teacher_clauses_per_class": args.teacher_clauses,
            "teacher_total_clauses": args.teacher_clauses * n_classes,
            "student_clauses_per_class": args.clauses,
            "student_total_clauses": args.clauses * n_classes,
            "n_classes": n_classes,
            "compression_ratio": 1.0 - (args.clauses / args.teacher_clauses),
            "teacher_train_acc": teacher_train_acc,
            "teacher_test_acc": teacher_test_acc,
            "student_train_acc": student_train_acc,
            "student_test_acc": student_test_acc,
            "train_fidelity": train_fidelity,
            "test_fidelity": test_fidelity,
            "test_acc_delta": student_test_acc - teacher_test_acc,
            "teacher_train_time": teacher_train_time,
            "student_train_time": student_train_time,
            "total_train_time": teacher_train_time + student_train_time,
            "seed": args.seed,
            "epochs": args.epochs,
            "n_train": len(X_train),
            "n_test": len(X_test)
        }

        if args.samples:
            result["n_samples"] = args.samples

        logger.add_result(result)

        # Summary
        logger.log(f"\n{'='*60}")
        logger.log(f"KNOWLEDGE DISTILLATION COMPLETE")
        logger.log(f"{'='*60}")
        logger.log(f"Dataset: {name}")
        logger.log(f"Teacher: {args.teacher_clauses * n_classes} clauses, test acc: {teacher_test_acc:.4f}")
        logger.log(f"Student: {args.clauses * n_classes} clauses, test acc: {student_test_acc:.4f}")
        logger.log(f"Compression: {result['compression_ratio']*100:.1f}%")
        logger.log(f"Test fidelity: {test_fidelity*100:.2f}%")
        logger.log(f"Test accuracy delta: {result['test_acc_delta']:+.4f}")
        logger.log(f"Total train time: {result['total_train_time']:.1f}s")

        logger.finish("completed")
        return 0

    except Exception as e:
        if logger:
            logger.log(f"ERROR: {e}", level="ERROR")
            import traceback
            logger.log(traceback.format_exc(), level="ERROR")
            logger.finish("failed")
        else:
            print(f"FATAL: {e}", file=sys.stderr)
            import traceback
            traceback.print_exc()
        return 1


if __name__ == "__main__":
    sys.exit(main())
