
import os
import logging
import datetime
import json
import time
import numpy as np
import contextlib
import sys
from typing import Any, Dict, Generator, Optional, Sequence, Tuple, Union, Mapping, List,  Iterator

import stable_baselines3.common.logger as sb_logger

from imitation.data import types
from d3rlpy.logger import default_json_encoder
import structlog

LOG: structlog.BoundLogger = structlog.get_logger(__name__)


def default_json_encoder(obj: Any) -> Any:
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    raise ValueError(f"invalid object type: {type(obj)}")

def make_log_dir(args, log_dir=None):
    if log_dir is None:
        env_sanitized = args.env.replace("/", "_")
        log_dir = os.path.join(
            "output", args.command_name, env_sanitized,
        )
        if args.extra_log_rep != "":
            log_dir = os.path.join(log_dir, args.extra_log_rep)
        log_dir = os.path.join(log_dir, 
                    f'seed_{args.seed}',
                    datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))
    
    os.makedirs(log_dir, exist_ok=True)
    return log_dir



def make_output_format(
    _format: str,
    log_dir: str,
    log_suffix: str = "",
    max_length: int = 40,
) -> sb_logger.KVWriter:
    """Returns a logger for the requested format.

    Args:
        _format: the requested format to log to
            ('stdout', 'log', 'json' or 'csv' or 'tensorboard').
        log_dir: the logging directory.
        log_suffix: the suffix for the log file.
        max_length: the maximum length beyond which the keys get truncated.

    Returns:
        the logger.
    """
    os.makedirs(log_dir, exist_ok=True)
    if _format == "stdout":
        return sb_logger.HumanOutputFormat(sys.stdout, max_length=max_length)
    elif _format == "log":
        return sb_logger.HumanOutputFormat(
            os.path.join(log_dir, f"log{log_suffix}.txt"),
            max_length=max_length,
        )
    else:
        return sb_logger.make_output_format(_format, log_dir, log_suffix)

def _build_output_formats(
    folder: str,
    format_strs: Sequence[str],
) -> Sequence[sb_logger.KVWriter]:
    """Build output formats for initializing a Stable Baselines Logger.

    Args:
        folder: Path to directory that logs are written to.
        format_strs: A list of output format strings. For details on available
            output formats see `stable_baselines3.logger.make_output_format`.

    Returns:
        A list of output formats, one corresponding to each `format_strs`.
    """
    os.makedirs(folder, exist_ok=True)
    output_formats = []
    for f in format_strs:
        if f == "wandb":
            output_formats.append(WandbOutputFormat())
        else:
            output_formats.append(make_output_format(f, folder))
    return output_formats






