import glob
import os

import h5py
import numpy as np
import ujson as json
from rich import print
from tqdm import tqdm, trange


def qlearning_ant_dataset(env, dataset=None, terminate_on_end=False, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.
    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().
    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)

    N = dataset["rewards"].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    goal_ = []
    xy_ = []
    done_bef_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatibility.
    use_timeouts = False
    if "timeouts" in dataset:
        use_timeouts = True

    episode_step = 0
    for i in range(N - 1):
        obs = dataset["observations"][i].astype(np.float32)
        new_obs = dataset["observations"][i + 1].astype(np.float32)
        action = dataset["actions"][i].astype(np.float32)
        reward = dataset["rewards"][i].astype(np.float32)
        done_bool = bool(dataset["terminals"][i])
        goal = dataset["infos/goal"][i].astype(np.float32)
        xy = dataset["infos/qpos"][i][:2].astype(np.float32)

        if use_timeouts:
            final_timestep = dataset["timeouts"][i]
            next_final_timestep = dataset["timeouts"][i + 1]
        else:
            final_timestep = episode_step == env._max_episode_steps - 1
            next_final_timestep = episode_step == env._max_episode_steps - 2

        done_bef = bool(next_final_timestep)

        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue
        if done_bool or final_timestep:
            episode_step = 0

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        goal_.append(goal)
        xy_.append(xy)
        done_bef_.append(done_bef)
        episode_step += 1

    return {
        "observations": np.array(obs_),
        "actions": np.array(action_),
        "next_observations": np.array(next_obs_),
        "rewards": np.array(reward_),
        "terminals": np.array(done_),
        "goals": np.array(goal_),
        "xys": np.array(xy_),
        "dones_bef": np.array(done_bef_),
    }


