import reverb
from tf_agents.specs import tensor_spec

from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils


from utils import log


def init_tf_replay_buffer(agent):
    pass


def init_reverb_replay_buffer(agent, table_name):
    this_func = "init_replay_buffer"
    log(this_func, f"data_spec: {agent.collect_data_spec}")
    replay_buffer_signature = tensor_spec.from_spec(agent.collect_data_spec)
    replay_buffer_signature = tensor_spec.add_outer_dim(replay_buffer_signature)
    log(this_func, f"replay_buffer_signature: {replay_buffer_signature}")

    table = reverb.Table(
        table_name,
        max_size=30000,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        rate_limiter=reverb.rate_limiters.MinSize(1),
        signature=replay_buffer_signature,
    )

    reverb_server = reverb.Server([table])

    replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        table_name=table_name,
        sequence_length=1,
        local_server=reverb_server,
    )

    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        replay_buffer.py_client, table_name, sequence_length=1
    )

    return rb_observer, replay_buffer


def init_reverb_episode_buffer(agent, table_name):
    this_func = "init_episode_buffer"
    # log(this_func, f"data_spec: {agent.collect_data_spec}")
    replay_buffer_signature = tensor_spec.from_spec(agent.collect_data_spec)
    replay_buffer_signature = tensor_spec.add_outer_dim(replay_buffer_signature)
    # log(this_func, f"replay_buffer_signature: {replay_buffer_signature}")

    table = reverb.Table(
        table_name,
        max_size=2000,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        rate_limiter=reverb.rate_limiters.MinSize(1),
        signature=replay_buffer_signature,
    )

    reverb_server = reverb.Server([table])

    replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        table_name=table_name,
        sequence_length=None,
        local_server=reverb_server,
    )

    rb_observer = reverb_utils.ReverbAddEpisodeObserver(
        replay_buffer.py_client, table_name, 2000
    )

    return rb_observer, replay_buffer
