import collections
from typing import Optional

import d4rl
import gym
import jax
import jax.numpy as jnp
import numpy as np
from tqdm import tqdm, trange

Batch = collections.namedtuple("Batch", ["observations", "actions", "rewards", "masks", "next_observations", "images"])


def split_into_trajectories(observations, actions, rewards, masks, dones_float, next_observations, images):
    trajs = [[]]

    for i in tqdm(range(len(observations)), desc="split"):
        elem = [observations[i], actions[i], rewards[i], masks[i], dones_float[i], next_observations[i], images[i]]
        trajs[-1].append(elem)
        if dones_float[i] == 1.0 and i + 1 < len(observations):
            trajs.append([])

    return trajs


# NOT USED.
def merge_trajectories(trajs):
    observations = []
    actions = []
    rewards = []
    masks = []
    dones_float = []
    next_observations = []

    for traj in trajs:
        for obs, act, rew, mask, done, next_obs in traj:
            observations.append(obs)
            actions.append(act)
            rewards.append(rew)
            masks.append(mask)
            dones_float.append(done)
            next_observations.append(next_obs)

    return (
        np.stack(observations),
        np.stack(actions),
        np.stack(rewards),
        np.stack(masks),
        np.stack(dones_float),
        np.stack(next_observations),
    )


def normalize(dataset, env_name, max_episode_steps=1000):
    trajs = split_into_trajectories(
        dataset.observations,
        dataset.actions,
        dataset.rewards,
        dataset.masks,
        dataset.dones_float,
        dataset.next_observations,
        dataset.images,
    )
    trj_mapper = []
    for trj_idx, traj in tqdm(enumerate(trajs), total=len(trajs), desc="chunk trajectories"):
        traj_len = len(traj)

        for _ in range(traj_len):
            trj_mapper.append((trj_idx, traj_len))

    def compute_returns(traj):
        episode_return = 0
        for _, _, rew, _, _, _ in traj:
            episode_return += rew

        return episode_return

    sorted_trajs = sorted(trajs, key=compute_returns)
    min_return, max_return = compute_returns(sorted_trajs[0]), compute_returns(sorted_trajs[-1])

    normalized_rewards = []
    for i in range(dataset.size):
        _reward = dataset.rewards[i]
        if "antmaze" in env_name:
            _, len_trj = trj_mapper[i]
            _reward -= min_return / len_trj
        _reward /= max_return - min_return
        # if ('halfcheetah' in env_name or 'walker2d' in env_name or 'hopper' in env_name):
        _reward *= max_episode_steps
        normalized_rewards.append(_reward)

    dataset.rewards = np.array(normalized_rewards)


class Dataset(object):
    def __init__(
        self,
        observations: np.ndarray,
        actions: np.ndarray,
        rewards: np.ndarray,
        masks: np.ndarray,
        dones_float: np.ndarray,
        next_observations: np.ndarray,
        size: int,
        images: np.ndarray = None,
    ):
        self.observations = observations
        self.actions = actions
        self.rewards = rewards
        self.masks = masks
        self.dones_float = dones_float
        self.next_observations = next_observations
        self.images = images if images is not None else np.asarray([None for _ in range(size)])
        self.size = size

    def sample(self, batch_size: int) -> Batch:
        indx = np.random.randint(self.size, size=batch_size)
        return Batch(
            observations=self.observations[indx],
            actions=self.actions[indx],
            rewards=self.rewards[indx],
            masks=self.masks[indx],
            next_observations=self.next_observations[indx],
            images=self.images[indx],
        )


