import io
import sys
import traceback
from collections import OrderedDict, defaultdict, deque
from typing import List, Dict
from copy import deepcopy
from concurrent.futures import ThreadPoolExecutor
import numpy as np

import elements
from elements import UUID
from utils.tools import merge_dict_list, UniformSampler

class Episode:
    def __init__(self):
        self.time: str = elements.timestamp(millis=True)
        self.uuid: UUID = UUID()
        self.data: Dict[str, List[np.ndarray]] = defaultdict(list)
        self.length: int = 0

    @property
    def filename(self):
        return f'{self.time}-{str(self.uuid)}-{self.length}.npz'

    def add(self, step_dict: Dict[str, np.ndarray]):
        for key, value in step_dict.items():
            if not key.startswith("log_"):
                self.data[key].append(value)
        self.length += 1

    def stats(self, rewards_reduce: str) -> Dict[str, float | int]:
        rewards: np.ndarray = np.array(self.data["rewards"], dtype=np.float32).sum(axis=0).reshape(-1)
        agent_mask: np.ndarray = np.array(self.data["agent_mask"], dtype=np.float32).mean(axis=0).reshape(-1)
        # rewards.shape = (n_agents, )
        metrics = {}
        # calculate the rewards for each agent
        for i in range(rewards.shape[0]):
            metrics[f"agent_{i}/rewards"] = rewards[i]
            metrics[f"agent_{i}/death_ratio"] = 1 - agent_mask[i]

        if rewards_reduce == "sum":
            rewards = np.sum(rewards)
        elif rewards_reduce == "mean":
            rewards = np.mean(rewards)
        else:
            raise ValueError(f"Rewards reduce {rewards_reduce} not supported")
        metrics["rewards"] = rewards
        metrics["death_ratio"] = 1 - agent_mask.mean()
        metrics["length"] = len(self) - 1
        return metrics

    def save(self, directory) -> Dict:
        directory = elements.Path(directory)
        filename = directory / self.filename
        with io.BytesIO() as stream:
            np.savez_compressed(stream, **self.data)
            stream.seek(0)
            filename.write(stream.read(), mode='wb')

    @classmethod
    def load(cls, filename, error="raise"):
        assert error in ("raise", "none")
        time, uuid, length = filename.stem.split("-")[:3]
        length = int(length)
        try:
            with filename.open("rb") as f:
                data = np.load(f)
                data = {k: data[k] for k in data.keys()}
        except:
            tb = ''.join(traceback.format_exception(sys.exception()))
            print(f'Error loading chunk {filename}:\n{tb}')
            if error == 'raise':
                raise
            else:
                return None

        episode = cls()
        episode.time = time
        episode.uuid = UUID(uuid)
        episode.data = data
        episode.length = length
        return episode

    def __len__(self) -> int:
        return self.length

