from UtilsRL.exp import setup, parse_args
from UtilsRL.logger import TensorboardLogger
from offlinerl.algo.redm_offline import REDMTrainer

# load configs
import offlinerl.config.redm_offline as redm_config
args = parse_args(redm_config, convert=True)

# initialize logger
logger = TensorboardLogger(
    log_path=args["tb_log_path"],
    name=args["exp_name"],
    txt=True
)

# setup experiments
args = setup(args, logger)

if args["test_mode"]:
    args["Dynamics"]["max_epoch"] = 3
    args["Candidate"]["train_epoch"] = 1
    args["Meta"]["train_epoch"] = 1
    args["BC"]["train_epoch"] = 1

# logger.log_str(str(args))


# initialize trainer
trainer = REDMTrainer(args)
train_stage = {
    "dynamics": trainer.train_dynamics,
    "bc_policy": trainer.train_bc_policy,
    "mainloop": trainer.train_mainloop,
}
load_stage = {
    "dynamics": trainer.load_dynamics,
    "bc_policy": trainer.load_bc_policy,
    "meta_policy": trainer.train_mainloop,
}

if args["from"] in ["candidate_model", "meta_policy"]:
    args["from"] = "mainloop"

if args["to"] in ["candidate_model", "meta_policy"]:
    args["to"] = "mainloop"

load = True
for stage in ["bc_policy", "dynamics", "mainloop"]:
    if stage == args["from"]:
        load = False
    if load:
        load_stage[stage](args[stage+"_path"])
    else:
        train_stage[stage](args[stage+"_path"])
    if stage == args["to"]:
        break



