from numpy import sqrt
from run_diffusion import ex as run_diffusion
from run_baselines import ex as run_baselines
from run_canm import ex as run_canm
from itertools import product
from sacred.observers import FileStorageObserver

from tqdm import tqdm
import os

computer_name = os.environ.get("computer_name", "default_name")  # fallback optional
print(f"Computer name is: {computer_name}")

computer_name_to_seed = {
    "gpu1": 100,
    "gpu2": 200,
    "gpu3": 300,
    "gpu4": 400,
    "gpu5": 500,
}

# Hyperparameter options
training_hyperparameter_options = [
    {
        "beta_schedule": {
            "type": "linear",
            "min": 0.0001,
            "max": 0.02,
            "timesteps": 256,
        },
        "condition_embed": True,
        "condition_dim": 4,
        "train_loss": "hsic_mse_debug",
        "test_loss": "hsic_mse_debug",
        "train_loss_schedule": {
            "min_beta": 0,
            "max_beta": 0,
            "start_beta": 0,
            "type": "constant",  # constant, cyclical(M, R), linear, adaptive, delayed_linear(M, R)
            "mi_threshold": 0.001,
            "adaptive_factor": 1.0001,
            "M": 4,  # number of cycles
            "R": 0.5,  # proportuion used to increase beta within a cycle
        },
        # Used for "plain"
        "layer_sizes": [128, 64],
        # Used for "resnet"
        "hidden_dim": 512,
        "num_blocks": 2,
    },
]

data_generation_options = [
    {
        "file_path": None,
        "dictionary": {
            "X": {
                "type": "normal",
                "length": 0,
            },
            "transformation": {
                "type": "neural_network",
                "args": {
                    "num_hidden": 10,
                    "num_parents": 1,
                    # "alpha": 1,
                },
            },
            "shape": "multiple_mediators",
            "depth": 0,
            "seed": 42,
            "standardize": True,
            "noise_type": "normal",
            "noise_parameters": None,
            "mediator_noise_type": "normal",
            "mediator_noise_parameters": {
                "mean": 0,
                "std": sqrt(0.5),
            },
        },
    },
]

transform_types = ["tanh"]
noises = ["uniform"]

lengths = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 2000]
depths = [2]

runs_per_config = 20

run_diffusion.observers.clear()
run_diffusion.observers.append(
    FileStorageObserver(f"results/sample_size_tanh_uniform_diffusion_{computer_name}")
)
run_canm.observers.clear()
run_canm.observers.append(
    FileStorageObserver(f"results/sample_size_tanh_uniform_canm_{computer_name}")
)
run_baselines.observers.clear()
run_baselines.observers.append(
    FileStorageObserver(f"results/sample_size_tanh_uniform_baseline_{computer_name}")
)

for seed in tqdm(
    range(runs_per_config),
    desc="Training Iterations",
):
    for (
        training_hyperparameters,
        data_config,
        depth,
        length,
        transform_type,
        noise,
    ) in tqdm(
        list(
            product(
                training_hyperparameter_options,
                data_generation_options,
                depths,
                lengths,
                transform_types,
                noises,
            )
        ),
        desc="Config Combinations",
        leave=False,
        position=1,
    ):
        data_config["dictionary"]["seed"] = seed + computer_name_to_seed.get(
            computer_name, 0
        )
        data_config["dictionary"]["depth"] = depth
        data_config["dictionary"]["X"]["length"] = length
        data_config["dictionary"]["transformation"]["type"] = transform_type
        data_config["dictionary"]["noise_type"] = noise
        data_config["dictionary"]["X"]["type"] = noise
        data_config["dictionary"]["mediator_noise_type"] = noise
        # run_diffusion.run(
        #     config_updates={
        #         "epochs": 4000,
        #         "data_config": data_config,
        #         "training_hyperparameters": training_hyperparameters,
        #         "batch_size": 8048,  # Large enough
        #         "validation_split": 0.0,
        #         "train_test_split": 0.8,
        #         "use_best_validation": False,
        #         "test_evaluation_passes": 10,
        #         "different_test_sets": True,
        #     }
        # )
        run_baselines.run(
            config_updates={
                "data_config": data_config,
            }
        )
