from sacred import Experiment
from sacred.observers import FileStorageObserver
import numpy
import torch
import os
from scipy.stats import gaussian_kde
from data_generation.data_from_dict import data_ingredient, load_data

from canm_discovery import CANM_loss, fit

# Create Sacred experiment
ex = Experiment("canm_discovery", ingredients=[data_ingredient])

# Add a file storage observer to save results to a directory
ex.observers.append(FileStorageObserver("results/canm"))


@ex.config
def config():
    depth = 2
    epochs = 50
    verbose = False
    beta = 1.0
    find_best_N = False
    batch_size = 8096
    training_hyperparameters = {
        "type": "cyclical",  # constant, cyclical, or linear
        "M": 4,  # number of cycles
        "R": 0.5,  # proportuion used to increase beta within a cycle
    }
    data_config = {
        "file_path": None,
        "config": None,
        "name": None,
        "dictionary": {
            "X": {
                "type": "uniform",
                "length": 500,
            },
            "transformation": {
                "type": "tanh",
                "args": {
                    "num_hidden": 10,
                    "num_parents": 1,
                },
            },
            "shape": "sequence",
            "depth": 2,
            "seed": 42,
            "noise_type": "uniform",
        },
    }


@ex.automain
def main(
    data_config,
    depth,
    epochs,
    seed,
    verbose,
    find_best_N,
    beta,
    training_hyperparameters,
    batch_size,
):
    data = torch.from_numpy(load_data(data_config))

    if find_best_N:
        depth = None  # Make sure that we not leak information about the depth

    results_xy = CANM_loss(
        data[:, [0, 1]],
        epochs=epochs,
        seed=seed,
        verbose=verbose,
        find_best_N=find_best_N,
        depth=depth,
        beta=beta,
        training_hyperparameters=training_hyperparameters,
        batch_size=batch_size,
    )
    results_yx = CANM_loss(
        data[:, [1, 0]],
        epochs=epochs,
        seed=seed,
        verbose=verbose,
        find_best_N=find_best_N,
        depth=depth,
        beta=beta,
        training_hyperparameters=training_hyperparameters,
        batch_size=batch_size,
    )

    return {"results_xy": results_xy, "results_yx": results_yx}
