import datetime
import warnings

import embodied
import ruamel.yaml as yaml

import car_dreamer
import dreamerv3

warnings.filterwarnings("ignore", ".*truncated to dtype int32.*")


def wrap_env(env, config):
    args = config.wrapper
    env = embodied.wrappers.SetEWrapper(env)
    env = embodied.wrappers.InfoWrapper(env)
    for name, space in env.act_space.items():
        if name == "reset":
            continue
        elif config.actor_dist_disc == "twohot":
            # TODO: special case for hierarcy
            env = embodied.wrappers.TwoHotAction(env, name)
        elif space.discrete:
            env = embodied.wrappers.OneHotAction(env, name)
        elif args.discretize:
            env = embodied.wrappers.DiscretizeAction(env, name, args.discretize)
        else:
            env = embodied.wrappers.NormalizeAction(env, name)
    env = embodied.wrappers.ExpandScalars(env)
    if args.length:
        env = embodied.wrappers.TimeLimit(env, args.length, args.reset)
    if args.checks:
        env = embodied.wrappers.CheckSpaces(env)
    for name, space in env.act_space.items():
        if not space.discrete:
            env = embodied.wrappers.ClipAction(env, name)
    return env


def main(argv=None):
    model_configs = yaml.YAML(typ="safe").load((embodied.Path(__file__).parent / "dreamerv3.yaml").read())
    config = embodied.Config({"dreamerv3": model_configs["defaults"]})
    config = config.update({"dreamerv3": model_configs["small"]})

    parsed, other = embodied.Flags(task=["carla_navigation"]).parse_known(argv)
    for name in parsed.task:
        print("Using task: ", name)
        env, env_config = car_dreamer.create_task(name, argv)
        config = config.update(env_config)
    config = embodied.Flags(config).parse(other)

    logdir = embodied.Path(config.dreamerv3.logdir)
    step = embodied.Counter()
    logger = embodied.Logger(
        step,
        [
            embodied.logger.TerminalOutput(),
            embodied.logger.JSONLOutput(logdir, "metrics.jsonl"),
            embodied.logger.TensorBoardOutput(logdir),
        ],
    )

    from embodied.envs import from_gym

    model_config = config.dreamerv3
    # HANSOME configs
    model_config = model_config.update(
        {
            "n_commands": config.env.action.n_commands,
            "actor_dist_disc": "twohot",
            "command_horizon": 16,
            "actor.unimix": [0.01, 0.01],
            "actor.maxstd": [1.0, 1.0],
            "actor.minstd": [15, 3],
        }
    )
    config = config.update({"dreamerv3": model_config})
    env = from_gym.FromGym(env)
    env = wrap_env(env, model_config)
    env = embodied.BatchEnv([env], parallel=False)

    timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    config_filename = f"config_{timestamp}.yaml"
    config.save(str(logdir / config_filename))
    print(f"[Train] Config saved to {logdir / config_filename}")

    agent = dreamerv3.Agent(env.obs_space, env.act_space, step, model_config)
    replay = embodied.replay.Uniform(model_config.batch_length, model_config.replay_size, logdir / "replay")
    args = embodied.Config(
        **model_config.run,
        logdir=model_config.logdir,
        batch_steps=model_config.batch_size * model_config.batch_length,
        actor_dist_disc=model_config.actor_dist_disc,
    )
    embodied.run.train(agent, env, replay, logger, args)


if __name__ == "__main__":
    main()
