from typing import Optional, NamedTuple, Tuple

import itertools

import dm_env

import jax

import numpy as np

from rl import rollout
from rl import policy


class Dataset(NamedTuple):
    r_t: np.ndarray
    discount_t: np.ndarray
    obs_t: np.ndarray
    a_t: np.ndarray
    obs_tp1: np.ndarray

    @property
    def shape(self) -> Tuple[int, ...]:
        return self.r_t.shape

    def num_samples(self) -> int:
        return int(np.prod(self.shape))

    def save(self, file, compress: bool = True):
        if compress:
            np.savez_compressed(file, **self._asdict())
        else:
            np.savez(file, **self._asdict())

    @classmethod
    def load(cls, file):
        return cls(**np.load(file))


def generate_dataset(seed: int,
                     env: dm_env.Environment,
                     policy: policy.Policy,
                     num_trajs: int,
                     max_steps_per_traj: Optional[int] = None,
                     ) -> Dataset:

    key = jax.random.PRNGKey(seed)
    trajs = []
    for i in range(num_trajs):
        traj_key, key = jax.random.split(key)
        trajs.append(
            rollout.generate_trajectory(traj_key, env, policy, max_steps=max_steps_per_traj))

    return pack_trajectories(trajs)


def pack_trajectories(trajs) -> Dataset:
    timestep_trajs, actions = zip(*trajs)
    
    timestep_trajs = [zip(*traj) for traj in timestep_trajs]
    _, rewards, discounts, observations = zip(*timestep_trajs)

    r_t = np.stack(list(itertools.chain(*(rews[1:] for rews in rewards))))
    discount_t = np.stack(list(itertools.chain(*(disc[1:] for disc in discounts))))
    obs_t, obs_tp1 = zip(*[(obs[:-1], obs[1:]) for obs in observations])
    obs_t = np.stack(list(itertools.chain(*obs_t)))
    obs_tp1 = np.stack(list(itertools.chain(*obs_tp1)))
    a_t = np.stack(list(itertools.chain(*actions)))

    return Dataset(r_t, discount_t, obs_t, a_t, obs_tp1)


def save_dataset(file, dataset: Dataset, compress: bool = True):
    dataset.save(file, compress=compress)


def load_dataset(file) -> Dataset:
    return Dataset.load(file)
