from sacred import Experiment
from sacred.observers import FileStorageObserver
from tqdm import tqdm

from discovery_baselines import *

from data_generation.data_from_dict import data_ingredient, load_data
import time

ex = Experiment("baseline_discovery")

# Add a file storage observer to save results to a directory
ex.observers.append(FileStorageObserver("results/baselines"))


@ex.config
def config():
    verbose = False
    data_config = {
        "file_path": None,
        "name": None,
        "id": None,
        "dictionary": {
            "X": {
                "type": "normal",
                "length": 1500,
            },
            "transformation": {
                "type": "neural_network",
                "args": {
                    "num_hidden": 10,
                    "num_parents": 1,
                },
            },
            "shape": "sequence",
            "depth": 2,
            "seed": 10,
            "noise_type": "normal",
            "noise_parameters": None,
            "standardize": True,
        },
    }


@ex.automain
def main(
    data_config,
):
    data = load_data(data_config)

    methods = [
        linear_ANM,
        additive_ANM,
        additive_ANM_UV,
        nonlinear_ANM,
        score_ANM,
        nogam_ANM,
        var_sort,
        r2_sort,
        entropy_knn,
        get_dagmalinear_order,
        pnl,  # slow
        causal_score_matching,
    ]

    results = {}
    # Run each method on this dataset
    for method in tqdm(methods):
        start_time = time.time()
        result = method(data)
        elapsed_time = time.time() - start_time

        if isinstance(result, list) or isinstance(result, tuple):
            if result == [0, 1]:
                results[method.__name__] = {"correct": 1, "time": elapsed_time}
            else:
                results[method.__name__] = {"correct": 0, "time": elapsed_time}
        elif isinstance(result, dict):
            result["time"] = elapsed_time
            results[method.__name__] = result

    return results
