import os
import pickle
import time
from math import sqrt

from torch.cuda import is_available as is_cuda_available

from args import DatasetConfig, SeedConfig, parse_args_to_dataclass
from bandit2.bandit_env import generate_bandit_trajectories
from mdp.chain_env import generate_trajectories as generate_chain_trajectories
from mdp.darkroom_env import generate_trajectories as generate_darkroom_trajectories
from util.seed import set_seed


def main(seed_config: SeedConfig, dataset_config: DatasetConfig):
    n_envs = dataset_config.n_envs
    device = "cuda" if is_cuda_available() else None

    n_envs_train = int(0.7 * n_envs)
    n_envs_test = int(0.2 * n_envs)
    n_envs_eval = n_envs - n_envs_train - n_envs_test

    os.makedirs("datasets", exist_ok=True)

    set_seed(seed_config.seed)

    filename = dataset_config.get_summary()
    for split_name, n_envs_split in [("train", n_envs_train), ("test", n_envs_test), ("eval", n_envs_eval)]:

        if dataset_config.env == "bandit":
            dataset = generate_bandit_trajectories(
                n_envs_split, dataset_config.context_len, dataset_config.n_actions, dataset_config.variance, device=device, pbar_desc=f"{split_name.capitalize()} Split"
            )
        elif dataset_config.env == "chain":
            dataset = generate_chain_trajectories(
                n_envs_split, dataset_config.context_len, dataset_config.n_states, dataset_config.variance, device=device, pbar_desc=f"{split_name.capitalize()} Split"
            )
        elif dataset_config.env == "darkroom":
            square_len = int(sqrt(dataset_config.n_states))
            assert square_len * square_len == dataset_config.n_states
            dataset = generate_darkroom_trajectories(n_envs_split, dataset_config.context_len, square_len, device=device, pbar_desc=f"{split_name.capitalize()} Split")
        else:
            raise NotImplementedError()

        with open(f"datasets/{filename}_{split_name}.pkl", "wb") as f:
            pickle.dump(dataset, f)
        print(f"Saved file to 'datasets/{filename}_{split_name}.pkl'.")


if __name__ == "__main__":
    args = parse_args_to_dataclass((SeedConfig, DatasetConfig))

    print(*args, sep="\n")

    time_start = time.time()
    main(*args)
    time_end = time.time()

    print(f"Total runtime: {time_end - time_start:.2f} s")
