
import datetime
import io
import random
import traceback
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import IterableDataset


def episode_len(episode):
    # subtract -1 because the dummy first transition
    return next(iter(episode.values())).shape[0] - 1


def save_episode(episode, fn):
    with io.BytesIO() as bs:
        np.savez_compressed(bs, **episode)
        bs.seek(0)
        with fn.open('wb') as f:
            f.write(bs.read())


def load_episode(fn):
    with fn.open('rb') as f:
        episode = np.load(f)
        episode = {k: episode[k] for k in episode.keys()}
        return episode


class ReplayBufferStorage:
    def __init__(self, data_specs, replay_dir):
        self._data_specs = data_specs
        self._replay_dir = replay_dir
        replay_dir.mkdir(exist_ok=True)
        self._current_episode = defaultdict(list)
        self._preload()

    def __len__(self):
        return self._num_transitions

    def add(self, time_step):
        for spec in self._data_specs:
            value = time_step[spec.name]
            if np.isscalar(value):
                value = np.full(spec.shape, value, spec.dtype)
            assert spec.shape == value.shape and spec.dtype == value.dtype
            self._current_episode[spec.name].append(value)
        if time_step.last():
            episode = dict()
            for spec in self._data_specs:
                value = self._current_episode[spec.name]
                episode[spec.name] = np.array(value, spec.dtype)
            self._current_episode = defaultdict(list)
            self._store_episode(episode)

    def _preload(self):
        self._num_episodes = 0
        self._num_transitions = 0
        for fn in self._replay_dir.glob('*.npz'):
            _, _, eps_len = fn.stem.split('_')
            self._num_episodes += 1
            self._num_transitions += int(eps_len)

    def _store_episode(self, episode):
        eps_idx = self._num_episodes
        eps_len = episode_len(episode)
        self._num_episodes += 1
        self._num_transitions += eps_len
        ts = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
        eps_fn = f'{ts}_{eps_idx}_{eps_len}.npz'
        save_episode(episode, self._replay_dir / eps_fn)


class DualModeReplayBuffer(IterableDataset):
    def __init__(self, replay_dir, max_size, num_workers, nstep, discount,
                 fetch_every, save_snapshot, mode='standard', aux_segment_length=64, aux_stride=6):
        self._replay_dir = replay_dir
        self._size = 0
        self._max_size = max_size
        self._num_workers = max(1, num_workers)
        self._episode_fns = []
        self._episodes = dict()
        self._nstep = nstep
        self._discount = discount
        self._fetch_every = fetch_every
        self._samples_since_last_fetch = fetch_every
        self._save_snapshot = save_snapshot

        self.mode = mode
        self.aux_segment_length = aux_segment_length
        self.aux_stride = aux_stride
        self.aux_required_length = aux_segment_length * aux_stride

    def _sample_standard(self, episode):
        idx = np.random.randint(0, episode_len(episode) - self._nstep + 1) + 1
        obs = episode['observation'][idx - 1]
        action = episode['action'][idx]
        next_obs = episode['observation'][idx + self._nstep - 1]
        reward = np.zeros_like(episode['reward'][idx])
        discount = np.ones_like(episode['discount'][idx])

        for i in range(self._nstep):
            step_reward = episode['reward'][idx + i]
            reward += discount * step_reward
            discount *= episode['discount'][idx + i] * self._discount

        return (obs, action, reward, discount, next_obs)

    def _sample_aux(self, episode):
        if episode_len(episode) < self.aux_required_length:

            return self._sample_standard(episode)

        max_start_idx = episode_len(episode) - self.aux_required_length + 1
        start_idx = np.random.randint(0, max_start_idx - self._nstep) + 1

        observations = []
        actions = []
        next_observations = []
        rewards = []
        discounts = []

        for i in range(self.aux_segment_length):
            idx = start_idx + i * self.aux_stride
            obs = episode['observation'][idx - 1]
            action = episode['action'][idx]
            next_obs = episode['observation'][idx + self._nstep - 1]
            reward = np.zeros_like(episode['reward'][idx])
            discount = np.ones_like(episode['discount'][idx])
            for i in range(self._nstep):
                step_reward = episode['reward'][idx + i]
                reward += discount * step_reward
                discount *= episode['discount'][idx + i] * self._discount

            observations.append(obs)
            actions.append(action)
            next_observations.append(next_obs)
            rewards.append(reward)
            discounts.append(discount)

        episode_total_reward = np.sum(episode['reward'][1:])

        observations = np.array(observations)
        actions = np.array(actions)
        next_observations = np.array(next_observations)
        rewards = np.array(rewards)
        discounts = np.array(discounts)

        return (observations, actions, episode_total_reward, start_idx, next_observations, rewards, discounts)

    def _sample(self):

        if self.mode == 'aux':
            try:
                self._try_fetch_aux()
            except:
                traceback.print_exc()
            self._samples_since_last_fetch += 1

            episode = self._sample_episode()

            return self._sample_aux(episode)
        else:
            try:
                self._try_fetch()
            except:
                traceback.print_exc()
            self._samples_since_last_fetch += 1

            episode = self._sample_episode()

            return self._sample_standard(episode)

    def _sample_episode(self):
        eps_fn = random.choice(self._episode_fns)
        return self._episodes[eps_fn]

    def _store_episode(self, eps_fn):
        try:
            episode = load_episode(eps_fn)
        except:
            return False
        eps_len = episode_len(episode)
        while eps_len + self._size > self._max_size:
            early_eps_fn = self._episode_fns.pop(0)
            early_eps = self._episodes.pop(early_eps_fn)
            self._size -= episode_len(early_eps)
            early_eps_fn.unlink(missing_ok=True)
        self._episode_fns.append(eps_fn)
        self._episode_fns.sort()
        self._episodes[eps_fn] = episode
        self._size += eps_len
        if not self._save_snapshot:
            eps_fn.unlink(missing_ok=True)
        return True

    def _try_fetch(self):
        if self._samples_since_last_fetch < self._fetch_every:
            return
        self._samples_since_last_fetch = 0
        try:
            worker_id = torch.utils.data.get_worker_info().id
        except:
            worker_id = 0
        eps_fns = sorted(self._replay_dir.glob('*.npz'), reverse=True)
        fetched_size = 0
        for eps_fn in eps_fns:
            eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]]
            if eps_idx % self._num_workers != worker_id:
                continue
            if eps_fn in self._episodes.keys():
                break
            if fetched_size + eps_len > self._max_size:
                break
            fetched_size += eps_len
            if not self._store_episode(eps_fn):
                break

    def _try_fetch_aux(self):
        if self._samples_since_last_fetch < self._fetch_every:
            return
        self._samples_since_last_fetch = 0
        try:
            worker_id = torch.utils.data.get_worker_info().id
        except:
            worker_id = 0
        eps_fns = sorted(self._replay_dir.glob('*.npz'), reverse=True)
        fetched_size = 0
        for eps_fn in eps_fns:
            eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]]

            if eps_fn in self._episodes.keys():
                break
            if fetched_size + eps_len > self._max_size:
                break
            fetched_size += eps_len
            if not self._store_episode(eps_fn):
                break

    def __iter__(self):
        while True:
            yield self._sample()


