import fire

from offlinerl.algo import algo_select
from offlinerl.data.d4rl import load_d4rl_buffer
from offlinerl.evaluation import OnlineCallBackFunction
import torch
print(torch.cuda.is_available())

def run_algo(**kwargs):
    algo_init_fn, algo_trainer_obj, algo_config = algo_select(kwargs)
    train_buffer = load_d4rl_buffer(algo_config["task"], 
                                    algo_config["data_path"], 
                                    algo_config["isMediumExpert"],
                                    algo_config["data_proportion"]
                                    )
    algo_init = algo_init_fn(algo_config)
    algo_trainer = algo_trainer_obj(algo_init, algo_config)
    callback = OnlineCallBackFunction()
    callback.initialize(train_buffer=train_buffer, val_buffer=None, task=algo_config["task"])

    algo_trainer.train(train_buffer, None, callback_fn=callback)

if __name__ == "__main__":
    fire.Fire(run_algo)

   #python3 examples/train_d4rl.py --algo_name=combo --exp_name=d4rl-walker2d-random-v0-combo --task=d4rl-walker2d-random-v0 --seed=0 --isMediumExpert=False --data_path=None
