import numpy as np
import torch
from torch import nn as nn
import d4rl
import rlkit.torch.pytorch_util as ptu


def d4rl_qlearning_dataset_with_next_actions(env):
    ds = d4rl.qlearning_dataset(env)
    ds["next_actions"] = np.roll(ds["actions"], -1, axis=0)
    for k in ds:  
        ds[k] = ds[k][:-1]

    
    return ds


def load_hdf5(dataset, replay_buffer):
    _obs = dataset["observations"]
    N = _obs.shape[0]
    assert (
        replay_buffer._max_replay_buffer_size >= N
    ), "dataset does not fit in replay buffer"

    _actions = dataset["actions"]
    _next_obs = dataset["next_observations"]
    _rew = dataset["rewards"][:N]
    _done = dataset["terminals"][:N]

    replay_buffer._observations[:N] = ptu.torch_ify(_obs[:N])
    replay_buffer._next_obs[:N] = ptu.torch_ify(_next_obs[:N])
    replay_buffer._actions[:N] = ptu.torch_ify(_actions[:N])
    replay_buffer._rewards[:N] = ptu.torch_ify(np.expand_dims(_rew, 1)[:N])
    replay_buffer._terminals[:N] = ptu.torch_ify(np.expand_dims(_done, 1)[:N])
    replay_buffer._size = N - 1
    replay_buffer._top = replay_buffer._size


def load_hdf5_next_actions(dataset, replay_buffer):
    load_hdf5(dataset, replay_buffer)
    _obs = dataset["observations"]
    N = _obs.shape[0]

    replay_buffer._next_actions[:N] = ptu.torch_ify(dataset["next_actions"][:N])


def load_hdf5_next_actions_and_val_data(
    dataset, replay_buffer, train_raio=0.95, fold_idx=1
):
    _obs = dataset["observations"]
    _actions = dataset["actions"]
    _next_obs = dataset["next_observations"]
    _next_actions = dataset["next_actions"]
    _rew = dataset["rewards"]
    _done = dataset["terminals"]

    N = _obs.shape[0]
    assert (
        replay_buffer._max_replay_buffer_size >= N
    ), "dataset does not fit in replay buffer"

    assert np.array_equal(
        _next_actions[: N - 1],
        _actions[1:N],
    )

    for _ in range(fold_idx):
        indices = np.random.permutation(N)
    tran_indices, val_indices = np.split(indices, [int(N * train_raio)])

    size = len(tran_indices)
    replay_buffer._observations[:size] = ptu.torch_ify(_obs[tran_indices])
    replay_buffer._next_obs[:size] = ptu.torch_ify(_next_obs[tran_indices])
    replay_buffer._actions[:size] = ptu.torch_ify(_actions[tran_indices])
    replay_buffer._rewards[:size] = ptu.torch_ify(np.expand_dims(_rew[tran_indices], 1))
    replay_buffer._terminals[:size] = ptu.torch_ify(
        np.expand_dims(_done[tran_indices], 1)
    )
    replay_buffer._next_actions[:size] = ptu.torch_ify(_next_actions[tran_indices])

    replay_buffer._size = size - 1
    replay_buffer._top = replay_buffer._size

    val_observations = ptu.torch_ify(_obs[val_indices])
    val_actions = ptu.torch_ify(_actions[val_indices])

    return replay_buffer, val_observations, val_actions