class ReplayBuffer:
    def __init__(self, config, n_rollout_threads, agg = None):
        # parameters for saving and loading
        self.directory = elements.Path(config.logdir) / "replay"
        self.directory.mkdir()

        self.config = config
        # length of sampled episodes
        self.length = config.train.burn_in_length + config.train.batch_length
        # number of rollout threads
        self.n_rollout_threads = n_rollout_threads
        # current episodes indexed by worker id
        self.current: List[Episode] = [Episode() for _ in range(n_rollout_threads)]
        # buffer for episodes available for training
        self.episodes: Dict[UUID, Episode] = OrderedDict((episode.uuid, episode) for episode in self.current)
        # efficient sampler for sampling an episode
        self.episode_sampler = UniformSampler()
        # count the number of samples in the buffer
        self.num_steps_for_training = 0
        # mapping from uuid to the filepath of a completed episode
        self.uuid_to_filepath: Dict[UUID, str] = dict()
        # episode stats aggregator
        self._agg = agg

        # concurrent executor
        self.thread_pool = ThreadPoolExecutor(16)
        self.jobs = deque()

    @elements.timer.section("add")
    def add(self, step_dict: Dict[str, np.ndarray], worker: int):
        # add step_dict to the current episode
        episode = self.current[worker]
        episode.add(step_dict)
        self.num_steps_for_training += 1

        # remove old samples exceeding the capacity
        with elements.timer.section("remove"):
            if self.config.replay.capacity and self.num_steps_for_training > self.config.replay.capacity:
                oldest_episode_uuid, oldest_episode = next(iter(self.episodes.items()))
                self.num_steps_for_training -= len(oldest_episode)
                del self.episodes[oldest_episode_uuid]
                del self.episode_sampler[oldest_episode_uuid]
                if self.config.replay.offload:
                    del self.uuid_to_filepath[oldest_episode_uuid]

        # episode having one transition is ready for training
        if episode.uuid not in self.episode_sampler and len(episode) >= 2:
            self.episode_sampler(episode.uuid)

        # check if current episode is done
        done = step_dict["terminated"] or step_dict["truncated"]
        if done:
            self.jobs.append(self.thread_pool.submit(episode.save, self.directory))
            if self.config.replay.offload:
                filepath = self.directory / episode.filename
                self.uuid_to_filepath[episode.uuid] = str(filepath)
                self.episodes[episode.uuid] = None
            # create new episode
            new_episode = Episode()
            self.current[worker] = new_episode
            self.episodes[new_episode.uuid] = new_episode
            # log the stat of completed episode
            if self._agg:
                self._agg.add(episode.stats(rewards_reduce=self.config.logging.rewards_reduce))

    @elements.timer.section("create_dataset")
    def create_dataset(self) -> Dict[str, np.ndarray]:
        # wait for all done episodes to be saved
        while self.jobs:
            job = self.jobs.popleft()
            job.result()

        if self.config.replay.offload:
            samples = [self.thread_pool.submit(self._sample) for _ in range(self.config.train.batch_size)]
            samples: List[Dict[str, np.ndarray]] = [future.result() for future in samples]
        else:
            samples: List[Dict[str, np.ndarray]] = [self._sample() for _ in range(self.config.train.batch_size)]
        samples: Dict[str, np.ndarray] = merge_dict_list(samples, axis=1)
        # samples[key].shape = (ts, bs, ...)
        return samples

    def _sample(self) -> Dict[str, np.ndarray]:
        sample, size = defaultdict(list), 0
        while size < self.length:
            episode_uuid = self.episode_sampler.sample()
            # self.episodes[episode_uuid] is None if the episode is offloaded
            episode = self.episodes[episode_uuid] or self.load_episode_from_disk(episode_uuid)
            # first trajectory begins from a random step, subsequent trajectories begin from the first step
            idx = np.random.randint(len(episode) - 1) if size == 0 else 0
            length = min(idx + (self.length - size), len(episode)) - idx
            for key in episode.data.keys():
                sample[key].extend(deepcopy(episode.data[key][idx: idx+length]))
            size += length
            # set the initial and terminal masks
            if "is_first" in sample.keys():
                sample["is_first"][-length] = np.array([True], dtype=bool)
            if "prev_actions" in sample.keys():
                sample["prev_actions"][-length] = np.zeros_like(sample["prev_actions"][-length])
            if "truncated" in sample.keys():
                sample["truncated"][-1] = np.array([True], dtype=bool)

        # assemble the sample
        for key in sample.keys():
            assert len(sample[key]) == self.length, (key, len(sample[key]))
        sample = {k: np.stack(v) for k, v in sample.items()}
        return sample

    def load_episode_from_disk(self, uuid: UUID):
        filepath = elements.Path(self.uuid_to_filepath[uuid])
        episode = Episode.load(filepath)
        return episode

    def clear(self):
        self.current = [Episode() for _ in range(self.n_rollout_threads)]
        self.episodes = OrderedDict((episode.uuid, episode) for episode in self.current)
        self.episode_sampler = UniformSampler()
        self.num_steps_for_training = 0

    def save(self):
        pass

    def load(self, data=None, capacity=None):
        """
        Load previously saved episodes to the in-memory buffer.
        """
        # load saved episode names from disk
        directory = elements.Path(self.config.train.checkpoint.from_checkpoint).parent.parent / "replay"
        names = [x.stem for x in directory.glob("*.npz")]
        names = sorted(names, reverse=True)
        if not names:
            return

        # count the number of episodes to load
        capacity = capacity or self.config.replay.capacity
        num_episodes = 0
        for name in names:
            length = name.split("-")[2]
            length = int(length)
            if capacity < length:
                break
            num_episodes += 1
            capacity -= length
        names = names[:num_episodes]

        # restore the episodes and sampler
        if self.config.replay.offload:
            for name in reversed(names):
                _, uuid, length = name.split("-")
                uuid = UUID(uuid)
                filepath = directory / (name + ".npz")
                self.uuid_to_filepath[uuid] = str(filepath)
                self.episodes[uuid] = None
                self.episode_sampler(uuid)
        else:
            jobs = []
            for name in reversed(names):
                jobs.append(self.thread_pool.submit(Episode.load, directory / (name + ".npz")))
            for job in jobs:
                episode = job.result()
                self.episodes[episode.uuid] = episode
                self.episode_sampler(episode.uuid)

class OnPolicyBuffer:
    def __init__(self, config, n_rollout_threads, agg = None):
        self.config = config
        # number of rollout threads
        self.n_rollout_threads = n_rollout_threads
        # the latest episode for each worker
        self.current: List[Episode] = [Episode() for _ in range(n_rollout_threads)]
        # the concatenated episodes sampled in one epoch for each worker
        self.completed: List[Episode] = [Episode() for _ in range(n_rollout_threads)]
        # episode stats aggregator
        self._agg = agg

    @elements.timer.section("add")
    def add(self, step_dict: Dict[str, np.ndarray], worker: int):
        # add step_dict to episode
        self.current[worker].add(step_dict)
        self.completed[worker].add(step_dict)

        # check if current episode is done
        done = step_dict["terminated"] or step_dict["truncated"]
        if done:
            # save done episode
            if self._agg:
                self._agg.add(self.current[worker].stats(rewards_reduce=self.config.logging.rewards_reduce))
            self.current[worker] = Episode()

    @elements.timer.section("create_dataset")
    def create_dataset(self) -> Dict[str, np.ndarray]:
        samples = defaultdict(list)
        for episode in self.completed:
            for key in episode.data.keys():
                sample = np.stack(episode.data[key], axis=0)
                samples[key].append(sample)
        samples: Dict[str, np.ndarray] = {k: np.stack(v, axis=1) for k, v in samples.items()}
        # samples[key].shape = (bs, ts, ...)
        return samples

    def reset(self):
        """
        Reset the buffer by leaving only the latest data. This method is called by the training buffer.
        """
        min_len = min(len(v) for v in self.completed[0].data.values())
        for episode in self.completed:
            for key, value in episode.data.items():
                episode.data[key] = value[min_len:]

    def clear(self):
        """
        Completely clear the buffers by reinitializing them. This method is called by the evaluation buffer.
        """
        self.current: List[Episode] = [Episode() for i in range(self.n_rollout_threads)]
        self.completed: List[Episode] = [Episode() for i in range(self.n_rollout_threads)]
