// Copybot 2019 DeepMind Technologies Ltd. All bots reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


#include <utility>
#include <open_spiel/algorithms/expected_returns.h>
#include "bot.h"

namespace open_spiel {
namespace papers_with_code {

SherlockBot::SherlockBot(std::unique_ptr<SubgameFactory> subgame_factory,
                         std::unique_ptr<SolverFactory> solver_factory,
                         Player player_id, int seed, std::string response_type, const TabularPolicy &fixed_policy)
    : subgame_factory_(std::move(subgame_factory)),
      solver_factory_(std::move(solver_factory)),
      player_id_(player_id),
      rnd_gen_(seed),
      solver_(nullptr),
      fixed_policy_(std::make_shared<TabularPolicy>(fixed_policy)) {
  subgame_ = subgame_factory_->MakeTrunk(1);
  current_depth_limit_ = 1;
  sbr_ = false;
  srnr_ = false;
  if (response_type == "sbr") {
    SPIEL_CHECK_GE(fixed_policy_->PolicyTable().size(), 1);
    sbr_ = true;
  } else if (response_type == "srnr") {
    srnr_ = true;
    SPIEL_CHECK_GE(fixed_policy_->PolicyTable().size(), 1);
  } else {
    SPIEL_CHECK_EQ(response_type, "none");
  }
}

SherlockBot::SherlockBot(SherlockBot const &bot)
    : subgame_factory_(bot.subgame_factory_),
      solver_factory_(bot.solver_factory_),
      player_id_(bot.player_id_),
      rnd_gen_(bot.rnd_gen_),
      subgame_(std::make_shared<Subgame>(*bot.subgame_)),
      solver_(bot.solver_),
      sbr_(bot.sbr_),
      srnr_(bot.srnr_),
      fixed_policy_(bot.fixed_policy_),
      current_depth_limit_(bot.current_depth_limit_) {
  SPIEL_CHECK_NE(subgame_.get(), bot.subgame_.get());
}

Action SherlockBot::Step(const State &state) {
  return StepWithPolicy(state).second;
}

void SherlockBot::SetSeed(int seed) {
  rnd_gen_.seed(seed);
}

void SherlockBot::Restart() {
  subgame_ = subgame_factory_->MakeTrunk(1);
  current_depth_limit_ = 1;
}

std::pair<ActionsAndProbs, Action> SherlockBot::StepWithPolicy(const State &state) {
  if (sbr_) {
    SPIEL_CHECK_TRUE(fixed_policy_);
  }
  SPIEL_CHECK_TRUE(subgame_);
  // We are doing step so we decrement depth limit and if it is zero we create new subgame.
  current_depth_limit_--;
  if (current_depth_limit_ == 0) {
    // Here we go through public states to check if the current public state is constructed.
    Observation public_observation(*subgame_factory_->game, subgame_factory_->public_observer);
    public_observation.SetFrom(state, 0);
    PublicState *publicState = nullptr;
    for (PublicState &pubState : subgame_->public_states) {
      if (pubState.public_tensor == public_observation) {
        publicState = &pubState;
        break;
      }
    }
    // There is a public state in the tree, we are at the end (or in root) and we create particles for the public state
    SPIEL_CHECK_TRUE(publicState);
    std::unique_ptr<ParticleSetPartition>
        partition = MakeParticleSetPartition(*publicState, pow(10, 7), pow(10, -9), false, rnd_gen_);
    std::unique_ptr<ParticleSet> set = std::make_unique<ParticleSet>(partition->primary);
    SPIEL_CHECK_FALSE(set->particles.empty());
    // We create new subgame using gadget if necessary.
    if (state.MoveNumber() > 0 and !sbr_) {
      if (state.IsPlayerActing(player_id_)) {
        subgame_ = subgame_factory_->MakeSubgameSafeResolving(*set,
                                                              player_id_,
                                                              publicState->GetCFVs(1 - player_id_),
                                                              subgame_factory_->max_move_ahead_limit + 1);
      } else {
        subgame_ = subgame_factory_->MakeSubgameSafeResolving(*set, player_id_, publicState->GetCFVs(1 - player_id_));
      }
    } else {
      if (state.IsPlayerActing(player_id_)) {
        subgame_ = subgame_factory_->MakeSubgame(*set, subgame_factory_->max_move_ahead_limit + 1);
      } else {
        subgame_ = subgame_factory_->MakeSubgame(*set);
      }
    }
    current_depth_limit_ = subgame_factory_->max_move_ahead_limit + (state.IsPlayerActing(player_id_) ? 1 : 0);
    // This creates new solver for the newly created subgame.
    solver_ = solver_factory_->MakeSolver(subgame_, nullptr, "", true);
    // When we are doing sbr or srnr we fix the opponent strategy.
    if (sbr_ or srnr_) {
      int opponent = 1 - player_id_;
      algorithms::BanditVector &opponent_bandits = solver_->bandits()[opponent];
      for (algorithms::DecisionId id : opponent_bandits.range()) {
        algorithms::InfostateNode *node = subgame_->trees[opponent]->decision_infostate(id);
        std::string infostate = node->infostate_string();
        if (srnr_ and infostate.rfind("free", 0) == 0) {
          continue;
        }
        if (srnr_) {
          SPIEL_CHECK_TRUE(infostate.rfind("fixed", 0) == 0);
          infostate.erase(0, 6);
        }
        ActionsAndProbs infostate_policy = fixed_policy_->GetStatePolicy(infostate);
        std::vector<double> probs = GetProbs(infostate_policy);
        auto fixable_bandit = std::make_unique<algorithms::bandits::FixableStrategy>(probs);
        opponent_bandits[id] = std::move(fixable_bandit);
      }
    }
    // Here we run cfr-d iterations.
    solver_->RunSimultaneousIterations(solver_factory_->cfr_iterations);
    solver_->SetAverageBeliefsInLeaves();
  }


  // TODO: proper management of beliefs between steps. This is just
  //       a dummy initialization. (Not needed when I initialize from public state.)

  // When it is players turn we return action and strategy otherwise we return empty strategy and -1 action.
  if (state.IsPlayerActing(player_id_)) {
    // Here we create infostate from observation
    Observation infostate_observation(*subgame_factory_->game, subgame_factory_->infostate_observer);
    infostate_observation.SetFrom(state, player_id_);
    const std::string infostate =
        subgame_factory_->infostate_observer->StringFrom(state, player_id_);

    // Here we create policy from infostate trees and subgame solver.
    auto policy = solver_->AveragePolicy();
    ActionsAndProbs actions_and_probs = policy->GetStatePolicy(infostate);
    SPIEL_CHECK_FALSE(actions_and_probs.empty());

    // Here we convert it to the output format.
    double p = std::uniform_real_distribution<>(0., 1.)(rnd_gen_);
    std::pair<Action, double> outcome = SampleAction(actions_and_probs, p);
    return {actions_and_probs, outcome.first};
  } else {
    // And we return empty actions and probs and -1 action
    ActionsAndProbs actions_and_probs;
    return {actions_and_probs, Action(-1)};
  }
}

std::unique_ptr<Bot> MakeSherlockBot(
    std::unique_ptr<SubgameFactory> subgame_factory,
    std::unique_ptr<SolverFactory> solver_factory,
    Player player_id, int seed) {
  return std::make_unique<SherlockBot>(std::move(subgame_factory),
                                       std::move(solver_factory),
                                       player_id, seed);
}

std::unique_ptr<SherlockBot> MakeSherlockBot(
    std::unique_ptr<SubgameFactory> subgame_factory,
    std::unique_ptr<SolverFactory> solver_factory,
    Player player_id, int seed, bool sherlock_type) {
  return std::make_unique<SherlockBot>(std::move(subgame_factory),
                                       std::move(solver_factory),
                                       player_id, seed);
}

std::unique_ptr<Bot> MakeSherlockBot(
    std::unique_ptr<SubgameFactory> subgame_factory,
    std::unique_ptr<SolverFactory> solver_factory,
    Player player_id, int seed, std::string response_type, const TabularPolicy &fixed_policy) {
  return std::make_unique<SherlockBot>(std::move(subgame_factory),
                                       std::move(solver_factory),
                                       player_id, seed, response_type, fixed_policy);
}

std::unique_ptr<SherlockBot> MakeSherlockBot(
    std::unique_ptr<SubgameFactory> subgame_factory,
    std::unique_ptr<SolverFactory> solver_factory,
    Player player_id, int seed, bool sherlock_type, std::string response_type, const TabularPolicy &fixed_policy) {
  return std::make_unique<SherlockBot>(std::move(subgame_factory),
                                       std::move(solver_factory),
                                       player_id, seed, response_type, fixed_policy);
}

}  // namespace papers_with_code
}  // namespace open_spiel

