import gym
import torch
import uuid
import numpy as np
import d4rl
import d4rl_ext

def build_d4rl_dataset_v2(env_name, db, split_obs=True, clip_to_eps = False, hiql_dones=False):
    # Create the environment
    # for x in sorted([t.id for t in envs.registry.all()]):
    #     print(x)

    env = gym.make(env_name)

    # dataset contains observations, actions, rewards, terminals, and infos
    dataset = env.get_dataset()
    # terminated not used in antmaze d4rl datasets (it does not abort the episode)
    # but truncated not present when using q_learning datasets
    #done_key = "truncated" if "truncated" in dataset.keys() else "terminals"
    if not "timeouts" in dataset.keys():
        if not hiql_dones:
            # print("building dones IQL-style")
            dones = np.zeros_like(dataset['rewards'], dtype=bool)
            # terminals are erroneous, compute the dones like in the IQL codebase
            for i in range(len(dones) - 1):
                if np.linalg.norm(dataset['observations'][i + 1] -
                                  dataset['next_observations'][i]
                                  ) > 1e-6 or dataset['terminals'][i] == 1.0:
                    dones[i] = True
                else:
                    dones[i] = False
            dones[-1] = True
    else:
        dones = dataset["timeouts"]
    ep_first_frame = 0
    ep_lengths = []

    if clip_to_eps:
        lim = 1 - 1e-5
        # print(f"clipping actions in[-{lim},{lim}], as in IQL codebase")
        dataset['actions'] = np.clip(dataset['actions'], -lim, lim)

    tensor_dataset = []
    returns = []
    for idx, done in enumerate(dones):
        if done:
            rewards = dataset['rewards'][ep_first_frame:idx+1].copy()
            returns.append(np.sum(rewards))
            if "antmaze" in env_name:
                rewards -= 1
            episode_frames = {'action': torch.tensor(dataset['actions'][ep_first_frame:idx+1].copy()).float(),
                              'reward': torch.tensor(rewards).float(),
                              'done': torch.tensor(dones[ep_first_frame:idx+1].copy()),
                              'mask': 1.0 - torch.tensor(dataset['terminals'][ep_first_frame:idx+1].copy()).float()}
            if 'infos/goal' in dataset.keys():
                episode_frames['goal'] = torch.tensor(dataset['infos/goal'][ep_first_frame:idx+1].copy())
            if split_obs:
                episode_frames['obs/pos'] = torch.tensor(dataset["observations"][ep_first_frame:idx + 1, :2].copy()).float()
                episode_frames['obs/other'] = torch.tensor(dataset["observations"][ep_first_frame:idx + 1, 2:].copy()).float()
            else:
                episode_frames['obs'] = torch.tensor(dataset["observations"][ep_first_frame:idx + 1, :].copy()).float()

            # if next_obs:
            #     episode_frames["next_obs"] = torch.tensor(dataset["next_observations"][ep_first_frame:idx + 1, :].copy()).float()

            ep_lengths.append(len(episode_frames['action']))
            ep_first_frame = idx + 1
            tensor_dataset.append(episode_frames)

    if 'halfcheetah' in env_name or 'walker2d' in env_name or 'hopper' in env_name:
        print("normalizing rewards in dataset...")
        min_r = min(returns)
        max_r = max(returns)
        for episode in tensor_dataset:
            episode["reward"] /= (max_r - min_r)
            episode["reward"] *= 1000.0

    # print(f"total of {len(tensor_dataset)} episodes to write in disk")
    for episode in tensor_dataset:
        db.write("MyApp", str(uuid.UUID(int=np.random.randint(1e7), version=4)), "action", episode)


    # unique_values, counts = np.unique(ep_lengths, return_counts=True)
    # if split_obs:
    #     print(f"observation space dimensions: {len(episode_frames['obs/pos'][0, :]) + len(episode_frames['obs/other'][0, :])}")
    # else:
    #     print(f"observation space dimensions: {len(episode_frames['obs'][0, :])}")
    # print(f"action space dimensions: {len(episode_frames['action'][0, :])}")
    # print(', '.join(f"{count} episodes of {value} steps" for value, count in zip(unique_values, counts)))

    # useful in HIQL settings, in which G == S, i.e. you need fake state inputs to concatenate with goals at test time
    return dataset["observations"][0]


