import os
import sys
import time
import json
import random
import logging
from tqdm import tqdm

import hydra
import numpy as np
from dotenv import load_dotenv
from rich.pretty import pprint
import matplotlib.pyplot as plt

from vtamp.environments.utils import Environment, Task, Updater
from vtamp.policies.utils import Policy
from vtamp.utils import get_log_dir


load_dotenv()
log = logging.getLogger(__name__)


class StreamToLogger:
    def __init__(self, logger, log_level):
        self.logger = logger
        self.log_level = log_level
        self.linebuf = ""

    def write(self, buf):
        for line in buf.rstrip().splitlines():
            self.logger.log(self.log_level, line.rstrip())

    def flush(self):
        pass


def setup_logger():
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    log_dir = get_log_dir()
    log_file = os.path.join(log_dir, f"output.log")

    formatter = logging.Formatter("%(message)s")

    # FileHandler: only log errors to file
    fh = logging.FileHandler(log_file)
    fh.setLevel(logging.ERROR)
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    # Do NOT add a StreamHandler here!

    # Redirect stdout and stderr
    sys.stdout = StreamToLogger(logger, logging.INFO)   # print() as INFO
    sys.stderr = StreamToLogger(logger, logging.ERROR)  # errors as ERROR


CONFIG_DIR = "vtamp/config"
CONFIG_NAMES = [
    "deneck_push_line_reprompt.yaml" # , "deneck_push_circle.yaml"
    # "deneck_push_circle.yaml", "proc3s_push_circle.yaml"
    # "proc3s_push_line.yaml", "proc3s_push_circle.yaml"
]
CACHED_LLM_OUTPUTS = [f"cashed_llm_out/{cn}".replace(".yaml", ".txt") for cn in CONFIG_NAMES]
start = 270
COUNT_PER_CONFIG = 1
# COUNT_PER_CONFIG = 1
# COUNT_PER_CONFIG = 101


def run_all_configs():

    results_dir = "results"
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    
    for i in range(start, start+COUNT_PER_CONFIG):
        for ci, config_name in tqdm(enumerate(CONFIG_NAMES)):
            
            with hydra.initialize(config_path=CONFIG_DIR):

                cfg = hydra.compose(config_name=config_name)

                optims = ["cma"] if "deneck" in config_name else [None]
                
                for optim in optims:
                    
                    results, img = run_with_cfg(cfg, CACHED_LLM_OUTPUTS[ci], optim)  # List of costs over time n such
                    
                    run_name = f"{config_name.replace('.yaml', '')}"
                    if not (optim is None):
                        run_name += f"_{optim}"
                        if optim == "cma":
                            run_name += "03"
        
                    path = f"{results_dir}/{run_name}"
                    if not os.path.exists(path):
                        os.makedirs(path)
                    
                    file_name = f"{path}/{i}"

                    plt.imsave(f"{file_name}_img.jpg", img)
                    with open(f"{file_name}.json", "w") as f:
                        json.dump(results, f, indent=4)
                

def run_with_cfg(cfg, llm_out, optim=None):

    cfg["seed"] = np.random.randint(0, 1e4)

    pprint(cfg)
    task: Task = hydra.utils.instantiate(cfg.task)
    updater: Updater = hydra.utils.instantiate(cfg.updater)
    env: Environment = hydra.utils.instantiate(
        cfg.env, task=task, render=cfg.render and not cfg.vis_debug,
    )
    obs = env.reset()

    belief = updater.update(obs)

    twin_env: Environment = hydra.utils.get_class(cfg.env._target_).sample_twin(
        env, belief, task, cost_thresh=1e6, render=cfg.vis_debug
    )

    if optim:
        policy: Policy = hydra.utils.instantiate(
            cfg.policy, twin=twin_env, seed=cfg["seed"], optim=optim
        )
    else:
        policy: Policy = hydra.utils.instantiate(
            cfg.policy, twin=twin_env, seed=cfg["seed"]
        )

    statistics = {"execution_time": 0, "planning_time": 0}

    for i in range(cfg.get("max_env_steps")):
        goal = env.task.get_goal()
        belief = updater.update(obs)

        st = time.time()
        action, step_statistics = policy.get_action(belief, goal)
        print(action)
        for k, v in step_statistics.items():
            statistics["step_{}_{}".format(i, k)] = v
        statistics["planning_time"] += time.time() - st
        if action is None:
            break

        st = time.time()
        obs, reward, done, info = env.step(action, vis=False)
        for k, v in info.items():
            statistics["step_{}_{}".format(i, k)] = v
        statistics["execution_time"] += time.time() - st

        img = env.render(False)

    env.close()
    return statistics, img

if __name__ == "__main__":
    run_all_configs()
