import torch
from time import sleep


def create_empty_traj_dict():
    return dict(observations=[], actions=[], rewards=[], dones=[], options=[], terminations=[])


def generate_samples(env, policy, episodes, params=None):
    """
    Generate samples from the environment with given policy

    Args:
        params (OrderedDict): policy parameters (used in Meta-Updates)

    Returns:
        (list): A list with trajectory data with dicts of observations, actions, log_probs, rewards and dones
    """
    traj_data = create_empty_traj_dict()
    traj = create_empty_traj_dict()

    # Reset the environment to get the initial observation
    obs = torch.tensor([env.reset()], dtype=torch.get_default_dtype())
    episodes_done = 0
    current_option = None

    while episodes_done < episodes:
        # Get action from policy
        with torch.no_grad():
            action, current_option, termination = policy.get_action(obs, current_option, params)

        # Perform an action
        next_obs, reward, done, env_infos = env.step(action.numpy())
        env.render()
        # Store data in buffers
        traj["observations"].append(obs)
        traj["actions"].append(action)
        traj["rewards"].append(reward)
        traj["dones"].append(1 if done else 0)
        traj["options"].append(current_option)
        traj["terminations"].append(termination)

        # If done reset otherwise use the new observation
        if done:
            obs = torch.tensor([env.reset()], dtype=torch.get_default_dtype())
            current_option = None
            episodes_done += 1
            traj_data["observations"].append(torch.cat(traj["observations"], dim=0))
            traj_data["actions"].append(torch.cat(traj["actions"], dim=0))
            traj_data["rewards"].append(torch.tensor(traj["rewards"], dtype=torch.get_default_dtype()))
            traj_data["dones"].append(torch.tensor(traj["dones"], dtype=torch.uint8))
            traj_data["options"].append(torch.tensor(traj["options"], dtype=torch.get_default_dtype()))
            traj_data["terminations"].append(torch.tensor(traj["terminations"], dtype=torch.uint8))

            traj = create_empty_traj_dict()
        else:
            obs = torch.tensor([next_obs], dtype=torch.get_default_dtype())

    return traj_data


def visualize(env, policy, episodes=None, params=None, steps_per_second=10):
    # TODO Move this somewhere else?
    episodes_done = 0

    obs = torch.tensor([env.reset()], dtype=torch.get_default_dtype())
    current_option = None
    env.render()

    while episodes_done < episodes:
        # TODO maybe fix this ugly sleep?
        sleep(1/steps_per_second)

        # Get action from policy
        with torch.no_grad():
            action, current_option, termination = policy.get_action(obs, current_option, params)
        # Perform an action
        next_obs, reward, done, env_infos = env.step(action.item())#numpy())

        if done:
            obs = torch.tensor([env.reset()], dtype=torch.get_default_dtype())
            current_option = None
            episodes_done += 1
        else:
            obs = torch.tensor([next_obs], dtype=torch.get_default_dtype())
        env.render()
