import json
import os
import time

from gym import error
from gym.utils import atomic_write
from gym.utils.json_utils import json_encode_np


class StatsRecorder:
    def __init__(self, directory, file_prefix, autoreset=False, env_id=None):
        self.autoreset = autoreset
        self.env_id = env_id

        self.initial_reset_timestamp = None
        self.directory = directory
        self.file_prefix = file_prefix
        self.episode_lengths = []
        self.episode_rewards = []
        self.episode_types = []  # experimental addition
        self._type = "t"
        self.timestamps = []
        self.steps = None
        self.total_steps = 0
        self.rewards = None

        self.done = None
        self.closed = False

        filename = f"{self.file_prefix}.stats.json"
        self.path = os.path.join(self.directory, filename)

    @property
    def type(self):
        return self._type

    @type.setter
    def type(self, type):
        if type not in ["t", "e"]:
            raise error.Error(
                "Invalid episode type {}: must be t for training or e for evaluation",
                type,
            )
        self._type = type

    def before_step(self, action):
        assert not self.closed

        if self.done:
            raise error.ResetNeeded(
                "Trying to step environment which is currently done. While the monitor is active for {}, you cannot step beyond the end of an episode. Call 'env.reset()' to start the next episode.".format(
                    self.env_id
                )
            )
        elif self.steps is None:
            raise error.ResetNeeded(
                "Trying to step an environment before reset. While the monitor is active for {}, you must call 'env.reset()' before taking an initial step.".format(
                    self.env_id
                )
            )

    def after_step(self, observation, reward, done, info):
        self.steps += 1
        self.total_steps += 1
        self.rewards += reward
        self.done = done

        if done:
            self.save_complete()

        if done:
            if self.autoreset:
                self.before_reset()
                self.after_reset(observation)

    def before_reset(self):
        assert not self.closed

        if self.done is not None and not self.done and self.steps > 0:
            raise error.Error(
                "Tried to reset environment which is not done. While the monitor is active for {}, you cannot call reset() unless the episode is over.".format(
                    self.env_id
                )
            )

        self.done = False
        if self.initial_reset_timestamp is None:
            self.initial_reset_timestamp = time.perf_counter()

    def after_reset(self, observation):
        self.steps = 0
        self.rewards = 0
        # We write the type at the beginning of the episode. If a user
        # changes the type, it's more natural for it to apply next
        # time the user calls reset().
        self.episode_types.append(self._type)

    def save_complete(self):
        if self.steps is not None:
            self.episode_lengths.append(self.steps)
            self.episode_rewards.append(float(self.rewards))
            self.timestamps.append(time.perf_counter())

    def close(self):
        self.flush()
        self.closed = True

    def flush(self):
        if self.closed:
            return

        with atomic_write.atomic_write(self.path) as f:
            json.dump(
                {
                    "initial_reset_timestamp": self.initial_reset_timestamp,
                    "timestamps": self.timestamps,
                    "episode_lengths": self.episode_lengths,
                    "episode_rewards": self.episode_rewards,
                    "episode_types": self.episode_types,
                },
                f,
                default=json_encode_np,
            )
