import logging
import pathlib
import random
import sys

import hydra
import numpy as np
from omegaconf import OmegaConf

from vtamp.environments.utils import Environment, Task, Updater
from vtamp.policies.utils import Policy

log = logging.getLogger(__name__)


class StreamToLogger:
    def __init__(self, logger, log_level):
        self.logger = logger
        self.log_level = log_level
        self.linebuf = ""

    def write(self, buf):
        for line in buf.rstrip().splitlines():
            self.logger.log(self.log_level, line.rstrip())

    def flush(self):
        pass


def setup_logger():

    log_level = logging.DEBUG
    logger = logging.getLogger()

    # Add StreamHandler to logger to output logs to stdout
    ch = logging.StreamHandler(sys.stdout)
    ch.setLevel(log_level)
    formatter = logging.Formatter("%(message)s")
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    # Redirect stdout and stderr
    sys.stdout = StreamToLogger(logger, log_level)
    sys.stderr = StreamToLogger(logger, logging.ERROR)


@hydra.main(
    version_base=None,
    config_path=str(pathlib.Path(__file__).parent.joinpath("vtamp", "config")),
)
def main(cfg: OmegaConf):

    log.info(" ".join(sys.argv))

    setup_logger()

    if cfg.get("seed") is not None:
        random.seed(cfg["seed"])
        np.random.seed(cfg["seed"])

    log.info("Setting up environment and policy...")
    task: Task = hydra.utils.instantiate(cfg.task)
    updater: Updater = hydra.utils.instantiate(cfg.updater)
    env: Environment = hydra.utils.instantiate(
        cfg.env, task=task, render=cfg.render and not cfg.vis_debug
    )
    obs = env.reset()

    belief = updater.update(obs)

    twin_env: Environment = hydra.utils.get_class(cfg.env._target_).sample_twin(
        env, belief, task, render=cfg.vis_debug
    )
    policy: Policy = hydra.utils.instantiate(
        cfg.policy, twin=twin_env, seed=cfg["seed"]
    )

    for i in range(cfg.get("max_env_steps")):
        log.info("Step " + str(i))
        goal = env.task.get_goal()
        log.info("Goal: " + str(goal))
        belief = updater.update(obs)
        log.info("Scene: " + str(belief))
        action = policy.get_action(belief, goal)
        log.info("Action: " + str(action))
        if action is None:
            break
        obs, reward, done, info = env.step(action)

        if cfg.render:
            env.render()

        log.info("Reward: " + str(reward))
        log.info("Done: " + str(done))
        log.info("Info: " + str(info))

    env.close()


if __name__ == "__main__":
    main()