def build_d4rl_dataset(
    env_name: str, 
    db, 
    max_episode_length:int,
    split_obs: bool = True,
    split_idx: int = 2,
    clip_to_eps: bool = False, 
    eps: float = 1e-5, 
    filter_terminals: bool = False,
    terminate_on_end: bool = False
):

    # Load numpy qlearning dataset
    env = gym.make(env_name)

    dataset = env.get_dataset()
    N = dataset['rewards'].shape[0]
    obs_ = []
    next_obs_ = []
    action_ = []
    reward_ = []
    done_ = []

    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True

    episode_step = 0
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])

        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)
        if (not terminate_on_end) and final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue
        if done_bool or final_timestep:
            episode_step = 0

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        episode_step += 1

    dataset['observations'] = np.array(obs_)
    dataset['actions'] = np.array(action_)
    dataset['next_observations'] = np.array(next_obs_)
    dataset['rewards'] = np.array(reward_)
    dataset['terminals'] = np.array(done_)

    if clip_to_eps:
        lim = 1 - eps
        dataset['actions'] = np.clip(dataset['actions'], -lim, lim)

    dataset['terminals'][-1] = 1
    if filter_terminals:
        # drop terminal transitions
        non_last_idx = np.nonzero(~dataset['terminals'])[0]
        last_idx = np.nonzero(dataset['terminals'])[0]
        penult_idx = last_idx - 1
        new_dataset = dict()
        for k, v in dataset.items():
            if k == 'terminals':
                v[penult_idx] = 1
            new_dataset[k] = v[non_last_idx]
        dataset = new_dataset

    if 'antmaze' in env_name:
        # antmaze: terminals are incorrect for GCRL
        dones_float = np.zeros_like(dataset['rewards'])
        dataset['terminals'][:] = 0.

        for i in range(len(dones_float) - 1):
            if np.linalg.norm(dataset['observations'][i + 1] - dataset['next_observations'][i]) > 1e-6:
                dones_float[i] = 1
            else:
                dones_float[i] = 0
        dones_float[-1] = 1
    else:
        dones_float = dataset['terminals'].copy()
    
    if 'kitchen' in env_name:
        dataset['observations'] = dataset['observations'][:, :30][:, mask].copy()
        dataset['next_observations'] = dataset['next_observations'][:, :30][:, mask].copy()

    # Get tensor episodes
    ep_first_frame = 0
    ep_lengths = []
    tensor_dataset = []
    returns = []
    for idx, done in enumerate(dones_float):
        if done:

            # Form episode
            rewards = dataset['rewards'][ep_first_frame:idx+1].copy()
            returns.append(np.sum(rewards))
            if "antmaze" in env_name:
                rewards -= 1
            episode_frames = {'action': torch.tensor(dataset['actions'][ep_first_frame:idx+1].copy()).float(),
                              'reward': torch.tensor(rewards).float(),
                              'done': torch.tensor(dones_float[ep_first_frame:idx+1].copy()),
                              'mask': 1.0 - torch.tensor(dataset['terminals'][ep_first_frame:idx+1].copy()).float()}
            if 'infos/goal' in dataset.keys():
                episode_frames['goal'] = torch.tensor(dataset['infos/goal'][ep_first_frame:idx+1].copy())
            if split_obs:
                episode_frames['obs/pos'] = torch.tensor(dataset["observations"][ep_first_frame:idx + 1, :split_idx].copy()).float()
                episode_frames['obs/other'] = torch.tensor(dataset["observations"][ep_first_frame:idx + 1, split_idx:].copy()).float()
            else:
                episode_frames['obs'] = torch.tensor(dataset["observations"][ep_first_frame:idx + 1, :].copy()).float()

            # if next_obs:
            #     episode_frames["next_obs"] = torch.tensor(dataset["next_observations"][ep_first_frame:idx + 1, :].copy()).float()

            length = len(episode_frames['action'])
            ep_lengths.append(length)
            

            # Pad episode
            rest = max_episode_length - length
            for k,v in episode_frames.items():
                padding = torch.zeros(rest,*v.shape[1:],dtype=v.dtype)
                episode_frames[k] = torch.concatenate([v,padding],dim=0).contiguous()

            # Store episode
            ep_first_frame = idx + 1
            tensor_dataset.append(episode_frames)

    
    if 'halfcheetah' in env_name or 'walker2d' in env_name or 'hopper' in env_name:
        print("normalizing rewards in dataset...")
        min_r = min(returns)
        max_r = max(returns)
        for episode in tensor_dataset:
            episode["reward"] /= (max_r - min_r)
            episode["reward"] *= 1000.0

    # Store in db
    for episode in tensor_dataset:
        db.write("MyApp", str(uuid.UUID(int=np.random.randint(1e7), version=4)), "action", episode)
    
    return dataset["observations"][0]