def qlearning_robosuite_dataset(dataset_path, terminate_on_end=False, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.
    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().
    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    f = h5py.File(dataset_path, "r")

    # N = dataset['rewards'].shape[0]
    demos = list(f["data"].keys())
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    traj_idx_ = []
    seg_idx_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatibility.
    obs_keys = kwargs.get(
        "obs_key",
        [
            "object",
            "robot0_joint_pos",
            "robot0_joint_pos_cos",
            "robot0_joint_pos_sin",
            "robot0_joint_vel",
            "robot0_eef_pos",
            "robot0_eef_quat",
            "robot0_gripper_qpos",
            "robot0_gripper_qvel",
        ],
    )
    for ep in tqdm(demos, desc="load robosuite demonstrations", ncols=0):
        ep_grp = f[f"data/{ep}"]
        traj_len = ep_grp["actions"].shape[0]
        for i in range(traj_len - 1):
            total_obs = ep_grp["obs"]
            obs = np.concatenate([total_obs[key][i].tolist() for key in obs_keys], axis=0)
            new_obs = np.concatenate([total_obs[key][i + 1].tolist() for key in obs_keys], axis=0)
            action = ep_grp["actions"][i]
            reward = ep_grp["rewards"][i]
            done_bool = bool(ep_grp["dones"][i])

            obs_.append(obs)
            next_obs_.append(new_obs)
            action_.append(action)
            reward_.append(reward)
            done_.append(done_bool)
            traj_idx_.append(int(ep[5:]))
            seg_idx_.append(i)

    return {
        "observations": np.array(obs_),
        "actions": np.array(action_),
        "next_observations": np.array(next_obs_),
        "rewards": np.array(reward_),
        "terminals": np.array(done_),
        "env_meta": json.loads(f["data"].attrs["env_args"]),
        "traj_indices": np.array(traj_idx_),
        "seg_indices": np.array(seg_idx_),
    }


def qlearning_metaworld_dataset(dataset_path, terminate_on_end=False, max_episode_steps=500, **kwargs):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.
    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().
    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    dataset = h5py.File(dataset_path, "r")

    N = dataset["rewards"].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    done_bef_ = []

    use_image = False
    if dataset.get("images"):
        use_image = True
        images_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatibility.
    use_timeouts = False
    if "timeouts" in dataset:
        use_timeouts = True

    episode_step = 0
    for i in trange(N - 1, desc="load metaworld data", ncols=0):
        obs = dataset["observations"][i, -1].astype(np.float32)
        new_obs = dataset["observations"][i + 1, -1].astype(np.float32)
        action = dataset["actions"][i, -1].astype(np.float32)
        reward = dataset["rewards"][i, -1].astype(np.float32).item()
        # print(f"terminals: {dataset['terminals'][i, -1]}")
        done_bool = bool(dataset["terminals"][i, -1])
        if use_image:
            images = dataset["images"][i].astype(np.uint8)

        if use_timeouts:
            final_timestep = dataset["timeouts"][i, -1]
            next_final_timestep = dataset["timeouts"][i + 1, -1]
        else:
            final_timestep = episode_step == max_episode_steps - 1
            next_final_timestep = episode_step == max_episode_steps - 2

        done_bef = bool(next_final_timestep)

        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue
        if done_bool or final_timestep:
            episode_step = 0

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        done_bef_.append(done_bef)
        if use_image:
            images_.append(images)
        episode_step += 1

    ret = {
        "observations": np.array(obs_),
        "actions": np.array(action_),
        "next_observations": np.array(next_obs_),
        "rewards": np.array(reward_),
        "terminals": np.array(done_),
        "dones_bef": np.array(done_bef_),
    }

    if use_image:
        ret.update(dict(images=np.array(images_)))

    return ret


def qlearning_factorworld_dataset(
    dataset_path,
    terminate_on_end=False,
    max_episode_steps=500,
    # camera_keys=["corner", "corner2", "corner3", "topview"],
    camera_keys=["corner2"],
    **kwargs,
):
    """
    Returns datasets formatted for use by standard Q-learning algorithms,
    with observations, actions, next_observations, rewards, and a terminal
    flag.
    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        terminate_on_end (bool): Set done=True on the last timestep
            in a trajectory. Default is False, and will discard the
            last timestep in each trajectory.
        **kwargs: Arguments to pass to env.get_dataset().
    Returns:
        A dictionary containing keys:
            observations: An N x dim_obs array of observations.
            actions: An N x dim_action array of actions.
            next_observations: An N x dim_obs array of next observations.
            rewards: An N-dim float array of rewards.
            terminals: An N-dim boolean array of "done" or episode termination flags.
    """
    episodes = sorted(glob.glob(os.path.join(dataset_path, "*.npz")))
    # dataset = h5py.File(dataset_path, "r")

    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []
    images_, next_images_ = {}, {}
    for ck in camera_keys:
        images_[ck] = []
        next_images_[ck] = []
    done_bef_ = []

    use_image = True
    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatibility.
    use_timeouts = False

    episode_step = 0

    for ep in tqdm(episodes, desc="load factorworld data", ncols=0):
        ep = np.load(ep, allow_pickle=True)
        ep = {key: ep[key] for key in ep.keys()}
        N = ep["rewards"].shape[0]
        # for i in trange(N - 1, desc="load metaworld data", ncols=0):
        for i in range(N):
            obs = ep["states"][i].astype(np.float32)
            # NOT USE
            new_obs = ep["states"][min(i + 1, N - 1)].astype(np.float32)

            action = ep["actions"][i].astype(np.float32)
            reward = ep["task_rewards"][i].astype(np.float32).item()
            done_bool = bool(ep["rewards"][i])
            if use_timeouts:
                final_timestep = ep["timeouts"][i, -1]
                next_final_timestep = ep["timeouts"][i + 1, -1]
            else:
                final_timestep = episode_step == max_episode_steps - 1
                next_final_timestep = episode_step == max_episode_steps - 2

            done_bef = bool(next_final_timestep)

            if (not terminate_on_end) and final_timestep:
                # Skip this transition and don't apply terminals on the last step of an episode
                episode_step = 0
                continue
            if done_bool or final_timestep:
                episode_step = 0

            obs_.append(obs)
            next_obs_.append(new_obs)
            action_.append(action)
            reward_.append(reward)
            done_.append(done_bool)
            done_bef_.append(done_bef)
            if use_image:
                for ck in camera_keys:
                    images = ep[ck][i].astype(np.uint8)
                    new_images = ep[ck][min(i + 1, N - 1)].astype(np.uint8)
                    images_[ck].append(images)
                    next_images_[ck].append(new_images)
            episode_step += 1

    ret = {
        "observations": np.array(obs_),
        "actions": np.array(action_),
        "next_observations": np.array(next_obs_),
        "rewards": np.array(reward_),
        "terminals": np.array(done_),
        "dones_bef": np.array(done_bef_),
    }

    if use_image:
        image_dict = {}
        for key, val in images_.items():
            image_dict[key] = val
        for key, val in next_images_.items():
            image_dict[f"next_{key}"] = val

        ret.update({key: np.asarray(val) for key, val in image_dict.items()})
    return ret


if __name__ == "__main__":
    # dataset_path = "/home/pref_data/pick-place-v2/light_object_pos_goal_pos_table_pos_floor_texture_table_texture/episodes"
    dataset_path = "/home/factor-world/individual_data/pick-place-v2/episodes"
    image_key = "corner2"
    ds = qlearning_factorworld_dataset(dataset_path, camera_name=image_key)
    for key, val in ds.items():
        print(f"[INFO] {key}: {val.shape}")
