import re
import warnings

import embodied
import numpy as np
import ruamel.yaml as yaml

import car_dreamer
import dreamerv3

warnings.filterwarnings("ignore", ".*truncated to dtype int32.*")


def wrap_env(env, config):
    args = config.wrapper
    env = embodied.wrappers.SetEWrapper(env)
    env = embodied.wrappers.InfoWrapper(env)
    for name, space in env.act_space.items():
        if name == "reset":
            continue
        elif config.actor_dist_disc == "twohot":
            # TODO: special case for hierarcy
            env = embodied.wrappers.TwoHotActionMA(env, name)
        elif args.discretize:
            env = embodied.wrappers.DiscretizeAction(env, name, args.discretize)
        else:
            env = embodied.wrappers.NormalizeAction(env, name)
    if args.length:
        env = embodied.wrappers.TimeLimit(env, args.length, args.reset)
    if args.checks:
        env = embodied.wrappers.CheckSpaces(env)
    for name, space in env.act_space.items():
        if not space.discrete:
            env = embodied.wrappers.ClipAction(env, name)
    return env


def eval_only(agents, env, logger, args):
    print("Start evaluation.")
    print("args:", args)
    logdir = embodied.Path(args.logdir)
    logdir.mkdirs()
    print("Logdir", logdir)
    step = logger.step
    metrics = embodied.Metrics()
    print("Observation space:", env.obs_space)
    print("Action space:", env.act_space)

    timer = embodied.Timer()
    for agent in agents:
        timer.wrap("agent", agent, ["policy"])
    timer.wrap("env", env, ["step"])
    timer.wrap("logger", logger, ["write"])

    nonzeros = set()

    def per_episode(ep, ep_info, i):
        length = len(ep["reward"]) - 1
        score = float(ep["reward"].astype(np.float64).sum())
        logger.add({"length": length, "score": score}, prefix="episode")
        print(f"Episode has {length} steps and return {score:.1f}.")
        stats = {}
        for key in args.log_keys_video:
            if key in ep:
                stats[f"policy_{key}{i}"] = ep[key]

        def log(key, value):
            if re.search(args.log_keys_sum, key):
                stats[f"sum_{key}"] = value.sum()
            if re.search(args.log_keys_mean, key):
                stats[f"mean_{key}"] = value.mean()
            if re.search(args.log_keys_max, key):
                stats[f"max_{key}"] = value.max(0).mean()

        for key, value in ep.items():
            if not args.log_zeros and key not in nonzeros and (value == 0).all():
                continue
            nonzeros.add(key)
            log(key, value)
        for key, value in ep_info.items():
            log(key, value)

        logger.add(metrics.result())
        logger.add(timer.stats(), prefix="timer")
        logger.write(fps=True)

        metrics.add(stats, prefix="stats")

    def per_step(tran, info, agent_idx):
        if agent_idx == 0:
            step.increment()

    env.set_e_parameters(
        command_repeat_counts=16,
        enable_replan=True,
        num_vehicles=200,
    )
    agent.set_m_parameters(
        m_horizon=1,
    )

    driver = embodied.MADriver(env, args.num_agents)
    driver.on_episode(lambda ep, ep_info, agent_idx: per_episode(ep, ep_info, agent_idx))
    driver.on_step(lambda tran, info, agent_idx: per_step(tran, info, agent_idx))

    checkpoint = embodied.Checkpoint()
    checkpoint.agent = agent
    if args.from_checkpoint:
        checkpoint.load(args.from_checkpoint, keys=["agent"])
    else:
        raise ValueError("No checkpoint specified.")

    print("Start evaluation loop.")
    policy = lambda *args: agent.policy(*args, mode="eval")
    while step < args.steps:
        driver(policy, steps=100)
    logger.write()


def main(argv=None):

    model_configs = yaml.YAML(typ="safe").load((embodied.Path(__file__).parent / "dreamerv3.yaml").read())
    config = embodied.Config({"dreamerv3": model_configs["defaults"]})
    config = config.update({"dreamerv3": model_configs["small"]})

    parsed, other = embodied.Flags(task=["carla_message_ma"]).parse_known(argv)
    for name in parsed.task:
        print("Using task: ", name)
        from car_dreamer import CarlaMessageMaEnv

        env_config = car_dreamer.load_task_configs(name)
        env_config, _ = car_dreamer.toolkit.Flags(env_config).parse_known(argv)
        env = CarlaMessageMaEnv(env_config.env)
    config = config.update(env_config)
    config = embodied.Flags(config).parse(other)

    logdir = embodied.Path(config.dreamerv3.logdir)
    step = embodied.Counter()
    logger = embodied.Logger(
        step,
        [
            embodied.logger.TerminalOutput(),
            embodied.logger.JSONLOutput(logdir, "metrics.jsonl"),
            embodied.logger.TensorBoardOutput(logdir),
        ],
    )

    dreamerv3_config = config.dreamerv3
    dreamerv3_config = dreamerv3_config.update(
        {
            "run.log_keys_sum": ".*travel_distance.*|.*destination_reached.*|.*out_of_lane.*|.*time_exceeded.*|.*is_collision.*|.*timesteps.*",
            "run.log_keys_mean": ".*travel_distance.*|.*speed_norm.*|.*wpt_dis.*",
            "run.log_keys_max": ".*travel_distance.*|.*speed_norm.*|.*wpt_dis.*",
            "run.steps": 3e4,
        }
    )
    print("Config:", dreamerv3_config)

    from embodied.envs import from_gym

    env = from_gym.FromGym(env)
    env = wrap_env(env, dreamerv3_config)
    env = embodied.BatchAgentEnv([env], num_agents=config.env.num_agents)

    agents = []
    for agent_idx in range(config.env.num_agents):
        agent_config = dreamerv3_config.update(
            {
                "jax.policy_devices": [agent_idx],
            }
        )
        agents.append(dreamerv3.Agent(env.obs_space, env.act_space, step, agent_config))
    args = embodied.Config(
        **dreamerv3_config.run,
        logdir=dreamerv3_config.logdir,
        batch_steps=dreamerv3_config.batch_size * dreamerv3_config.batch_length,
        num_agents=config.env.num_agents,
    )
    eval_only(agents, env, logger, args)


if __name__ == "__main__":
    main()
