import logging
import os

from typing import Dict, Sequence
from torch.utils import tensorboard
from expground.logger import Log

import shutil
import yaml


class ExpConfig:
    def __init__(self, raw_yaml, base_path, seed=None):
        self.yaml = raw_yaml
        self.base_path = base_path
        if seed is not None:
            self.yaml["learner_config"]["params"]["seed"] = seed
            self.seed = seed
        else:
            self.seed = self.yaml["learner_config"]["params"]["seed"]

    @property
    def exp_name(self) -> str:
        """Return the name of current experiment, grouped by `prefix_learner_type_env_algo`.

        Returns:
            str: A str-like experiment name.
        """

        # {exp_prefix}_{learner_type}_{env}_{algo}
        exp_prefix = self.yaml.get("exp_prefix")
        learner_type = self.yaml["learner_config"]["type"]
        env = self.yaml["env_config"]["env_id"]
        algo = self.yaml["algorithm"]["name"]
        if exp_prefix is None:
            return "{}_{}_{}".format(learner_type, env, algo)
        else:
            return "{}_{}_{}_{}".format(exp_prefix, learner_type, env, algo)

    @property
    def log_path(self):
        path = os.path.join(self.base_path, self.exp_name)
        path = os.path.join(path, str(self.seed))
        return path

    def init_logpath(
        self,
    ):
        base_path = self.log_path
        if os.path.exists(base_path):
            # delete all files in path
            Log.warning("Found existed log data in %s, remove it.", base_path)
            shutil.rmtree(base_path)
        Log.info("Save log data in %s", base_path)
        os.makedirs(base_path)

        # dump config file here.
        with open(os.path.join(base_path, "config.yaml"), "w") as f:
            yaml.dump(self.yaml, f)

    def get_path(self, sub_name=None):
        log_path = self.log_path
        if sub_name is None:
            return log_path
        log_path = os.path.join(log_path, sub_name)
        if not os.path.exists(log_path):
            # create sub dirs0
            os.mkdir(log_path)
        return log_path


def get_logger(name, log_dir, log_level):
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )

    logger = logging.getLogger(name)
    logger.setLevel(log_level)

    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

    # init file
    if log_dir is not None:
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        f_name = os.path.join(log_dir, f"{name}.log")
        file_handler = logging.FileHandler(filename=f_name)
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)

    return logger


def append_to_table(
    file_name: str,
    title: list,
    row: list,
):
    f = None
    # create the csv file with title.
    if not os.path.exists(file_name):
        f = open(file_name, "w")
        for i, t in enumerate(title):
            if i < len(title) - 1:
                f.write("{}, ".format(t))
            else:
                f.write("{}".format(t))
        f.write("\n")

    # append new infos
    if f is None:
        f = open(file_name, "a")
    if row is not None:
        for i, d in enumerate(row):
            if i < len(row) - 1:
                f.write("{}, ".format(d))
            else:
                f.write("{}".format(d))
        f.write("\n")

    f.close()


def write_to_tensorboard(
    writer: tensorboard.SummaryWriter, info: Dict, global_step: int, prefix: str
):
    """Write learning info to tensorboard.

    Args:
        writer (tensorboard.SummaryWriter): The summary writer instance.
        info (Dict): The information dict.
        global_step (int): The global step indicator.
        prefix (str): Prefix added to keys in the info dict.
    """
    if writer is None:
        return

    prefix = f"{prefix}/" if len(prefix) > 0 else ""
    for k, v in info.items():
        if isinstance(v, dict):
            # add k to prefix
            write_to_tensorboard(writer, v, global_step, f"{prefix}{k}")
        elif isinstance(v, Sequence):
            raise NotImplementedError(
                f"Sequence value cannot be logged currently: {v}."
            )
        elif v is None:
            continue
        else:
            writer.add_scalar(f"{prefix}{k}", v, global_step=global_step)
