import json
import os
import shutil
import subprocess
import time

import numpy as np


def run_algorithm(algo_name, algo_args):
    print(f"\n\n\n--------------- Time {algo_name} ---------------", flush=True)
    save_path = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), f"../../experiments/atari/exp_output/_time_{algo_name}_Pong"
    )
    if os.path.exists(save_path):
        shutil.rmtree(save_path)

    time_begin = time.time()
    returncode = subprocess.run(
        f"python3 experiments/atari/{algo_name}.py --experiment_name _time_{algo_name}_Pong {algo_args}".split(" ")
    ).returncode
    time_end = time.time()

    if returncode != 0:
        print(
            f"Training {algo_name} should not have raised an error. The training time has to be recomputed.", flush=True
        )
    else:
        print(f"{algo_name} trained in {np.around(time_end - time_begin)} seconds.", flush=True)

    shutil.rmtree(save_path)

    return time_end - time_begin if returncode == 0 else None


if __name__ == "__main__":
    base_args = (
        "--seed 1 --disable_wandb --features 32 64 64 512 --replay_buffer_capacity 1_000_000 --batch_size 32 "
        + "--update_horizon 1 --learning_rate 6.25e-5 --horizon 27_000 --n_epochs 2 "
        + "--n_training_steps_per_epoch 250_000 --update_to_data 4 --target_update_frequency 8000 "
        + "--n_initial_samples 20_000 --epsilon_duration 1 --epsilon_end 0.01"
    )

    time_metadqn = run_algorithm(
        "metadqn",
        base_args + " --gamma_init 0.99 --gamma_validation 0.99 --meta_learning_rate 0.001",
    )
    time_adadqn = run_algorithm(
        "adadqn",
        base_args
        + " --gamma_range 0.985 0.995 --gamma_validation 0.99 --n_networks 5 --exploitation_type elitism --hp_update_frequency 80000",
    )

    json.dump(
        {"metadqn": time_metadqn, "adadqn": time_adadqn},
        open("tests/time_computation/time_algorithms.json", "w"),
        indent=4,
    )
