from run_diffusion import ex as run_diffusion
from run_canm import ex as run_canm
from run_baselines import ex as run_baselines
from itertools import product
from sacred.observers import FileStorageObserver

from tqdm import tqdm
import os

computer_name = os.environ.get("computer_name", "default_name")  # fallback optional
print(f"Computer name is: {computer_name}")

# Hyperparameter options
training_hyperparameter_options = [
    {
        "beta_schedule": {
            "type": "linear",
            "min": 0.0001,
            "max": 0.02,
            "timesteps": 256,
        },
        "condition_embed": True,
        "condition_dim": 4,
        "train_loss": "mse",
        "test_loss": "mse_and_hsic",
        "train_loss_schedule": {
            "min_beta": 0,
            "max_beta": 0,
            "start_beta": 0,
            "type": "constant",  # constant, cyclical(M, R), linear, adaptive, delayed_linear(M, R)
            "mi_threshold": 0.001,
            "adaptive_factor": 1.0001,
            "M": 4,  # number of cycles
            "R": 0.5,  # proportuion used to increase beta within a cycle
        },
        # Used for "plain"
        "layer_sizes": [128, 64],
        # Used for "resnet"
        "hidden_dim": 512,
        "num_blocks": 2,
    },
]

data_generation_options = [
    {
        "file_path": None,
        "name": "tuebingen",
        "id": id,
        "n_samples": 3000,
    }
    for id in range(0, 99)
]

run_diffusion.observers.clear()
run_diffusion.observers.append(
    FileStorageObserver(f"results/tuebingen_3000_diffusion_{computer_name}")
)
run_canm.observers.clear()
run_canm.observers.append(
    FileStorageObserver(f"results/tuebingen_3000_canm_{computer_name}")
)
run_baselines.observers.clear()
run_baselines.observers.append(
    FileStorageObserver(f"results/tuebingen_3000_baselines_{computer_name}")
)

runs_per_config = 5

for seed in tqdm(
    range(runs_per_config),
    desc="Training Iterations",
):
    for training_hyperparameters, data_config in tqdm(
        list(product(training_hyperparameter_options, data_generation_options)),
        desc="Config Combinations",
        leave=False,
        position=1,
    ):
        # data_config["dictionary"]["seed"] = seed
        run_diffusion.run(
            config_updates={
                "epochs": 4000,
                "data_config": data_config,
                "training_hyperparameters": training_hyperparameters,
                "batch_size": 8048,  # Large enough
                "validation_split": 0.0,
                "train_test_split": 0.8,
                "use_best_validation": False,
                "test_evaluation_passes": 10,
                "different_test_sets": False,
            }
        )
        run_baselines.run(
            config_updates={
                "data_config": data_config,
            }
        )
        run_canm.run(
            config_updates={
                "epochs": 2000,
                "data_config": data_config,
                "depth": 2,
                "training_hyperparameters": training_hyperparameters,
                "batch_size": 128,
            }
        )
