import random
from tf_agents.environments import tf_py_environment
from tf_agents.policies import py_tf_eager_policy
from tf_agents.trajectories import trajectory

from tfagents.distributed_locking.DistributedLocking import reset_envs, step_envs
from tfagents.utils import log


def collect_data(envs, agent, observer, steps, players, is_refine, failed_cases):
    this_func = "DistributedLocking_driver: collect_data"
    num_steps = 0
    policies = [None] * players

    for i in range(players):
        policies[i] = py_tf_eager_policy.PyTFEagerPolicy(agent[0].collect_policy, use_tf_function=True)

    while num_steps < steps:
        if is_refine:
            if num_steps < steps / 2 or failed_cases is None:
                reset_envs(envs, players)
            else:
                reset_envs(envs, players, input=random.choice(list(failed_cases)))
        else:
            reset_envs(envs, players)

        time_steps = [None] * players
        next_time_steps = [None] * players
        trajs = [None] * players

        for i in range(players):
            time_steps[i] = envs[i].reset()

        while not time_steps[0].is_last():
            log(this_func, f"{time_steps}")
            action_steps = []
            actions = []
            for i in range(players):
                action_step = policies[0].action(time_steps[i])
                action_steps.append(action_step)
                actions.append(int(action_step.action))
            log(this_func, f"action selected: {actions}")
            step_envs(envs, actions, players)

            for i in range(players):
                next_time_steps[i] = envs[i].step(action_steps[i])
                traj = trajectory.from_transition(time_steps[i], action_steps[i], next_time_steps[i])
                trajs[i] = traj
            time_steps = next_time_steps[:]

        log(this_func, f"final time steps {[ts for ts in time_steps]}")

        # store trajectories to replay buffer
        for i in range(players):
            log(this_func, f"Store {trajs[i]} into buffer 0")
            observer[0](trajs[i])
        num_steps += 1