class D4RLDataset(Dataset):
    def __init__(self, env: gym.Env, clip_to_eps: bool = True, eps: float = 1e-5):
        dataset = d4rl.qlearning_dataset(env)

        if clip_to_eps:
            lim = 1 - eps
            dataset["actions"] = np.clip(dataset["actions"], -lim, lim)

        dones_float = np.zeros_like(dataset["rewards"])

        for i in range(len(dones_float) - 1):
            if (
                np.linalg.norm(dataset["observations"][i + 1] - dataset["next_observations"][i]) > 1e-5
                or dataset["terminals"][i] == 1.0
            ):
                dones_float[i] = 1
            else:
                dones_float[i] = 0

        dones_float[-1] = 1

        super().__init__(
            dataset["observations"].astype(np.float32),
            actions=dataset["actions"].astype(np.float32),
            rewards=dataset["rewards"].astype(np.float32),
            masks=1.0 - dataset["terminals"].astype(np.float32),
            dones_float=dones_float.astype(np.float32),
            next_observations=dataset["next_observations"].astype(np.float32),
            size=len(dataset["observations"]),
        )


class RelabeledDataset(Dataset):
    def __init__(
        self,
        observations,
        actions,
        rewards,
        terminals,
        next_observations,
        images: np.ndarray = None,
        clip_to_eps: bool = True,
        eps: float = 1e-5,
    ):
        if clip_to_eps:
            lim = 1 - eps
            actions = np.clip(actions, -lim, lim)

        dones_float = np.zeros_like(rewards)
        for i in range(len(dones_float) - 1):
            if np.linalg.norm(observations[i + 1] - next_observations[i]) > 1e-6 or terminals[i] == 1.0:
                dones_float[i] = 1
            else:
                dones_float[i] = 0

        dones_float[-1] = 1
        super().__init__(
            observations=observations,
            actions=actions,
            rewards=rewards,
            masks=1.0 - terminals,
            dones_float=dones_float.astype(np.float32),
            next_observations=next_observations,
            images=images,
            size=len(observations),
        )


# NOT USE
class ReplayBuffer(Dataset):
    def __init__(self, observation_space: gym.spaces.Box, action_dim: int, capacity: int):
        observations = np.empty((capacity, *observation_space.shape), dtype=observation_space.dtype)
        actions = np.empty((capacity, action_dim), dtype=np.float32)
        rewards = np.empty((capacity,), dtype=np.float32)
        masks = np.empty((capacity,), dtype=np.float32)
        dones_float = np.empty((capacity,), dtype=np.float32)
        next_observations = np.empty((capacity, *observation_space.shape), dtype=observation_space.dtype)
        super().__init__(
            observations=observations,
            actions=actions,
            rewards=rewards,
            masks=masks,
            dones_float=dones_float,
            next_observations=next_observations,
            size=0,
        )

        self.size = 0

        self.insert_index = 0
        self.capacity = capacity

    def initialize_with_dataset(self, dataset: Dataset, num_samples: Optional[int]):
        assert self.insert_index == 0, "Can insert a batch online in an empty replay buffer."

        dataset_size = len(dataset.observations)

        if num_samples is None:
            num_samples = dataset_size
        else:
            num_samples = min(dataset_size, num_samples)
        assert self.capacity >= num_samples, "Dataset cannot be larger than the replay buffer capacity."

        if num_samples < dataset_size:
            perm = np.random.permutation(dataset_size)
            indices = perm[:num_samples]
        else:
            indices = np.arange(num_samples)

        self.observations[:num_samples] = dataset.observations[indices]
        self.actions[:num_samples] = dataset.actions[indices]
        self.rewards[:num_samples] = dataset.rewards[indices]
        self.masks[:num_samples] = dataset.masks[indices]
        self.dones_float[:num_samples] = dataset.dones_float[indices]
        self.next_observations[:num_samples] = dataset.next_observations[indices]

        self.insert_index = num_samples
        self.size = num_samples

    def insert(
        self,
        observation: np.ndarray,
        action: np.ndarray,
        reward: float,
        mask: float,
        done_float: float,
        next_observation: np.ndarray,
    ):
        self.observations[self.insert_index] = observation
        self.actions[self.insert_index] = action
        self.rewards[self.insert_index] = reward
        self.masks[self.insert_index] = mask
        self.dones_float[self.insert_index] = done_float
        self.next_observations[self.insert_index] = next_observation

        self.insert_index = (self.insert_index + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)


