import random
import numpy as np
from symbols.domain.domain import Domain
from symbols.logger.precondition_logger import PreconditionLogger
from symbols.logger.transition_logger import TransitionLogger




def _uniform_sample_policy(domain: Domain):
    """
    Select an admissible action uniformly at random. If none exists, return None
    :param domain: the domain
    :return: an action selected uniformly randomly
    """
    return domain.action_space.sample()
    # return random.choice(domain.admissible_actions)


def gather_trajectories(domain: Domain,
                        output_directory,
                        random_starts=True,
                        num_episodes=40,
                        max_episode_length=300,
                        sampling_policy=_uniform_sample_policy,
                        verbose=False):
    """
    Gather trajectories from the given environment
    :param random_starts: whether or not to start each episode from a random location
    :param domain: the environment
    :param output_directory: the directory to write the data to
    :param num_episodes: the number of episodes to collect data over
    :param max_episode_length: the maximum episode length allowed
    :param sampling_policy: the policy to be used when collecting data. By default, a uniformly random policy is used
    :param verbose: whether to print information to screen
    """

    # n_options = domain.action_space.spaces[0].n  # action space is action x object
    n_options = domain.action_space.n  # action space is action x object

    transition_logger = TransitionLogger(output_directory, n_options)
    precondition_logger = PreconditionLogger(output_directory, n_options)
    for episode in range(0, num_episodes):
        if verbose:
            print('Running episode ' + str(episode) + '...')
        state = domain.reset(random_starts=random_starts)  # restart the domain at the start state
        observation = domain.current_observation

        for _ in range(0, max_episode_length):
            admissible_actions = domain.admissible_actions
            for option in domain.action_space:
                precondition_logger.log_sample(episode, state, observation, option, option in admissible_actions)

            action = sampling_policy(domain)

            next_state, reward, done, _ = domain.step(action)
            next_observation = domain.current_observation
            if not np.array_equal(state, next_state):
                transition_logger.log_sample(episode, state, observation, action, reward, next_state, next_observation)
            if done:
                break
            state = next_state
            observation = next_observation

    precondition_logger.close()
    transition_logger.close()
    if verbose:
        print('Completed data collection')