def aux_collate_fn(batch):

    observations_batch = []
    actions_batch = []
    episode_rewards_batch = []
    segment_rewards_batch = []
    start_indices_batch = []
    sample_indices_batch = []
    next_observations_batch = []
    rewards_batch = []
    discounts_batch = []

    for i, sample in enumerate(batch):
        obs, actions, ep_reward, start_idx, next_observations, rewards, discounts= sample
        observations_batch.append(obs)
        actions_batch.append(actions)
        episode_rewards_batch.append(ep_reward)
        start_indices_batch.append(start_idx)
        next_observations_batch.append(next_observations)
        rewards_batch.append(rewards)
        discounts_batch.append(discounts)

    all_observations = np.concatenate(observations_batch, axis=0)
    all_actions = np.concatenate(actions_batch, axis=0)
    all_next_observations = np.concatenate(next_observations_batch, axis=0)
    all_rewards = np.concatenate(rewards_batch, axis=0)
    all_discounts = np.concatenate(discounts_batch, axis=0)

    episode_rewards = np.array(episode_rewards_batch)
    start_indices = np.array(start_indices_batch)

    return (all_observations, all_actions, episode_rewards, start_indices, all_next_observations, all_rewards, all_discounts)


def _worker_init_fn(worker_id):
    seed = np.random.get_state()[1][0] + worker_id
    np.random.seed(seed)
    random.seed(seed)

def make_replay_loader(replay_dir, max_size, batch_size, num_workers,
                       save_snapshot, nstep, discount, mode='standard',
                       aux_segment_length=64, aux_stride=6):

    max_size_per_worker = max_size // max(1, num_workers)

    if mode in ['standard', 'aux', 'both']:

        if mode == 'standard':
            standard_iterable = DualModeReplayBuffer(
                replay_dir, max_size_per_worker, num_workers, nstep, discount,
                fetch_every=1000, save_snapshot=True,
                mode='standard'
            )

            standard_loader = torch.utils.data.DataLoader(
                standard_iterable,
                batch_size=batch_size,
                num_workers=num_workers,
                pin_memory=True,
                worker_init_fn=_worker_init_fn
            )
            return standard_loader

        elif mode == 'aux':
            aux_iterable = DualModeReplayBuffer(
                replay_dir, 100000, num_workers, nstep, discount,
                fetch_every=1000, save_snapshot=save_snapshot,
                mode='aux', aux_segment_length=aux_segment_length, aux_stride=aux_stride
            )

            aux_loader = torch.utils.data.DataLoader(
                aux_iterable,
                batch_size=8,
                num_workers=1,
                pin_memory=True,
                collate_fn=aux_collate_fn,
                worker_init_fn=_worker_init_fn
            )
            return aux_loader

    else:
        raise ValueError(f"mode: {mode} 'standard', 'aux', 'both'")