import cyclopts
import ipdb
import os
import numpy as np

from reform.agents import agents, AgentCfg, ReFORMCfg
from reform.trainer.datasets import Dataset, DatasetCfg
from reform.trainer.trainer import TrainerCfg, Trainer
from reform.env import make_env_and_datasets


app = cyclopts.App()


def _train(
        agent_cfg: AgentCfg,
        trainer_cfg: TrainerCfg,
        dataset_cfg: DatasetCfg = DatasetCfg(),
        env_name: str = "antmaze-medium-navigate-singletask-task1-v0",
        seed: int = 0,
        debug: bool = False
):
    # Set up environment variables and seed.
    os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
    if debug:
        os.environ["WANDB_MODE"] = "disabled"
        os.environ["JAX_DISABLE_JIT"] = "True"
    global_rng = np.random.default_rng(seed)
    np.random.seed(global_rng.integers(2 ** 31))

    # Create environment and load datasets.
    env, train_dataset, val_dataset = make_env_and_datasets(env_name, render_mode="rgb_array")
    train_dataset = Dataset.create(dataset_cfg, global_rng.integers(2**31), **train_dataset)
    val_dataset = Dataset.create(dataset_cfg, global_rng.integers(2**31), **val_dataset)

    # Create agent.
    example_batch = train_dataset.sample(1)
    agent_class = agents[agent_cfg.agent_name]
    agent = agent_class.create(
        global_rng.integers(2**31),
        example_batch['observations'],
        example_batch['actions'],
        agent_cfg,
    )

    # Initialize trainer.
    trainer = Trainer(trainer_cfg, env_name, env, train_dataset, val_dataset, dataset_cfg, agent)

    # Start training.
    trainer.train(seed, debug)


@app.command
def reform(
        agent_cfg: ReFORMCfg = ReFORMCfg(),
        trainer_cfg: TrainerCfg = TrainerCfg(),
        dataset_cfg: DatasetCfg = DatasetCfg(),
        env_name: str = "antmaze-medium-navigate-singletask-task1-v0",
        seed: int = 0,
        debug: bool = False
):
    _train(agent_cfg, trainer_cfg, dataset_cfg, env_name, seed, debug)


if __name__ == "__main__":
    with ipdb.launch_ipdb_on_exception():
        app()
