# Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example algorithm to sample some states from a game."""

import random


def sample_some_states(game, max_states=100):
  """Samples some states in the game.

  This can be run for large games, in contrast to `get_all_states`. It is useful
  for tests that need to check a predicate only on a subset of the game, since
  generating the whole game is infeasible.

  Currently only works for sequential games.

  The algorithm maintains a list of states and repeatedly picks a random state
  from the list to expand until enough states have been sampled.

  Arguments:
    game: The game to analyze, as returned by `load_game`.
    max_states: The maximum number of states to return. Negative means no limit.

  Returns:
    A `list` of `pyspiel.State`.
  """
  states = []
  unexplored_actions = []
  indexes_with_unexplored_actions = set()

  def add_state(state):
    states.append(state)
    if state.is_terminal():
      unexplored_actions.append(None)
    else:
      indexes_with_unexplored_actions.add(len(states) - 1)
      unexplored_actions.append(set(state.legal_actions()))

  def expand_random_state():
    index = random.choice(list(indexes_with_unexplored_actions))
    state = states[index]
    actions = unexplored_actions[index]
    action = random.choice(list(actions))
    actions.remove(action)
    if not actions:
      indexes_with_unexplored_actions.remove(index)
    return state.child(action)

  add_state(game.new_initial_state())
  while (len(states) < max_states) and indexes_with_unexplored_actions:
    add_state(expand_random_state())

  if not states:
    raise ValueError("get_some_states sampled 0 states!")

  return states