@jax.jit
def batch_to_jax(batch):
    return jax.tree_util.tree_map(jax.device_put, batch)


def reward_from_preference(
    env_name: str,
    dataset: D4RLDataset,
    reward_model,
    batch_size: int = 256,
):
    use_image = dataset.images[0] is not None
    data_size = dataset.rewards.shape[0]
    interval = int(data_size / batch_size) + 1
    new_r = np.zeros_like(dataset.rewards)
    for i in trange(interval):
        start_pt = i * batch_size
        end_pt = (i + 1) * batch_size

        if use_image:
            images = dataset.images[start_pt:end_pt]

        input = dict(
            observations=dataset.observations[start_pt:end_pt],
            actions=dataset.actions[start_pt:end_pt],
            next_observations=dataset.next_observations[start_pt:end_pt],
        )

        if use_image:
            input.update(dict(images=images))

        jax_input = batch_to_jax(input)
        new_reward = reward_model.get_reward(jax_input)
        new_reward = np.asarray(list(new_reward))
        new_r[start_pt:end_pt] = new_reward

    dataset.rewards = new_r.copy()
    return dataset


def reward_from_preference_transformer(
    env_name: str,
    dataset: D4RLDataset,
    reward_model,
    seq_len: int,
    batch_size: int = 256,
    use_diff: bool = False,
    label_mode: str = "last",
    with_attn_weights: bool = False,
    skip_frame: int = 1,  # Option for attention analysis.
):
    use_image = dataset.images[0] is not None
    trajs = split_into_trajectories(
        dataset.observations,
        dataset.actions,
        dataset.rewards,
        dataset.masks,
        dataset.dones_float,
        dataset.next_observations,
        dataset.images,
    )
    trajectories = []
    trj_mapper = []
    observation_dim = dataset.observations.shape[-1]
    action_dim = dataset.actions.shape[-1]

    if use_image:
        image_dim = dataset.images.shape[1:]

    for trj_idx, traj in tqdm(enumerate(trajs), total=len(trajs), desc="chunk trajectories"):
        _obs, _act, _mask = [], [], []
        if use_image:
            _img = []

        for i in range(seq_len - 1):
            _obs.append(np.zeros(observation_dim))
            _act.append(np.zeros(action_dim))
            _mask.append(0.0)
            if use_image:
                _img.append(np.zeros(image_dim, dtype=np.uint8))

        for _o, _a, _, _m, _, _, _im in traj:
            _obs.append(_o)
            _act.append(_a)
            _mask.append(_m)
            if use_image:
                _img.append(_im)

        traj_len = len(traj)
        if use_image:
            _obs, _act, _attn_mask, _img = np.asarray(_obs), np.asarray(_act), np.asarray(_mask), np.asarray(_img)
            trajectories.append((_obs, _act, _attn_mask, _img))
        else:
            _obs, _act, _attn_mask = np.asarray(_obs), np.asarray(_act), np.asarray(_mask)
            trajectories.append((_obs, _act, _attn_mask))

        for seg_idx in range(traj_len):
            trj_mapper.append((trj_idx, seg_idx))

    data_size = dataset.rewards.shape[0]
    interval = int(data_size / batch_size) + 1
    new_r = np.zeros_like(dataset.rewards)
    pts = []
    attn_weights = []
    for i in trange(interval, desc="relabel reward"):
        start_pt = i * batch_size
        end_pt = min((i + 1) * batch_size, data_size)

        _input_obs, _input_act, _input_timestep, _input_attn_mask, _input_pt = [], [], [], [], []
        if use_image:
            _input_img = []
        for pt in range(start_pt, end_pt):
            _trj_idx, _seg_idx = trj_mapper[pt]
            __input_obs = trajectories[_trj_idx][0][_seg_idx : _seg_idx + seq_len, :][::skip_frame]
            __input_act = trajectories[_trj_idx][1][_seg_idx : _seg_idx + seq_len, :][::skip_frame]
            if use_image:
                __input_img = trajectories[_trj_idx][-1][_seg_idx : _seg_idx + seq_len, ...][::skip_frame]
            if _seg_idx < seq_len:
                __input_timestep = np.concatenate(
                    [np.zeros(seq_len - _seg_idx, dtype=np.int32), np.arange(_seg_idx, dtype=np.int32)], axis=0
                )
            elif _seg_idx <= 500:
                __input_timestep = np.arange(_seg_idx - seq_len, _seg_idx, dtype=np.int32)
            elif 0 < _seg_idx - 500 < seq_len:
                __input_timestep = np.concatenate(
                    [np.arange(seq_len - _seg_idx + 500, dtype=np.int32), np.zeros(_seg_idx - 500, dtype=np.int32)],
                    axis=0,
                )
            else:
                __input_timestep = np.zeros(seq_len, dtype=np.int32)
            __input_timestep = __input_timestep[::skip_frame]
            __input_attn_mask = trajectories[_trj_idx][2][_seg_idx : _seg_idx + seq_len, ...][::skip_frame]
            __input_pt = np.arange(pt - seq_len + 1, pt + 1)[::skip_frame]

            _input_obs.append(__input_obs)
            _input_act.append(__input_act)
            _input_timestep.append(__input_timestep)
            _input_attn_mask.append(__input_attn_mask)
            _input_pt.append(__input_pt)
            if use_image:
                _input_img.append(__input_img)

        _input_obs, _input_act, _input_timestep, _input_attn_mask, _input_pt = map(
            lambda x: np.asarray(x), [_input_obs, _input_act, _input_timestep, _input_attn_mask, _input_pt]
        )
        if use_image:
            _input_img = np.asarray(_input_img)

        input = dict(
            observations=_input_obs,
            actions=_input_act,
            timestep=_input_timestep,
            attn_mask=_input_attn_mask,
        )
        if use_image:
            input.update(images=_input_img)

        jax_input = batch_to_jax(input)
        if with_attn_weights:
            new_reward, attn_weight = reward_model.get_reward(jax_input)
            attn_weights.append(np.array(attn_weight))
            pts.append(_input_pt)
        else:
            new_reward, _ = reward_model.get_reward(jax_input)
        new_reward = new_reward.reshape(end_pt - start_pt, seq_len // skip_frame)

        # NOT USE
        if use_diff:
            prev_input = dict(
                observations=_input_obs[:, : seq_len - 1, :],
                actions=_input_act[:, : seq_len - 1, :],
                timestep=_input_timestep[:, : seq_len - 1],
                attn_mask=_input_attn_mask[:, : seq_len - 1],
            )
            jax_prev_input = batch_to_jax(prev_input)
            prev_reward, _ = reward_model.get_reward(jax_prev_input)
            prev_reward = prev_reward.reshape(end_pt - start_pt, seq_len - 1) * prev_input["attn_mask"]
            if label_mode == "mean":
                new_reward = jnp.sum(new_reward, axis=1).reshape(-1, 1)
                prev_reward = jnp.sum(prev_reward, axis=1).reshape(-1, 1)
            elif label_mode == "last":
                new_reward = new_reward[:, -1].reshape(-1, 1)
                prev_reward = prev_reward[:, -1].reshape(-1, 1)
            new_reward -= prev_reward
        else:
            if label_mode == "mean":
                new_reward = jnp.sum(new_reward, axis=1) / jnp.sum(_input_attn_mask, axis=1)
                new_reward = new_reward.reshape(-1, 1)
            elif label_mode == "last":
                new_reward = new_reward[:, -1].reshape(-1, 1)

        new_reward = np.asarray(list(new_reward))
        new_r[start_pt:end_pt, ...] = new_reward.squeeze(-1)

    dataset.rewards = new_r.copy()

    if with_attn_weights:
        return dataset, (attn_weights, pts)
    return dataset
