import datetime
import io
import random
import traceback
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import IterableDataset
import copy


def episode_len(episode):
    # subtract -1 because the dummy first transition
    return next(iter(episode.values())).shape[0] - 1


def save_episode(episode, fn):
    with io.BytesIO() as bs:
        np.savez_compressed(bs, **episode)
        bs.seek(0)
        with fn.open('wb') as f:
            f.write(bs.read())


def load_episode(fn):
    with fn.open('rb') as f:
        episode = np.load(f)
        episode = {k: episode[k] for k in episode.keys()}
        return episode


class ReplayBufferStorage:
    def __init__(self, data_specs, meta_specs, replay_dir, n_envs, cfg):
        self._data_specs = data_specs
        self._meta_specs = meta_specs
        self._replay_dir = replay_dir
        self.n_envs = n_envs
        if "her" in cfg:
            self.her = cfg.her
        else:
            self.her = False
        self.cfg = cfg # cfg can be used for her relabling
        replay_dir.mkdir(exist_ok=True)
        self._current_episode = defaultdict(list)
        self._preload()

    def __len__(self):
        return self._num_transitions

    def add(self, time_step, meta):
        for key, value in meta.items():
            self._current_episode[key].append(value)
        for spec in self._data_specs:
            value = time_step[spec.name]
            value = value.reshape((self.n_envs, spec.shape[0])).astype(spec.dtype)
            assert spec.shape[0] == value.shape[1] and self.n_envs == value.shape[0] and spec.dtype == value.dtype
            self._current_episode[spec.name].append(value)
        if time_step.last():
            episode = [{} for _ in range(self.n_envs)]
            for spec in self._data_specs:
                value = np.array(self._current_episode[spec.name])
                for i in range(self.n_envs):
                    episode[i][spec.name] = np.array(value[:, i], spec.dtype)
            for spec in self._meta_specs:
                value = np.array(self._current_episode[spec.name])
                for i in range(self.n_envs):
                    episode[i][spec.name] = np.array(value[:, i], spec.dtype)

            if self.her:
                her_eps = copy.deepcopy(episode)
                eps_len = episode[0]['step'][-1] + 1  # would be 500
                sk_trans = 200  # skill frequency

                for i in range(self.n_envs):
                    relabel_skill = None
                    for t in range(eps_len):
                        idx = t + 1
                        if t % sk_trans == 0:
                            offset = np.random.randint(50)
                            rlb_idx = np.minimum(eps_len, t + sk_trans) - offset # pick a random step
                            sk_dim = int(self.cfg.agent.skill_dim)
                            sk_ch = int(self.cfg.agent.gc_skill_channel)
                            sk_total_dim = int(sk_dim * sk_ch)
                            skill_state = episode[i]['observation'][rlb_idx][:sk_total_dim]
                            skill_state = skill_state.reshape([sk_ch, sk_dim])

                            if self.cfg.agent.return_dist:
                                # Take the min
                                sk_idx = np.argmin(skill_state, axis=1)
                            else:
                                # take the max
                                sk_idx = np.argmax(skill_state, axis=1)

                            relabel_skill = np.zeros_like(skill_state)
                            for j in range(sk_idx.shape[0]):
                                relabel_skill[j][sk_idx[j]] = 1.0
                            relabel_skill = relabel_skill.flatten()

                        her_eps[i]['skill'][idx][:sk_total_dim] = relabel_skill

                self._store_episode(her_eps)

            self._current_episode = defaultdict(list)
            self._store_episode(episode)

    def _preload(self):
        self._num_episodes = 0
        self._num_transitions = 0
        for fn in self._replay_dir.glob('*.npz'):
            _, _, eps_len = fn.stem.split('_')
            self._num_episodes += 1
            self._num_transitions += int(eps_len)

    def _store_episode(self, episode):
        eps_idx = self._num_episodes
        # TODO: We assume each episode has the same length
        eps_len = episode_len(episode[0])
        self._num_episodes += self.n_envs
        self._num_transitions += eps_len * self.n_envs
        ts = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')

        for i in range(self.n_envs):
            eps_fn = f'{ts}_{eps_idx + i}_{eps_len}.npz'
            save_episode(episode[i], self._replay_dir / eps_fn)


