import jax 
import json
import yaml 
import argparse
import matplotlib

import numpy as np

from ood_detection import ood_detection
from uci_regression import uci_regression
from toy_regression import toy_regression
from posterior_comparison import posterior_comparaison

matplotlib.rcParams["pdf.fonttype"] = 42

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_debug_nans", True)


if __name__ == "__main__":

    # Define random key 
    key = jax.random.PRNGKey(0)
    key, key1, key2, key3, key4, key5, key6 = jax.random.split(key, num=7)

    # Fix random seed
    np.random.seed(0)

    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--config_file', required=True, type=str, 
        help="Config file for the model and experiment"
    )
    args = parser.parse_args()
        
    # Load configuration
    with open(args.config_file, "r") as f:
        config = yaml.safe_load(f)

    # Print configuration
    print(json.dumps(config, sort_keys=True, indent=4))

    # Experiment selection
    if config["experiment"]["name"] == "toy_regression":
        toy_regression(config)
    elif config["experiment"]["name"] == "uci_regression":
        uci_regression(config)
    elif config["experiment"]["name"] == "ood_detection":
        ood_detection(config) 
    elif config["experiment"]["name"] == "posterior_comparison":
        posterior_comparaison(config)
    else:
        raise NotImplementedError()