class HierarchicalLogger(sb_logger.Logger):
    """A logger supporting contexts for accumulating mean values.

    `self.accumulate_means` creates a context manager. While in this context,
    values are loggged to a sub-logger, with only mean values recorded in the
    top-level (root) logger.
    """

    def __init__(
        self,
        default_logger: sb_logger.Logger,
        format_strs: Sequence[str] = ("stdout", "log", "csv"),
    ):
        """Builds HierarchicalLogger.

        Args:
            default_logger: The default logger when not in an `accumulate_means`
                context. Also the logger to which mean values are written to after
                exiting from a context.
            format_strs: A list of output format strings that should be used by
                every Logger initialized by this class during an `AccumulatingMeans`
                context. For details on available output formats see
                `stable_baselines3.logger.make_output_format`.
        """
        self.default_logger = default_logger
        self.current_logger = None
        self._cached_loggers = {}
        self._subdir = None
        self.format_strs = format_strs

        self._tensorboard_step = 0
        self._current_epoch = 0

        super().__init__(folder=self.default_logger.dir, output_formats=[])

    def _update_name_to_maps(self) -> None:
        self.name_to_value = self._logger.name_to_value
        self.name_to_count = self._logger.name_to_count
        self.name_to_excluded = self._logger.name_to_excluded

    @contextlib.contextmanager
    def accumulate_means(self, subdir: types.AnyPath) -> Generator[None, None, None]:
        """Temporarily modifies this HierarchicalLogger to accumulate means values.

        During this context, `self.record(key, value)` writes the "raw" values in
        "{self.default_logger.log_dir}/{subdir}" under the key "raw/{subdir}/{key}".
        At the same time, any call to `self.record` will also accumulate mean values
        on the default logger by calling
        `self.default_logger.record_mean(f"mean/{subdir}/{key}", value)`.

        During the context, `self.record(key, value)` will write the "raw" values in
        `"{self.default_logger.log_dir}/subdir"` under the key "raw/{subdir}/key".

        After the context exits, calling `self.dump()` will write the means
        of all the "raw" values accumulated during this context to
        `self.default_logger` under keys with the prefix `mean/{subdir}/`

        Note that the behavior of other logging methods, `log` and `record_mean`
        are unmodified and will go straight to the default logger.

        Args:
            subdir: A string key which determines the `folder` where raw data is
                written and temporary logging prefixes for raw and mean data. Entering
                an `accumulate_means` context in the future with the same `subdir`
                will safely append to logs written in this folder rather than
                overwrite.

        Yields:
            None when the context is entered.

        Raises:
            RuntimeError: If this context is entered into while already in
                an `accumulate_means` context.
        """
        if self.current_logger is not None:
            raise RuntimeError("Nested `accumulate_means` context")

        if subdir in self._cached_loggers:
            logger = self._cached_loggers[subdir]
        else:
            subdir = types.path_to_str(subdir)
            folder = os.path.join(self.default_logger.dir, "raw", subdir)
            os.makedirs(folder, exist_ok=True)
            output_formats = _build_output_formats(folder, self.format_strs)
            logger = sb_logger.Logger(folder, list(output_formats))
            self._cached_loggers[subdir] = logger

        try:
            self.current_logger = logger
            self._subdir = subdir
            self._update_name_to_maps()
            yield
        finally:
            self.current_logger = None
            self._subdir = None
            self._update_name_to_maps()

    def record(self, key, val, exclude=None):
        if self.current_logger is not None:  # In accumulate_means context.
            assert self._subdir is not None
            raw_key = "/".join(["raw", self._subdir, key])
            self.current_logger.record(raw_key, val, exclude)

            mean_key = "/".join(["mean", self._subdir, key])
            self.default_logger.record_mean(mean_key, val, exclude)
        else:  # Not in accumulate_means context.
            self.default_logger.record(key, val, exclude)

    @property
    def _logger(self):
        if self.current_logger is not None:
            return self.current_logger
        else:
            return self.default_logger

    def dump(self, step=0):
        self._logger.dump(step)

    def get_dir(self) -> str:
        return self._logger.get_dir()

    def log(self, *args, **kwargs):
        self.default_logger.log(*args, **kwargs)

    def set_level(self, level: int) -> None:
        self.default_logger.set_level(level)

    def record_mean(self, key, val, exclude=None):
        self.default_logger.record_mean(key, val, exclude)

    def close(self):
        self.default_logger.close()
        for logger in self._cached_loggers.values():
            logger.close()

    def reset_tensorboard_steps(self):
        self._tensorboard_step = 0

    def log_epoch(self, epoch_number):
        self._current_epoch = epoch_number

    def log_batch(
        self,
        batch_num: int,
        batch_size: int,
        num_samples_so_far: int,
        training_metrics: dict,
        train_stats = None,
        test_stats = None,
        rollout_stats = None,
    ):
        self.record("batch_size", batch_size)
        self.record("epoch", self._current_epoch)
        self.record("batch", batch_num)
        self.record("samples_so_far", num_samples_so_far)
        for k, v in training_metrics.items():
            self.record(k, float(v))

        if train_stats is not None:
            for k, v in train_stats.items():
                self.record(f"{k}_train", float(v))
        
        if test_stats is not None:
            for k, v in test_stats.items():
               self.record(f"{k}_test", float(v))

        if rollout_stats is not None:
            for k, v in rollout_stats.items():
                if "return" in k and "monitor" not in k:
                    self.record("rollout/" + k, v)
        self.dump(self._tensorboard_step)
        self._tensorboard_step += 1

    def __getstate__(self):
        state = self.__dict__.copy()
        del state["_logger"]
        return state