class ReplayBuffer(IterableDataset):
    def __init__(self, storage, max_size, num_workers, nstep, discount,
                 fetch_every, save_snapshot):
        self._storage = storage
        self._size = 0
        self._max_size = max_size
        self._num_workers = max(1, num_workers)
        self._episode_fns = []
        self._episodes = dict()
        self._nstep = nstep
        self._discount = discount
        self._fetch_every = fetch_every
        self._samples_since_last_fetch = fetch_every
        self._save_snapshot = save_snapshot

    def _sample_episode(self):
        eps_fn = random.choice(self._episode_fns)
        return self._episodes[eps_fn]

    def _store_episode(self, eps_fn):
        try:
            episode = load_episode(eps_fn)
        except:
            return False
        eps_len = episode_len(episode)
        while eps_len + self._size > self._max_size:
            early_eps_fn = self._episode_fns.pop(0)
            early_eps = self._episodes.pop(early_eps_fn)
            self._size -= episode_len(early_eps)
            early_eps_fn.unlink(missing_ok=True)
        self._episode_fns.append(eps_fn)
        self._episode_fns.sort()
        self._episodes[eps_fn] = episode
        self._size += eps_len

        if not self._save_snapshot:
            eps_fn.unlink(missing_ok=True)
        return True

    def _try_fetch(self):
        if self._samples_since_last_fetch < self._fetch_every:
            return
        self._samples_since_last_fetch = 0
        try:
            worker_id = torch.utils.data.get_worker_info().id
        except:
            worker_id = 0
        eps_fns = sorted(self._storage._replay_dir.glob('*.npz'), reverse=True)
        fetched_size = 0
        for eps_fn in eps_fns:
            eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]]
            if eps_idx % self._num_workers != worker_id:
                continue
            if eps_fn in self._episodes.keys():
                break
            if fetched_size + eps_len > self._max_size:
                break
            fetched_size += eps_len
            if not self._store_episode(eps_fn):
                break

    def _sample(self):
        try:
            self._try_fetch()
        except:
            traceback.print_exc()
        self._samples_since_last_fetch += 1
        episode = self._sample_episode()
        # add +1 for the first dummy transition
        idx = np.random.randint(0, episode_len(episode) - self._nstep + 1) + 1
        meta = []
        for spec in self._storage._meta_specs:
            meta.append(episode[spec.name][idx])
        obs = episode['observation'][idx - 1]
        action = episode['action'][idx]
        next_obs = episode['observation'][idx + self._nstep - 1]
        reward = np.zeros_like(episode['reward'][idx])
        discount = np.ones_like(episode['discount'][idx])
        for i in range(self._nstep):
            step_reward = episode['reward'][idx + i]
            reward += discount * step_reward
            discount *= episode['discount'][idx + i] * self._discount
        return (obs, action, reward, discount, next_obs, *meta)

    def __iter__(self):
        while True:
            yield self._sample()


def _worker_init_fn(worker_id):
    seed = np.random.get_state()[1][0] + worker_id
    np.random.seed(seed)
    random.seed(seed)


def make_replay_loader(storage, max_size, batch_size, num_workers,
                       save_snapshot, nstep, discount):
    max_size_per_worker = max_size // max(1, num_workers)

    iterable = ReplayBuffer(storage,
                            max_size_per_worker,
                            num_workers,
                            nstep,
                            discount,
                            fetch_every=1000,
                            save_snapshot=save_snapshot)

    g = torch.Generator()
    g.manual_seed(0)

    loader = torch.utils.data.DataLoader(iterable,
                                         batch_size=batch_size,
                                         num_workers=num_workers,
                                         pin_memory=True,
                                         worker_init_fn=_worker_init_fn,
                                         generator=g)
    return loader
