import time
import pickle
import datetime
from pathlib import Path
from typing import Dict, List, Any, Tuple, Type

import numpy as np
import pandas as pd
import tqdm

from emm import MixtureModel, FlowMixtureModel, RemixMixtureModel, TrainingConfig

from .metrics import (
    adjusted_rand_index,
    normalized_mutual_information,
)
from emm.data_gen.mixture.mixture_gen import (
    compare_component_assignments,
)


def load_dataset_from_path(
    dataset_path: Path,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Any]:
    """
    Loads a synthetic dataset from a specified directory.

    Args:
        dataset_path: The path to the directory containing 'dataset.pkl'.

    Returns:
        A tuple containing X, Y, true_labels, and the dataset configuration.
    """
    dataset_path = Path(dataset_path)
    if not dataset_path.is_dir():
        raise FileNotFoundError(f"Dataset directory not found: {dataset_path}")

    pickle_path = dataset_path / "dataset.pkl"
    if not pickle_path.exists():
        raise FileNotFoundError(f"Could not find 'dataset.pkl' in {dataset_path}")

    with open(pickle_path, "rb") as f:
        dataset = pickle.load(f)

    return (
        dataset["X"],
        dataset["y"],
        dataset["component_labels"],
        dataset["config"],
    )


def run_single_test(dataset_path: Path, config: TrainingConfig) -> Dict[str, Any]:
    """
    Runs a single training and evaluation trial on a given dataset using the
    specified configuration.

    This function now directly uses the new class-based model API.

    Args:
        dataset_path: Path to the dataset directory.
        config: A TrainingConfig object specifying the model and hyperparameters.

    Returns:
        A dictionary containing the results and metrics of the trial.
    """
    # 1. Load Data
    X, Y, true_labels, dataset_config = load_dataset_from_path(dataset_path)
    dataset_name = dataset_path.name
    feature_names = [f"X{i}" for i in range(X.shape[1])]
    true_n_components = len(dataset_config.components)

    # 2. Handle Train/Test Split
    train_indices = dataset_config.train_indices
    test_indices = dataset_config.test_indices

    if train_indices is not None and test_indices is not None:
        X_train, Y_train = X[train_indices], Y[train_indices]
        X_test, Y_test = X[test_indices], Y[test_indices]
        true_labels_test = true_labels[test_indices]
    else:
        # If no split is defined, use the full dataset for training and testing
        X_train, Y_train = X, Y
        X_test, Y_test = X, Y
        true_labels_test = true_labels

    # 3. Select and Train the Model
    model_class: Type[MixtureModel] = (
        RemixMixtureModel if config.use_gmm_remix else FlowMixtureModel
    )
    model: MixtureModel
    search_history = None

    start_time = time.time()

    try:
        # The model's fit method now handles the component search internally
        if config.verbose:
            print(
                f"Instantiating and training {model_class.__name__} for {dataset_name}..."
            )
        model = model_class(config)
        model.fit(X_train, Y_train, feature_names=feature_names)

        # If history was requested, retrieve it from the model instance
        if config.model_finder_return_history:
            search_history = model.search_history

    except Exception as e:
        print(f"ERROR: Training failed for {dataset_name} with exception: {e}")
        import traceback

        traceback.print_exc()
        return {
            "dataset": dataset_name,
            "error": str(e),
            **config.to_dict(),
        }

    runtime = time.time() - start_time

    # 4. Evaluate the Trained Model
    if config.verbose:
        print(f"Evaluating model for {dataset_name}...")
    metrics = model.get_metrics()
    rule_complexity_metrics = model.get_rule_complexity_metrics()

    # Get component assignments for the entire dataset
    predicted_labels = model.get_labels(X)

    # Calculate clustering metrics on the full dataset
    accuracy, _ = compare_component_assignments(true_labels, predicted_labels)
    ari = adjusted_rand_index(true_labels, predicted_labels)
    nmi = normalized_mutual_information(true_labels, predicted_labels)

    # Calculate test set specific metrics if a split exists
    test_ari, test_nmi, test_accuracy, test_nll = None, None, None, None
    if test_indices is not None:
        predicted_labels_test = predicted_labels[test_indices]
        test_accuracy, _ = compare_component_assignments(
            true_labels_test, predicted_labels_test
        )
        test_ari = adjusted_rand_index(true_labels_test, predicted_labels_test)
        test_nmi = normalized_mutual_information(
            true_labels_test, predicted_labels_test
        )

        test_nll = model.get_nll(X_test, Y_test)

    # 5. Compile and Return Results
    results = {
        "dataset": dataset_name,
        "timestamp": datetime.datetime.now().isoformat(),
        "model_type": model_class.__name__,
        "accuracy": accuracy,
        "ARI": ari,
        "NMI": nmi,
        "test_ARI": test_ari,
        "test_NMI": test_nmi,
        "test_accuracy": test_accuracy,
        "final_nll": metrics.get("final_nll"),
        "test_nll": test_nll,
        "bic": metrics.get("bic"),
        "aic": metrics.get("aic"),
        "rules": model.rules_model.debug_print_cutpoints(
            scaler=model.preprocessor.scaler_x,
            simple_format=True,
            feature_names=feature_names,
        ),
        "n_rules": len([c for c in model.disabled_components if not c]),
        "n_features": X.shape[1],
        "n_noise_features": dataset_config.n_noise_features,
        "n_samples": X.shape[0],
        "true_components": true_n_components,
        "runtime_seconds": runtime,
        **rule_complexity_metrics,  # Add the new metrics here
        **config.to_dict(),  # Append config for traceability
    }

    if search_history:
        results["search_history"] = search_history

    return results


def run_test_suite(
    configs: List[TrainingConfig],
    data_dir: str,
    results_file: str = "mixture_results.csv",
    overwrite: bool = False,
) -> pd.DataFrame:
    """
    Runs a full test suite over multiple datasets and configurations.

    Args:
        configs: A list of TrainingConfig objects to test.
        data_dir: The root directory containing synthetic dataset folders.
        results_file: The path to the output CSV file for storing results.
        overwrite: If True, deletes the existing results file.

    Returns:
        A pandas DataFrame containing the aggregated results.
    """
    root_path = Path(data_dir)
    dataset_paths = sorted([p.parent for p in root_path.glob("**/dataset.pkl")])

    if not dataset_paths:
        print(f"No datasets found in '{data_dir}'.")
        return pd.DataFrame()

    print(f"Found {len(dataset_paths)} datasets. Starting test suite...")

    results_list = []
    results_path = Path(results_file)
    results_path.parent.mkdir(parents=True, exist_ok=True)

    if overwrite and results_path.exists():
        results_path.unlink()

    header_written = results_path.exists()

    # Create a progress bar for the outer loop (datasets)
    for dataset_path in tqdm.tqdm(dataset_paths, desc="Testing Datasets"):
        for config in configs:
            # print(f"\n--- Testing {dataset_path.name} with config ---")
            try:
                result = run_single_test(dataset_path, config)
                results_list.append(result)

                # Append result to CSV immediately
                result_df = pd.DataFrame([result])
                mode = "a" if header_written else "w"
                result_df.to_csv(
                    results_path, mode=mode, header=not header_written, index=False
                )
                header_written = True

            except Exception as e:
                print(f"FATAL ERROR while testing {dataset_path.name}: {e}")
                import traceback

                traceback.print_exc()

    print(f"\nTest suite complete. Results saved to {results_file}")
    return pd.DataFrame(results_list)
