from typing import Any, Iterator, Optional

from absl import logging
from acme import types
from acme import wrappers
import d4rl
import gym
import jax
import jax.numpy as jnp
import numpy as np
import tqdm
import tree
import h5py

def get_d4rl_dataset(env):
  dataset = d4rl.qlearning_dataset(env)
  return types.Transition(
      observation=dataset["observations"],
      action=dataset["actions"],
      reward=dataset["rewards"],
      next_observation=dataset["next_observations"],
      discount=1.0 - dataset["terminals"].astype(np.float32),
  )


def make_environment(name: str, seed: Optional[int] = None):
  environment = gym.make(name)
  if seed is not None:
    environment.seed(seed)
  environment = wrappers.GymWrapper(environment)
  environment = wrappers.SinglePrecisionWrapper(environment)
  environment = wrappers.CanonicalSpecWrapper(environment)
  return environment


def split_into_trajectories(observations, actions, rewards, masks, dones_float,
                            next_observations):
  trajs = [[]]

  for i in tqdm.tqdm(range(len(observations))):
    trajs[-1].append(
        types.Transition(
            observation=observations[i],
            action=actions[i],
            reward=rewards[i],
            discount=masks[i],
            next_observation=next_observations[i]))
    if dones_float[i] == 1.0 and i + 1 < len(observations):
      trajs.append([])

  return trajs


def merge_trajectories(trajs):
  flat = []
  for traj in trajs:
    for transition in traj:
      flat.append(transition)
  return tree.map_structure(lambda *xs: np.stack(xs), *flat)


def load_trajectories(data_load: str, score_lambda: int):
  dataset = {}
  # datasets_merge_full_split_1/antmaze-medium-play-v2-oriscores
  data_dict = h5py.File('/home/zulipeng/implicit_q_learning/datasets/'+data_load+'.hdf5', 'r')
  dataset['observations'] = np.array(data_dict['observations'][:])
  dataset['actions'] = np.array(data_dict['actions'][:])
  dataset['next_observations'] = np.array(data_dict['next_observations'][:])
  dataset['rewards'] = np.array(data_dict['rewards'][:])
  dataset['terminals'] = np.array(data_dict['terminals'][:])
  dataset['scores'] = np.array(data_dict['scores'][:])
  data_dict.close()
  dones_float = np.zeros_like(dataset['rewards'])

  for i in range(len(dones_float) - 1):
    if np.linalg.norm(dataset['observations'][i + 1] -
                      dataset['next_observations'][i]
                     ) > 1e-6 or dataset['terminals'][i] == 1.0:
      dones_float[i] = 1
    else:
      dones_float[i] = 0
  dones_float[-1] = 1

  if 'realterminals' in dataset:
    # We updated terminals in the dataset, but continue using
    # the old terminals for consistency with original IQL.
    masks = 1.0 - dataset['realterminals'].astype(np.float32)
  else:
    masks = 1.0 - dataset['terminals'].astype(np.float32)
  traj = split_into_trajectories(
      observations=dataset['observations'].astype(np.float32),
      actions=dataset['actions'].astype(np.float32),
      rewards=(np.exp(-score_lambda * dataset['scores'])-1).astype(np.float32),
      masks=masks,
      dones_float=dones_float.astype(np.float32),
      next_observations=dataset['next_observations'].astype(np.float32))
  return traj


def load_demonstrations(name: str, num_top_episodes: int = 10):
  """Load expert demonstrations."""
  # Load trajectories from the given dataset
  trajs = load_trajectories(name)
  if num_top_episodes < 0:
    logging.info("Loading the entire dataset as demonstrations")
    return trajs

  def compute_returns(traj):
    episode_return = 0
    for transition in traj:
      episode_return += transition.reward
    return episode_return

  # Sort by episode return
  trajs.sort(key=compute_returns)
  return trajs[-num_top_episodes:]


class JaxInMemorySampler(Iterator[Any]):

  def __init__(
      self,
      dataset,
      key: jnp.ndarray,
      batch_size: int,
  ):
    self._dataset_size = jax.tree_util.tree_leaves(dataset)[0].shape[0]
    self._jax_dataset = jax.tree_map(jax.device_put, dataset)

    def sample(data, key: jnp.ndarray):
      key1, key2 = jax.random.split(key)
      indices = jax.random.randint(
          key1, (batch_size,), minval=0, maxval=self._dataset_size)
      data_sample = jax.tree_map(lambda d: jnp.take(d, indices, axis=0), data)
      return data_sample, key2

    self._sample = jax.jit(lambda key: sample(self._jax_dataset, key))
    self._key = key

  def __next__(self) -> Any:
    data, self._key = self._sample(self._key)
    return data
