import json
import logging
import os
import pathlib
import random
import sys
import time

import hydra
import omegaconf
import numpy as np
from dotenv import load_dotenv
from rich.pretty import pprint

from vtamp.environments.utils import Environment, Task, Updater
from vtamp.policies.utils import Policy
from vtamp.utils import get_log_dir


load_dotenv()
log = logging.getLogger(__name__)

llm_out_path = "cashed_llm_out/proc3s_push_line.txt"

class StreamToLogger:
    def __init__(self, logger, log_level):
        self.logger = logger
        self.log_level = log_level
        self.linebuf = ""

    def write(self, buf):
        for line in buf.rstrip().splitlines():
            self.logger.log(self.log_level, line.rstrip())

    def flush(self):
        pass


def setup_logger():
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    log_dir = get_log_dir()
    log_file = os.path.join(log_dir, f"output.log")

    formatter = logging.Formatter("%(message)s")

    # FileHandler: only log errors to file
    fh = logging.FileHandler(log_file)
    fh.setLevel(logging.ERROR)
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    # Do NOT add a StreamHandler here!

    # Redirect stdout and stderr
    sys.stdout = StreamToLogger(logger, logging.INFO)   # print() as INFO
    sys.stderr = StreamToLogger(logger, logging.ERROR)  # errors as ERROR

@hydra.main(
    version_base=None,
    config_path=str(pathlib.Path(__file__).parent.joinpath("vtamp", "config")),
)
def main(cfg: omegaconf.DictConfig):

    log.info(" ".join(sys.argv))

    setup_logger()

    if cfg.get("seed") is not None:
        random.seed(cfg["seed"])
        np.random.seed(cfg["seed"])

    use_komo = cfg.get("use_komo")

    log.info("Setting up environment and policy...")
    pprint(cfg)
    task: Task = hydra.utils.instantiate(cfg.task)
    updater: Updater = hydra.utils.instantiate(cfg.updater)
    env: Environment = hydra.utils.instantiate(
        cfg.env, task=task, render=cfg.render and not cfg.vis_debug, use_komo=use_komo,
    )
    obs = env.reset()

    belief = updater.update(obs)

    twin_env: Environment = hydra.utils.get_class(cfg.env._target_).sample_twin(
        env, belief, task, render=cfg.vis_debug
    )
    policy: Policy = hydra.utils.instantiate(
        cfg.policy, cost_thresh=task.cost_threshold, twin=twin_env, seed=cfg["seed"], use_komo=use_komo
    )

    statistics = {"execution_time": 0, "planning_time": 0}
    
    if use_komo:
        goal = env.task.get_goal()
        log.info("Goal: " + str(goal))
        belief = updater.update(obs)
        log.info("Scene: " + str(belief))
        st = time.time()
        komo, step_statistics = policy.get_action(belief, goal)
        statistics["planning_time"] += time.time() - st
        # log.info("KOMO: " + str(komo))

        st = time.time()
        obs, reward, done, info = env.step_komo(komo, vis=True)
        statistics["execution_time"] += time.time() - st

        if cfg.render:
            env.render()

        log.info("Reward: " + str(reward))
        log.info("Done: " + str(done))
        log.info("Info: " + str(info))
    else:
        for i in range(cfg.get("max_env_steps")):
            log.info("Step " + str(i))
            goal = env.task.get_goal()
            log.info("Goal: " + str(goal))
            belief = updater.update(obs)
            # log.info("Scene: " + str(belief))
            st = time.time()
            # action, step_statistics = policy.get_action(belief, goal, "cashed_llm_out/deneck_push_line.txt")
            action, step_statistics = policy.get_action(belief, goal)
            for k, v in step_statistics.items():
                statistics["step_{}_{}".format(i, k)] = v
            statistics["planning_time"] += time.time() - st
            log.info("Action: " + str(action))
            if action is None:
                break

            st = time.time()
            obs, reward, done, info = env.step(action, vis=False)
            for k, v in info.items():
                statistics["step_{}_{}".format(i, k)] = v
            statistics["execution_time"] += time.time() - st

            if cfg.render:
                env.render(False)

            log.info("Reward: " + str(reward))
            log.info("Done: " + str(done))
            log.info("Info: " + str(info))
    
    env.render()
    env.close()
    # log.info("Statistics: " + str(json.dumps(statistics)))


if __name__ == "__main__":
    main()
