import torch
import uuid
import numpy as np
import d4rl
import d4rl_ext
import gym 

def build_d4rl_dataset(env_name, db, split_obs=True, clip_to_eps = False, hiql_dones=True):
    env = gym.make(env_name)

    # dataset contains observations, actions, rewards, terminals, and infos
    dataset = env.get_dataset()

    # terminated not used in antmaze d4rl datasets (it does not abort the episode)
    # but truncated not present when using q_learning datasets
    # done_key = "truncated" if "truncated" in dataset.keys() else "terminals"
    if not "timeouts" in dataset.keys():
        if not hiql_dones:
            # print("building dones IQL-style")
            dones = np.zeros_like(dataset['rewards'], dtype=bool)
            # terminals are erroneous, compute the dones like in the IQL codebase
            for i in range(len(dones) - 1):
                if np.linalg.norm(dataset['observations'][i + 1] -
                                  dataset['next_observations'][i]
                                  ) > 1e-6 or dataset['terminals'][i] == 1.0:
                    dones[i] = True
                else:
                    dones[i] = False
            dones[-1] = True
    else:
        dones = dataset["timeouts"]
    ep_first_frame = 0
    ep_lengths = []

    if clip_to_eps:
        lim = 1 - 1e-5
        # print(f"clipping actions in[-{lim},{lim}], as in IQL codebase")
        dataset['actions'] = np.clip(dataset['actions'], -lim, lim)

    tensor_dataset = []
    returns = []
    for idx, done in enumerate(dones):
        if done:
            rewards = dataset['rewards'][ep_first_frame:idx+1].copy()
            returns.append(np.sum(rewards))
            if "antmaze" in env_name:
                rewards -= 1
            episode_frames = {'action': torch.tensor(dataset['actions'][ep_first_frame:idx+1].copy()).float(),
                              'rewards': torch.tensor(rewards).float(),
                              'reward': torch.tensor(rewards).float(),
                              'done': torch.tensor(dones[ep_first_frame:idx+1].copy()),
                              'masks': 1.0 - torch.tensor(dataset['terminals'][ep_first_frame:idx+1].copy()).float(),
                              'mask': 1.0 - torch.tensor(dataset['terminals'][ep_first_frame:idx+1].copy()).float()}
            if 'infos/goal' in dataset.keys():
                episode_frames['goal'] = torch.tensor(dataset['infos/goal'][ep_first_frame:idx+1].copy())
            if split_obs:
                episode_frames['obs/pos'] = torch.tensor(dataset["observations"][ep_first_frame:idx + 1, :2].copy()).float()
                episode_frames['obs/other'] = torch.tensor(dataset["observations"][ep_first_frame:idx + 1, 2:].copy()).float()
            else:
                episode_frames['obs'] = torch.tensor(dataset["observations"][ep_first_frame:idx + 1, :].copy()).float()

            ep_lengths.append(len(episode_frames['action']))
            ep_first_frame = idx + 1
            tensor_dataset.append(episode_frames)

    for episode in tensor_dataset:
        db.write("MyApp", str(uuid.UUID(int=np.random.randint(1e7), version=4)), "action", episode)


    
