import fire

import os
# print(os.getcwd())
# os.chdir(os.getcwd())

from offlinerl.algo import algo_select
from offlinerl.data.d4rl_mabe import load_d4rl_buffer
from offlinerl.evaluation import OnlineCallBackFunction


def run_algo(**kwargs):
    algo_init_fn, algo_trainer_obj, algo_config = algo_select(kwargs)
    train_buffer = load_d4rl_buffer(algo_config["task"], algo_config["data_path"], algo_config["isMediumExpert"])
    algo_init = algo_init_fn(algo_config)
    algo_trainer = algo_trainer_obj(algo_init, algo_config)
    callback = OnlineCallBackFunction()
    callback.initialize(train_buffer=train_buffer, val_buffer=None, task=algo_config["task"])

    source_train_buffer = train_buffer[0]
    target_train_buffer = train_buffer[1]

    domain_N=100
    for num in range(200):
        if num<domain_N:
            print("\nSource domain num:{}".format(num))
            algo_trainer.train(source_train_buffer, None, callback_fn=callback,num=num,domain_N=domain_N)
        else:
            print("\nTarget domain num:{}".format(num))
            algo_trainer.train(target_train_buffer, None, callback_fn=callback,num=num,domain_N=domain_N)
          

if __name__ == "__main__":
    fire.Fire(run_algo)
    