import gym
import d4rl
import chex
import jax
import jax.numpy as jnp

import numpy as np
from typing import Dict, Tuple

import tree

def random_dataset(env, num):
    random_state = [env.observation_space.sample() for _ in range(num)]
    random_action = [env.action_space.sample() for _ in range(num)]
    random_state = jnp.array(random_state)
    random_action = jnp.array(random_action)
    return random_state, random_action

# source: https://github.com/rail-berkeley/d4rl/blob/d842aa194b416e564e54b0730d9f934e3e32f854/d4rl/__init__.py#L63
# modified to also return next_action (needed for logging and in general useful to have)
def qlearning_dataset(env, dataset=None, terminate_on_end=False, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, next_actins, rewards,
     and a terminal flag.
    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().
    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            next_actions: An N x dim_action array of next actions.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)

    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    next_action_ = []
    reward_ = []
    done_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = 'timeouts' in dataset

    episode_step = 0
    for i in range(N - 1):
        obs = dataset['observations'][i].astype(np.float32)
        new_obs = dataset['observations'][i + 1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        new_action = dataset['actions'][i + 1].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])

        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue
        if done_bool or final_timestep:
            episode_step = 0

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        next_action_.append(new_action)
        reward_.append(reward)
        done_.append(done_bool)
        episode_step += 1

    return {
        'observations': np.array(obs_),
        'actions': np.array(action_),
        'next_observations': np.array(next_obs_),
        'next_actions': np.array(next_action_),
        'rewards': np.array(reward_),
        'terminals': np.array(done_),
    }


@chex.dataclass
class ReplayBuffer:
    data: Dict[str, jax.Array]

    @staticmethod
    def create_from_d4rl(dataset_name: str, normalize_reward: bool = False) -> "ReplayBuffer":
        d4rl_data = qlearning_dataset(gym.make(dataset_name))
        buffer = {
            "states": jnp.asarray(d4rl_data["observations"], dtype=jnp.float32),
            "actions": jnp.asarray(d4rl_data["actions"], dtype=jnp.float32),
            "rewards": jnp.asarray(d4rl_data["rewards"], dtype=jnp.float32),
            "next_states": jnp.asarray(d4rl_data["next_observations"], dtype=jnp.float32),
            "next_actions": jnp.asarray(d4rl_data["next_actions"], dtype=jnp.float32),
            "dones": jnp.asarray(d4rl_data["terminals"], dtype=jnp.float32)
        }
        if normalize_reward:
            buffer["rewards"] = ReplayBuffer.normalize_reward(dataset_name, buffer["rewards"])

        return ReplayBuffer(data=buffer)

    @staticmethod
    def create_from_d4rl_with_random_data(dataset_name: str, num_random_data: int = 100000, label_dim: int = 1) -> "ReplayBuffer":
        d4rl_data = qlearning_dataset(gym.make(dataset_name))
        buffer = {
            "states": jnp.asarray(d4rl_data["observations"], dtype=jnp.float32),
            "actions": jnp.asarray(d4rl_data["actions"], dtype=jnp.float32),
            # "next_states": jnp.asarray(d4rl_data["next_observations"], dtype=jnp.float32),
            # "next_actions": jnp.asarray(d4rl_data["next_actions"], dtype=jnp.float32),
            "random_or_not": jnp.zeros((d4rl_data["terminals"].shape[0], label_dim), dtype=jnp.float32)
        }
        random_state, random_action = random_dataset(env=gym.make(dataset_name), num=num_random_data)
        random_buffer = {
            "states": random_state, 
            "actions": random_action, 
            "random_or_not": jnp.ones((num_random_data, label_dim), dtype=jnp.float32)
        }
        buffer = tree.map_structure(lambda x, y: jnp.concatenate([x, y]), buffer, random_buffer)

        return ReplayBuffer(data=buffer)

    @property
    def size(self):
        # WARN: do not use __len__ here! It will use len of the dataclass, i.e. number of fields.
        return self.data["states"].shape[0]

    def sample_batch(self, key: jax.random.PRNGKey, batch_size: int) -> Dict[str, jax.Array]:
        indices = jax.random.randint(key, shape=(batch_size,), minval=0, maxval=self.size)
        batch = jax.tree_map(lambda arr: arr[indices], self.data)
        return batch

    def get_moments(self, modality: str) -> Tuple[jax.Array, jax.Array]:
        mean = self.data[modality].mean(0)
        std = self.data[modality].std(0)
        return mean, std

    @staticmethod
    def normalize_reward(dataset_name: str, rewards: jax.Array) -> jax.Array:
        if "antmaze" in dataset_name:
            return rewards * 100.0  # like in LAPO
        else:
            raise NotImplementedError("Reward normalization is implemented only for AntMaze yet!")