class WandbOutputFormat(sb_logger.KVWriter):
    """A stable-baseline logger that writes to wandb.

    Users need to call `wandb.init()` before initializing `WandbOutputFormat`.
    """

    def __init__(self):
        """Initializes an instance of WandbOutputFormat.

        Raises:
            ModuleNotFoundError: wandb is not installed.
        """
        try:
            import wandb
        except ModuleNotFoundError as e:
            raise ModuleNotFoundError(
                "Trying to log data with `WandbOutputFormat` "
                "but `wandb` not installed: try `pip install wandb`.",
            ) from e
        self.wandb_module = wandb

    def write(
        self,
        key_values: Dict[str, Any],
        key_excluded: Dict[str, Union[str, Tuple[str, ...]]],
        step: int = 0,
    ) -> None:
        for (key, value), (key_ex, excluded) in zip(
            sorted(key_values.items()),
            sorted(key_excluded.items()),
        ):
            assert key == key_ex, f"key mismatch between {key} and {key_ex}."
            if excluded is not None and "wandb" in excluded:
                continue

            self.wandb_module.log({key: value}, step=step)
        self.wandb_module.log({}, commit=True)

    def close(self) -> None:
        self.wandb_module.finish()




def setup_logging(args, log_dir = None, log_format_strs = ["tensorboard", "stdout"]):
    # SETUP LOGGING (tensorboard, wandb, text)
    log_dir = make_log_dir(args, log_dir)
    
    # convert strings of digits to numbers; but leave levels like 'INFO' unmodified
    log_level = logging.INFO
    try:
        log_level = int(log_level)
    except ValueError:
        pass
    logging.basicConfig(level=log_level)
    
    folder = types.path_to_str(os.path.join(log_dir, "log"))
    output_formats = _build_output_formats(folder, log_format_strs)
    default_logger = sb_logger.Logger(folder, list(output_formats))
    hier_format_strs = [f for f in log_format_strs if f != "wandb"]
    hier_logger = HierarchicalLogger(default_logger, hier_format_strs)
    return hier_logger, log_dir


######################################################################

class D3RLPyLogger:

    def __init__(
        self, logger,
    ):
        self.logger = logger
        self._logdir = '/'.join(self.logger.dir.split('/')[:-1])
        print(self._logdir)
        self._params = None
        self._metrics_buffer = {}

    def add_params(self, params: Dict[str, Any]) -> None:
        assert self._params is None, "add_params can be called only once."

        # save dictionary as json file
        params_path = os.path.join(self._logdir, "params.json")
        with open(params_path, "w") as f:
            json_str = json.dumps(
                params, default=default_json_encoder, indent=2
            )
            f.write(json_str)

        LOG.info(f"Parameters are saved to {params_path}", params=params)

        # remove non-scaler values for HParams
        self._params = {k: v for k, v in params.items() if np.isscalar(v)}

    def add_metric(self, name: str, value: float) -> None:
        if name not in self._metrics_buffer:
            self._metrics_buffer[name] = []
        self._metrics_buffer[name].append(value)


    def commit(self, epoch: int, step: int) -> Dict[str, float]:
        metrics = {}
        for name, buffer in self._metrics_buffer.items():

            metric = np.mean(buffer) #sum(buffer) / len(buffer)

            self.logger.record(name, metric)
            #self._writer.add_scalar(f"metrics/{name}", metric, epoch)

            metrics[name] = metric

            self.logger.record(name, metric)
            self.logger.dump(step=step)

        LOG.info(
                f"epoch={epoch} step={step}",
                epoch=epoch,
                step=step,
                metrics=metrics,
            )

        self._metrics_buffer = {}
        return metrics

    def save_model(self, epoch, algo) -> None:
        # save entire model
        model_path = os.path.join(self._logdir, f"model_{epoch}.pt")
        algo.save_model(model_path)
        LOG.info(f"Model parameters are saved to {model_path}")

    def close(self) -> None:
        self.logger.close()

    @contextlib.contextmanager
    def measure_time(self, name: str) -> Iterator[None]:
        name = "time_" + name
        start = time.time()
        try:
            yield
        finally:
            self.add_metric(name, time.time() - start)

    @property
    def logdir(self) -> str:
        return self._logdir

    @property
    def experiment_name(self) -> str:
        return 

