from causal_discovery.new_approach import NewApproach
from causal_discovery.utils import combine_with_defaults
from experiments.run_generated_graphs import parse_args_generated_graphs
from mrunner.helpers.specification_helper import create_experiments_helper

name = globals()["script"][:-3]

base_config = {
    "cluster": True,
    "graph_type": "random",
    "num_vars": 25,
    "edge_prob": 0.3,
    "max_graph_stacking": 200,
    "model_iters": 1000,
    "graph_iters": 100,
    "use_theta_only_stage": False,
    "theta_only_iters": 1000,
    "sample_size_obs": 5_000,
    "use_neptune_logger": False,
    "visualize": False,
    "num_graphs": 1,
    "log_grads": True,
    "force_online_data": False,
    "interventions_policy": "nonempty_round_robin",
    "num_epochs": 100,
    "embed_dim": 4,
    "batch_size": 128,
    "num_categs": 10,
    "lambda_sparse": 0.004,
    "lr_gamma": 2e-2,
    "lr_model": 5e-3,
    "lr_theta": 1e-1,
    "weight_decay": 1e-4,
}

base_config = combine_with_defaults(
    base_config, defaults=vars(parse_args_generated_graphs([]))
)

base_config.update(
    {
        "test_graph.discovery_method": "@causal_discovery.new_approach.NewApproach",
        "Logger.log_frequency": 50,
        "NewApproach.num_inner_loop_epochs": 30,
        "NewApproach.int_data_collection_batch_size": 32,
        # Acquisition methods settings
        "BALD_num_int_samples": 128,  # default
        "num_hypothetical_graphs": 50,
        "policy_softmax_temperature": 1.0,
    }
)


params_grid = [
    {
        "NewApproach.int_data_collection_policy": [
            "gradients_l2_hypothetical_samples"
        ],
        "graph_type": ["jungle"],
        "seed": list(range(0, 1)),
    },
]


experiments_list = create_experiments_helper(
    experiment_name=name,
    project_name="",
    script="python mrun_generated_graphs.py",
    python_path="",
    exclude=[
        ".pytest_cache",
        "__pycache__",
        "checkpoints",
        "out",
        "singularity",
        "data",
    ],
    tags=[name],
    base_config=base_config,
    params_grid=params_grid,
)
