import os
import sys
import time
import json
import random
import logging
from tqdm import tqdm

import hydra
import numpy as np
from dotenv import load_dotenv
from rich.pretty import pprint
import matplotlib.pyplot as plt

from vtamp.environments.utils import Environment, Task, Updater
from vtamp.policies.utils import Policy
from vtamp.utils import get_log_dir


load_dotenv()
log = logging.getLogger(__name__)


class StreamToLogger:
    def __init__(self, logger, log_level):
        self.logger = logger
        self.log_level = log_level
        self.linebuf = ""

    def write(self, buf):
        for line in buf.rstrip().splitlines():
            self.logger.log(self.log_level, line.rstrip())

    def flush(self):
        pass


def setup_logger():
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    log_dir = get_log_dir()
    log_file = os.path.join(log_dir, f"output.log")

    formatter = logging.Formatter("%(message)s")

    # FileHandler: only log errors to file
    fh = logging.FileHandler(log_file)
    fh.setLevel(logging.ERROR)
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    # Do NOT add a StreamHandler here!

    # Redirect stdout and stderr
    sys.stdout = StreamToLogger(logger, logging.INFO)   # print() as INFO
    sys.stderr = StreamToLogger(logger, logging.ERROR)  # errors as ERROR


CONFIG_DIR = "vtamp/config"
CONFIG_NAMES = [
    "cap_push_line.yaml", "cap_push_circle.yaml"
]
CACHED_LLM_OUTPUTS = [f"cashed_llm_out/{cn}".replace(".yaml", ".txt") for cn in CONFIG_NAMES]
COUNT_PER_CONFIG = 5


def run_all_configs():

    results_dir = "results"
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    
    for i in range(COUNT_PER_CONFIG):
        for ci, config_name in tqdm(enumerate(CONFIG_NAMES)):
            
            with hydra.initialize(config_path=CONFIG_DIR):

                cfg = hydra.compose(config_name=config_name)

                results, img = run_with_cfg(cfg, CACHED_LLM_OUTPUTS[ci])  # List of costs over time n such
                
                run_name = f"{config_name.replace('.yaml', '')}"
    
                path = f"{results_dir}/{run_name}"
                if not os.path.exists(path):
                    os.makedirs(path)
                
                file_name = f"{path}/{i}"

                plt.imsave(f"{file_name}_img.jpg", img)
                with open(f"{file_name}.json", "w") as f:
                    json.dump(results, f, indent=4)
                

def run_with_cfg(cfg, llm_out):

    cfg["seed"] = np.random.randint(0, 1e4)

    pprint(cfg)
    task: Task = hydra.utils.instantiate(cfg.task)
    updater: Updater = hydra.utils.instantiate(cfg.updater)
    env: Environment = hydra.utils.instantiate(
        cfg.env, task=task, render=cfg.render and not cfg.vis_debug,
    )
    obs = env.reset()

    belief = updater.update(obs)

    twin_env: Environment = hydra.utils.get_class(cfg.env._target_).sample_twin(
        env, belief, task, cost_thresh=1e6, render=cfg.vis_debug
    )

    policy: Policy = hydra.utils.instantiate(
        cfg.policy, twin=twin_env, seed=cfg["seed"]
    )

    statistics = {"execution_time": 0, "planning_time": 0}

    for i in range(cfg.get("max_env_steps")):
        goal = env.task.get_goal()
        belief = updater.update(obs)

        st = time.time()
        action, step_statistics = policy.get_action(belief, goal)
        print(action)
        for k, v in step_statistics.items():
            statistics["step_{}_{}".format(i, k)] = v
        statistics["planning_time"] += time.time() - st
        if action is None:
            break

        st = time.time()
        obs, reward, done, info = env.step(action, vis=False)
        for k, v in info.items():
            statistics["step_{}_{}".format(i, k)] = v
        statistics["execution_time"] += time.time() - st

        cost = env.compute_cost()

        img = env.render(False)
    statistics["cost"] = cost

    env.close()
    return statistics, img


if __name__ == "__main__":
    run_all_configs()
