import pandas as pd
from emm import TrainingConfig
from emm.algorithms.mixture import RemixMixtureModel
from sklearn.model_selection import train_test_split
import time


def run_gold_experiment():
    gold = pd.read_csv("data/gold/data_Au5-14_2.0.0.csv")
    cols = pd.read_csv("data/gold/data_Au5-14_2.0.0_attributes.csv")
    gold.columns = ["Au+N", *list(cols["Structure"])]

    Y = gold["HOMO-LUMO"].to_numpy().reshape(-1, 1)
    gold["N-even"] = 1 - gold["N"] % 2
    gold = gold.drop(
        [
            "HOMO-LUMO",
            "Evdw-Evdw0",
            "Rg/Rg(ref)",
            "Au+N",
            "-LUMO",
            "-HOMO",
            "chemical hardness",
            "electronic chemical potential",
            "Evdw/N",
            "hardness-hardness0",
            "electronic chemical pot. - electronic chemical pot0",
            "|F| / N",
            "0#",
            "7#",
        ],
        axis=1,
    )
    feature_names = list(gold.columns)
    X = gold.to_numpy()

    name_mapping = {
        "N": "$N$",
        "E-E0": "$\\Delta E$",
        "T": "$T$",
        "1#": "1\\#",
        "1#": "1\\#",
        "2#": "2\\#",
        "3#": "3\\#",
        "4#": "4\\#",
        "5#": "5\\#",
        "6#": "6\\#",
        "Rg/Rg(ref)": "$R_{g0}$",
        "Evdw-Evdw0": "$\\Delta E_{vdW}$",
        "Planarity": "Planarity",
        "N-even": "$N_{\\text{even}}$",
        "HOMO-LUMO": "HOMO-LUMO",
    }

    X_train, X_test, Y_train, Y_test = train_test_split(
        X, Y, test_size=0.2, random_state=0
    )
    # %%
    remix_config = TrainingConfig(
        n_mixture_components=100,
        component_train_epochs=1000,
        device="cuda",
        use_gmm_remix=True,
        n_gmm_components=40,
        component_scoring=None,
        n_gmm_extra_components=0,
        and_layer_entropy=0.05,
        partition_weight=0.1,
        min_responsibility_threshold=0.001,
        pruning_threshold=0,
        merge_components=False,
        merge_settle_epochs=40,
        check_responsibility_every=25,
    )
    start1 = time.time()
    model1 = RemixMixtureModel(remix_config)
    model1.fit(X_train, Y_train, feature_names=feature_names)
    training_time1 = time.time() - start1

    X_train2, X_test2, Y_train2, Y_test2 = train_test_split(
        X, Y, test_size=0.2, random_state=100
    )
    start2 = time.time()
    model2 = RemixMixtureModel(remix_config)
    model2.fit(X_train2, Y_train2, feature_names=feature_names)
    training_time2 = time.time() - start2

    X_train3, X_test3, Y_train3, Y_test3 = train_test_split(
        X, Y, test_size=0.2, random_state=1000
    )
    start3 = time.time()
    model3 = RemixMixtureModel(remix_config)
    model3.fit(X_train3, Y_train3, feature_names=feature_names)
    training_time3 = time.time() - start3

    results1 = {
        "dataset": "gold",
        "seed": 0,
        "train_nll": model1.get_nll(X_train, Y_train),
        "test_nll": model1.get_nll(X_test, Y_test),
        "rules": model1.rules_model.debug_print_cutpoints(
            scaler=model1.preprocessor.scaler_x,
            simple_format=True,
            feature_names=feature_names,
        ),
        "n_rules": len([c for c in model1.disabled_components if not c]),
        "n_features": X.shape[1],
        "n_samples": X.shape[0],
        "runtime_seconds": training_time1,
    }
    results2 = {
        "dataset": "gold",
        "seed": 100,
        "train_nll": model2.get_nll(X_train, Y_train),
        "test_nll": model2.get_nll(X_test, Y_test),
        "rules": model2.rules_model.debug_print_cutpoints(
            scaler=model2.preprocessor.scaler_x,
            simple_format=True,
            feature_names=feature_names,
        ),
        "n_rules": len([c for c in model2.disabled_components if not c]),
        "n_features": X.shape[1],
        "n_samples": X.shape[0],
        "runtime_seconds": training_time2,
    }
    results3 = {
        "dataset": "gold",
        "seed": 1000,
        "train_nll": model3.get_nll(X_train, Y_train),
        "test_nll": model3.get_nll(X_test, Y_test),
        "rules": model3.rules_model.debug_print_cutpoints(
            scaler=model3.preprocessor.scaler_x,
            simple_format=True,
            feature_names=feature_names,
        ),
        "n_rules": len([c for c in model3.disabled_components if not c]),
        "n_features": X.shape[1],
        "n_samples": X.shape[0],
        "runtime_seconds": training_time3,
    }

    df = pd.DataFrame([results1, results2, results3])
    df.to_csv("results/emm_casestudy_gold.csv", index=False)


if __name__ == "__main__":
    run_gold_experiment()
