import time
import datetime
from pathlib import Path
import json
from typing import List
import argparse
import numpy as np
import pandas as pd
import tqdm
from sklearn.model_selection import train_test_split
import torch

try:
    from emm import MixtureModel, FlowMixtureModel, RemixMixtureModel, TrainingConfig
except ImportError:
    print("FATAL: Could not import 'emm'. Please ensure the library is installed.")
    exit(1)


def run_single_real_test(csv_path: Path, config: TrainingConfig) -> dict:
    """
    Runs a single training and evaluation trial on a real-world CSV dataset.

    Args:
        csv_path: Path to the CSV dataset file.
        config: A TrainingConfig object specifying the model and hyperparameters.

    Returns:
        A dictionary containing the results of the trial.
    """
    # 1. Load Data from CSV
    try:
        data_df = pd.read_csv(csv_path)
        X = data_df.iloc[:, :-1].values
        Y = data_df.iloc[:, -1].values.reshape(-1, 1)
        feature_names = data_df.columns[:-1].tolist()
    except Exception as e:
        print(f"ERROR: Failed to load or parse {csv_path.name}: {e}")
        return {"dataset": csv_path.name, "error": f"Data loading failed: {e}"}

    # 2. Handle Train/Test Split using a corresponding '_split_indices.json' file
    indices_path = csv_path.with_name(f"{csv_path.stem}_split_indices.json")
    if indices_path.exists():
        print(f"Loading split indices from {indices_path}")
        with open(indices_path, "r") as f:
            split_data = json.load(f)
        train_indices = np.array(split_data["train_indices"])
        test_indices = np.array(split_data["test_indices"])
        X_train, Y_train = X[train_indices], Y[train_indices]
        X_test, Y_test = X[test_indices], Y[test_indices]
    else:
        print(
            f"Warning: No split index file found at {indices_path}. Using random split."
        )
        X_train, X_test, Y_train, Y_test = train_test_split(
            X, Y, test_size=0.2, random_state=42
        )

    # 3. Select and Train the Model
    model_class = RemixMixtureModel if config.use_gmm_remix else FlowMixtureModel
    model: MixtureModel
    search_history = None
    start_time = time.time()

    try:
        # The model's fit method now handles the component search internally
        if config.verbose:
            print(
                f"Instantiating and training {model_class.__name__} for {csv_path.name}..."
            )
        model = model_class(config)
        model.fit(X_train, Y_train, feature_names=feature_names)

        # If history was requested, retrieve it from the model instance
        if config.model_finder_return_history:
            search_history = model.search_history

    except Exception as e:
        print(f"ERROR: Training failed for {csv_path.name} with exception: {e}")
        import traceback

        traceback.print_exc()
        return {"dataset": csv_path.name, "error": str(e), **config.to_dict()}

    runtime = time.time() - start_time

    # 4. Evaluate - Calculate only metrics that don't require true labels
    metrics = model.get_metrics()
    rule_complexity_metrics = model.get_rule_complexity_metrics()
    train_nll = model.get_nll(X_train, Y_train)
    test_nll = model.get_nll(X_test, Y_test)

    # 5. Compile and Return Results
    results = {
        "dataset": csv_path.name,
        "timestamp": datetime.datetime.now().isoformat(),
        "model_type": model_class.__name__,
        "train_nll": train_nll,
        "test_nll": test_nll,
        "bic": metrics.get("bic"),
        "aic": metrics.get("aic"),
        "rules": model.rules_model.debug_print_cutpoints(
            scaler=model.preprocessor.scaler_x,
            simple_format=True,
            feature_names=feature_names,
        ),
        "n_rules": len([c for c in model.disabled_components if not c]),
        "n_features": X.shape[1],
        "n_samples": X.shape[0],
        "runtime_seconds": runtime,
        **rule_complexity_metrics,
        **config.to_dict(),
    }

    if search_history:
        results["search_history"] = str(
            search_history
        )  # Stringify for CSV compatibility

    return results


def run_mixture_on_real_data(
    data_folder: str, results_file: str, configs: List[TrainingConfig]
):
    """
    Loads all CSV files from a folder, runs mixture models on them, and saves results.
    """
    data_path = Path(data_folder)
    if not data_path.is_dir():
        raise FileNotFoundError(f"Data folder not found: {data_folder}")

    csv_files = sorted(list(data_path.glob("*.csv")))
    if not csv_files:
        raise ValueError(f"No CSV files found in {data_folder}")

    results_path = Path(results_file)
    results_path.parent.mkdir(parents=True, exist_ok=True)
    header_written = results_path.exists()

    for file_path in tqdm.tqdm(csv_files, desc="Running Mixture Models"):
        for config in configs:
            try:
                print(f"\nProcessing file: {file_path.name} with config...")
                result = run_single_real_test(file_path, config)

                result_df = pd.DataFrame([result])
                mode = "a" if header_written else "w"
                result_df.to_csv(
                    results_path, mode=mode, header=not header_written, index=False
                )
                header_written = True

            except Exception as e:
                print(f"FATAL ERROR on file {file_path.name}: {e}")

    print(f"\nCompleted processing. Final results saved to {results_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run mixture models on real-world CSV datasets."
    )
    parser.add_argument(
        "--data_folder",
        type=str,
        default="data/real",
        help="Path to the folder containing CSV datasets.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="results/real",
        help="Directory to save results.",
    )
    args = parser.parse_args()
    # use cuda if available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    nsf_config = TrainingConfig(
        use_model_finder=True,
        model_finder_component_range=[100],
        component_train_epochs=1500,
        device=device,
        flow_gen=("zuko_nsf", {}),
        use_gmm_remix=False,
        and_layer_entropy=0.1,
        flow_steps=1,
        rules_steps=1,
        partition_weight=0.1,
        min_responsibility_threshold=0.01,
        lr_flow=5e-3,
        merge_components=True,
        merge_settle_epochs=100,
        merge_adjacency_tol=0.5,
        check_responsibility_every=25,
    )
    nsf_config_bic = nsf_config.copy()
    nsf_config_bic.model_finder_component_range = [20, 100]
    remix_config = TrainingConfig(
        use_model_finder=True,
        model_finder_component_range=[100],
        component_train_epochs=1500,
        device=device,
        use_gmm_remix=True,
        n_gmm_components=30,
        component_scoring="BIC",
        n_gmm_extra_components=3,
        and_layer_entropy=0.1,
        partition_weight=0.1,
        min_responsibility_threshold=0.005,
        merge_components=True,
        merge_settle_epochs=100,
        merge_adjacency_tol=0.5,
        check_responsibility_every=25,
    )
    remix_config_bic = remix_config.copy()
    remix_config_bic.model_finder_component_range = [20, 100]

    run_mixture_on_real_data(
        data_folder=args.data_folder,
        results_file=f"{args.output_dir}/emm_gmm.csv",
        configs=[remix_config],
    )
    run_mixture_on_real_data(
        data_folder=args.data_folder,
        results_file=f"{args.output_dir}/emm_gmm_bic.csv",
        configs=[remix_config_bic],
    )
    run_mixture_on_real_data(
        data_folder=args.data_folder,
        results_file=f"{args.output_dir}/emm_nsf.csv",
        configs=[nsf_config],
    )
    run_mixture_on_real_data(
        data_folder=args.data_folder,
        results_file=f"{args.output_dir}/emm_nsf_bic.csv",
        configs=[nsf_config_bic],
    )
