// Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "open_spiel/algorithms/expected_returns.h"

#include <functional>
#include <string>
#include <vector>

#include "open_spiel/simultaneous_move_game.h"
#include "open_spiel/spiel.h"

namespace open_spiel {
namespace algorithms {
namespace {

// Implements the recursive traversal using a general way to access the
// player's policies via a function that takes as arguments the player id and
// information state.
// We have a special case for the case where we can get a policy just from the
// InfostateString as that gives us a 2x speedup.
//std::vector<double> ExpectedReturnsImpl(
//    const State &state,
//    const std::function<ActionsAndProbs(Player, const std::string &)> &
//    policy_func,
//    int depth_limit) {
//  if (state.IsTerminal() || depth_limit == 0) {
//    return state.Rewards();
//  }
//
//  int num_players = state.NumPlayers();
//  std::vector<double> values(num_players, 0.0);
//  if (state.IsChanceNode()) {
//    ActionsAndProbs action_and_probs = state.ChanceOutcomes();
//    for (const auto &action_and_prob : action_and_probs) {
//      std::unique_ptr<State> child = state.Child(action_and_prob.first);
//      std::vector<double> child_values =
//          ExpectedReturnsImpl(*child, policy_func, depth_limit - 1);
//      for (auto p = Player{0}; p < num_players; ++p) {
//        values[p] += action_and_prob.second * child_values[p];
//      }
//    }
//  } else if (state.IsSimultaneousNode()) {
//    // Walk over all the joint actions, and weight by the product of
//    // probabilities to choose them.
//    values = state.Rewards();
//    auto smstate = dynamic_cast<const SimMoveState *>(&state);
//    SPIEL_CHECK_TRUE(smstate != nullptr);
//    std::vector<ActionsAndProbs> state_policies(num_players);
//    for (auto p = Player{0}; p < num_players; ++p) {
//      state_policies[p] = policy_func(p, state.InformationStateString(p));
//      if (state_policies[p].empty()) {
//        SpielFatalError("Error in ExpectedReturnsImpl; infostate not found.");
//      }
//    }
//    for (const Action flat_action : smstate->LegalActions()) {
//      std::vector<Action> actions =
//          smstate->FlatJointActionToActions(flat_action);
//      double joint_action_prob = 1.0;
//      for (auto p = Player{0}; p < num_players; ++p) {
//        double player_action_prob = GetProb(state_policies[p], actions[p]);
//        SPIEL_CHECK_GE(player_action_prob, 0.0);
//        SPIEL_CHECK_LE(player_action_prob, 1.0);
//        joint_action_prob *= player_action_prob;
//        if (player_action_prob == 0.0) {
//          break;
//        }
//      }
//
//      if (joint_action_prob > 0.0) {
//        std::unique_ptr<State> child = state.Clone();
//        child->ApplyActions(actions);
//        std::vector<double> child_values =
//            ExpectedReturnsImpl(*child, policy_func, depth_limit - 1);
//        for (auto p = Player{0}; p < num_players; ++p) {
//          values[p] += joint_action_prob * child_values[p];
//        }
//      }
//    }
//  } else {
//    // Turn-based decision node.
//    Player player = state.CurrentPlayer();
//    ActionsAndProbs state_policy =
//        policy_func(player, state.InformationStateString());
//    if (state_policy.empty()) {
//      SpielFatalError("Error in ExpectedReturnsImpl; infostate not found.");
//    }
//    values = state.Rewards();
//    for (const Action action : state.LegalActions()) {
//      std::unique_ptr<State> child = state.Child(action);
//      double action_prob = GetProb(state_policy, action);
//      SPIEL_CHECK_GE(action_prob, 0.0);
//      SPIEL_CHECK_LE(action_prob, 1.0);
//      if (action_prob > 0.0) {
//        std::vector<double> child_values =
//            ExpectedReturnsImpl(*child, policy_func, depth_limit - 1);
//        for (auto p = Player{0}; p < num_players; ++p) {
//          values[p] += action_prob * child_values[p];
//        }
//      }
//    }
//  }
//  SPIEL_CHECK_EQ(values.size(), state.NumPlayers());
//  return values;
//}
//
//// Same as above, but the policy_func now takes a State as input in, rather
//// than a string.
//std::vector<double> ExpectedReturnsImpl(
//    const State &state,
//    const std::function<ActionsAndProbs(Player, const State &)> &policy_func,
//    int depth_limit) {
//  if (state.IsTerminal() || depth_limit == 0) {
//    return state.Rewards();
//  }
//
//  int num_players = state.NumPlayers();
//  std::vector<double> values(num_players, 0.0);
//  if (state.IsChanceNode()) {
//    ActionsAndProbs action_and_probs = state.ChanceOutcomes();
//    for (const auto &action_and_prob : action_and_probs) {
//      std::unique_ptr<State> child = state.Child(action_and_prob.first);
//      std::vector<double> child_values =
//          ExpectedReturnsImpl(*child, policy_func, depth_limit - 1);
//      for (auto p = Player{0}; p < num_players; ++p) {
//        values[p] += action_and_prob.second * child_values[p];
//      }
//    }
//  } else if (state.IsSimultaneousNode()) {
//    // Walk over all the joint actions, and weight by the product of
//    // probabilities to choose them.
//    values = state.Rewards();
//    auto smstate = dynamic_cast<const SimMoveState *>(&state);
//    SPIEL_CHECK_TRUE(smstate != nullptr);
//    std::vector<ActionsAndProbs> state_policies(num_players);
//    for (auto p = Player{0}; p < num_players; ++p) {
//      state_policies[p] = policy_func(p, state);
//      if (state_policies[p].empty()) {
//        SpielFatalError("Error in ExpectedReturnsImpl; infostate not found.");
//      }
//    }
//    for (const Action flat_action : smstate->LegalActions()) {
//      std::vector<Action> actions =
//          smstate->FlatJointActionToActions(flat_action);
//      double joint_action_prob = 1.0;
//      for (auto p = Player{0}; p < num_players; ++p) {
//        double player_action_prob = GetProb(state_policies[p], actions[p]);
//        SPIEL_CHECK_GE(player_action_prob, 0.0);
//        SPIEL_CHECK_LE(player_action_prob, 1.0);
//        joint_action_prob *= player_action_prob;
//        if (player_action_prob == 0.0) {
//          break;
//        }
//      }
//
//      if (joint_action_prob > 0.0) {
//        std::unique_ptr<State> child = state.Clone();
//        child->ApplyActions(actions);
//        std::vector<double> child_values =
//            ExpectedReturnsImpl(*child, policy_func, depth_limit - 1);
//        for (auto p = Player{0}; p < num_players; ++p) {
//          values[p] += joint_action_prob * child_values[p];
//        }
//      }
//    }
//  } else {
//    // Turn-based decision node.
//    Player player = state.CurrentPlayer();
//    ActionsAndProbs state_policy = policy_func(player, state);
//    if (state_policy.empty()) {
//      SpielFatalError("Error in ExpectedReturnsImpl; infostate not found.");
//    }
//    values = state.Rewards();
//    for (const Action action : state.LegalActions()) {
//      std::unique_ptr<State> child = state.Child(action);
//      double action_prob = GetProb(state_policy, action);
//      SPIEL_CHECK_GE(action_prob, 0.0);
//      SPIEL_CHECK_LE(action_prob, 1.0);
//      if (action_prob > 0.0) {
//        std::vector<double> child_values =
//            ExpectedReturnsImpl(*child, policy_func, depth_limit - 1);
//        for (auto p = Player{0}; p < num_players; ++p) {
//          values[p] += action_prob * child_values[p];
//        }
//      }
//    }
//  }
//  SPIEL_CHECK_EQ(values.size(), state.NumPlayers());
//  return values;
//}
}  // namespace
//
//std::vector<double> ExpectedReturns(const State &state,
//                                    const std::vector<const Policy *> &policies,
//                                    int depth_limit,
//                                    bool use_infostate_get_policy) {
//  if (use_infostate_get_policy) {
//    return ExpectedReturnsImpl(
//        state,
//        [&policies](Player player, const std::string &info_state) {
//          return policies[player]->GetStatePolicy(info_state);
//        },
//        depth_limit);
//  } else {
//    return ExpectedReturnsImpl(
//        state,
//        [&policies](Player player, const State &state) {
//          return policies[player]->GetStatePolicy(state, player);
//        },
//        depth_limit);
//  }
//}
//
//std::vector<double> ExpectedReturns(const State &state,
//                                    const Policy &joint_policy, int depth_limit,
//                                    bool use_infostate_get_policy) {
//  if (use_infostate_get_policy) {
//    return ExpectedReturnsImpl(
//        state,
//        [&joint_policy](Player player, const std::string &info_state) {
//          return joint_policy.GetStatePolicy(info_state);
//        },
//        depth_limit);
//  } else {
//    return ExpectedReturnsImpl(
//        state,
//        [&joint_policy](Player player, const State &state) {
//          return joint_policy.GetStatePolicy(state, player);
//        },
//        depth_limit);
//  }
//}

std::vector<double> CreateWinProbabilityFromVectors(std::vector<std::vector<double>> win_probs,
                                                    std::vector<double> range) {
  std::vector<double> wp = {0., 0.};
  for (int i = 0; i < range.size(); i++) {
    for (int player = 0; player < 2; player++) {
      wp[player] = wp[player] + win_probs[i][player] / win_probs[i][2] * range[i];
    }
  }
  return wp;
}

void VectorWinProbabilityStep(const std::vector<std::unique_ptr<State>> &state_vector,
                              std::vector<std::vector<double>> &wins,
                              std::mt19937 &rng) {
  // In terminal state we check who won and increment win counter.
  if (state_vector[0]->IsTerminal()) {
    for (int i = 0; i < wins.size(); i++) {
      wins[i][2]++;
      if (state_vector[i]->Rewards()[0] > 0) {
        wins[i][0]++;
      } else if (state_vector[i]->Rewards()[1] > 0) {
        wins[i][1]++;
      }
    }
    return;
  }
  if (state_vector[0]->IsPlayerNode()) {
    for (const std::unique_ptr<State> &state : state_vector) {
      state->ApplyAction(1);
    }
    VectorWinProbabilityStep(state_vector, wins, rng);
  } else if (state_vector[0]->IsChanceNode()) {
    for (const std::unique_ptr<State> &state : state_vector) {
      std::vector<std::pair<Action, double>> outcomes = state->ChanceOutcomes();
      Action action_two =
          open_spiel::SampleAction(outcomes, std::uniform_real_distribution<double>(0.0, 1.0)(rng)).first;
      state->ApplyAction(action_two);
    }
    VectorWinProbabilityStep(state_vector, wins, rng);
  } else {
    SpielFatalError("Weird node type in poker");
  }
}

std::vector<std::vector<double>> VectorWinProbability(std::vector<std::unique_ptr<State>> &state_vector,
                                                      int simulations,
                                                      int seed) {
  std::mt19937 rng(seed);

  int num_states = state_vector.size();

  std::vector<std::vector<double>> wins(num_states, std::vector<double>(3, 0.));

  for (int i = 0; i < simulations; i++) {
    std::vector<std::unique_ptr<State>> new_state_vector;
    new_state_vector.reserve(state_vector.size());
    for (auto &state : state_vector) {
      new_state_vector.push_back(state->Clone());
    }
    VectorWinProbabilityStep(new_state_vector, wins, rng);
  }

  return wins;
}

void WinProbabilityStep(const std::unique_ptr<State> &state,
                        const std::shared_ptr<std::vector<double>> &wins,
                        std::mt19937 &rng) {
  // In terminal state we check who won and increment win counter.
  if (state->IsTerminal()) {
    (*wins)[2]++;
    if (state->Rewards()[0] > 0) {
      (*wins)[0]++;
    } else if (state->Rewards()[1] > 0) {
      (*wins)[1]++;
    }
    return;
  }
  if (state->IsPlayerNode()) {
    state->ApplyAction(1);
    WinProbabilityStep(state, wins, rng);
  } else if (state->IsChanceNode()) {
    std::vector<std::pair<Action, double>> outcomes = state->ChanceOutcomes();
    Action action_two =
        open_spiel::SampleAction(outcomes, std::uniform_real_distribution<double>(0.0, 1.0)(rng)).first;
    state->ApplyAction(action_two);
    WinProbabilityStep(state, wins, rng);
  } else {
    SpielFatalError("Weird node type in poker");
  }
}

std::pair<double, double> WinProbability(std::unique_ptr<State> state, int simulations, int seed) {
  std::mt19937 rng(seed);

  std::shared_ptr<std::vector<double>> wins = std::make_shared<std::vector<double>>(3, 0.);

  for (int i = 0; i < simulations; i++) {
    WinProbabilityStep(state->Clone(), wins, rng);
  }
  std::cout << *wins << "\n";

  return std::pair<double, double>{wins->at(0) / wins->at(2), wins->at(1) / wins->at(2)};
}

double LocalBestResponseUniversal(const std::shared_ptr<const Game> &game,
                                  const TabularPolicy &policy,
                                  Player player,
                                  int seed) {
  std::mt19937 rng(seed);
  std::string my_hand;
  std::unique_ptr<State> state = game->NewInitialState();
  std::vector<std::unique_ptr<State>> state_vector;
  std::vector<double> range;
  // Deal to cards to create my hand.
  if (player == 0) {
    std::vector<std::pair<Action, double>> outcomes = state->ChanceOutcomes();
    Action action_one =
        open_spiel::SampleAction(outcomes, std::uniform_real_distribution<double>(0.0, 1.0)(rng)).first;
    state->ApplyAction(action_one);

    outcomes = state->ChanceOutcomes();
    Action action_two =
        open_spiel::SampleAction(outcomes, std::uniform_real_distribution<double>(0.0, 1.0)(rng)).first;
    state->ApplyAction(action_two);
    // Save my hand
    if (action_one < action_two) {
      my_hand = std::to_string(action_one) + " " + std::to_string(action_two);
    } else {
      my_hand = std::to_string(action_two) + " " + std::to_string(action_one);
    }
    for (Action new_action_one : state->LegalActions()) {
      std::unique_ptr<State> intermediate_child = state->Child(new_action_one);
      for (Action new_action_two : intermediate_child->LegalActions()) {
        state_vector.push_back(intermediate_child->Child(new_action_two));
      }
    }
  } else {
    SpielFatalError("Not implemented yet");
  }
  // Initialize ranges from state vector
  int hands = state_vector.size();
  range.resize(hands);
  std::fill(range.begin(), range.end(), 1. / hands);

  std::vector<std::vector<double>> win_probs = VectorWinProbability(state_vector, 1000, 0);

  std::vector<double> wp = CreateWinProbabilityFromVectors(win_probs, range);

  state_vector[0]->ApplyAction(2);

  std::cout << state_vector[0]->ToString() << "\n";

  return 0;
}

void NormalizeVector(std::vector<double> &to_normalize) {
  double sum = 0.;
  for (double &part : to_normalize) {
    sum = sum + part;
  }
  for (double &part : to_normalize) {
    part = part / sum;
  }
}

double LocalBestResponseLeduc(const std::shared_ptr<const Game> &game,
                              const TabularPolicy &policy,
                              Player player,
                              int seed) {
  std::mt19937 rng(seed);
  int my_hand;
  std::unique_ptr<State> state = game->NewInitialState();
  std::vector<std::unique_ptr<State>> state_vector;
  std::vector<double> range;
  std::vector<int> pot_contributions = {1, 1};
  int round = 1;
  if (player == 0) {
    // Deal card for me
    std::vector<std::pair<Action, double>> outcomes = state->ChanceOutcomes();
    Action action =
        open_spiel::SampleAction(outcomes, std::uniform_real_distribution<double>(0.0, 1.0)(rng)).first;
    state->ApplyAction(action);
    // Save my hand
    my_hand = action;
    // Deal all cards for the opponent
    for (Action new_action : state->LegalActions()) {
      state_vector.push_back(state->Child(new_action));
    }
  } else {
    // Deal all cards for opponent
    for (Action new_action : state->LegalActions()) {
      state_vector.push_back(state->Child(new_action));
    }
    // Check what card I get
    std::vector<std::pair<Action, double>> outcomes = state->ChanceOutcomes();
    Action action =
        open_spiel::SampleAction(outcomes, std::uniform_real_distribution<double>(0.0, 1.0)(rng)).first;
    // Save my hand
    my_hand = action;
    state->ApplyAction(action);
    state_vector.erase(state_vector.begin() + my_hand);
    for (const std::unique_ptr<State> &iter_state : state_vector) {
      iter_state->ApplyAction(action);
    }
  }
  // Initialize ranges from state vector
  int hands = state_vector.size();
  range.resize(hands);
  std::fill(range.begin(), range.end(), 1. / hands);

  std::vector<std::vector<double>> win_probs = VectorWinProbability(state_vector, 1000, 0);

  std::vector<double> wp = CreateWinProbabilityFromVectors(win_probs, range);

  while (!state_vector[0]->IsTerminal()) {
    if (state_vector[0]->IsPlayerActing(player)) {
      int asked = pot_contributions[1 - player] - pot_contributions[player];
      int pot = pot_contributions[0] + pot_contributions[1];

      double call_utility = wp[player] * pot - wp[1 - player] * asked;
      double bet_utility = -1000;
      double fold_probability = 0;
      std::vector<Action> legal_actions = state_vector[0]->LegalActions();
      if (std::find(legal_actions.begin(), legal_actions.end(), 2) != legal_actions.end()) {
        double fp = 0;
        std::vector<double> new_range = range;
        for (int i = 0; i < state_vector.size(); i++) {
          std::unique_ptr<State> child = state_vector[i]->Child(2);
          ActionsAndProbs state_policy = policy.GetStatePolicy(child->InformationStateString());
          for (std::pair<Action, double> &action_and_prob : state_policy) {
            if (action_and_prob.first == 0) {
              fp = action_and_prob.second;
              break;
            }
          }
          fold_probability = fold_probability + range[i] * fp;
          new_range[i] = new_range[i] * (1 - fp);
        }
        NormalizeVector(new_range);
        std::vector<double> bet_wp = CreateWinProbabilityFromVectors(win_probs, new_range);
        bet_utility = fold_probability * pot
            + (1 - fold_probability) * (bet_wp[player] * (pot + 2 * round) - bet_wp[1 - player] * (asked + 2 * round));
      }
      Action action;
      if (bet_utility > call_utility) {
        if (bet_utility > 0) {
          action = 2;
        } else {
          action = 0;
        }
      } else {
        if (call_utility > 0) {
          action = 1;
        } else {
          action = 0;
        }
      }
//      std::cout << "My action " << action << "\n";
      for (auto &iter_state : state_vector) {
        iter_state->ApplyAction(action);
      }
    } else if (state_vector[0]->IsPlayerActing(1 - player)) {
      std::vector<Action> legal_actions = state_vector[0]->LegalActions();
      std::vector<std::pair<Action, double>> probs(legal_actions.size());
      std::vector<std::vector<double>> divided_probs(state_vector.size(), std::vector<double>(legal_actions.size(), 0));
      int state_index = 0;
      for (const std::unique_ptr<State> &iter_state : state_vector) {
        ActionsAndProbs state_policy = policy.GetStatePolicy(iter_state->InformationStateString());
        for (int i = 0; i < legal_actions.size(); i++) {
          for (std::pair<Action, double> &action_and_prob : state_policy) {
            if (action_and_prob.first == legal_actions[i]) {
              probs[i].first = action_and_prob.first;
              probs[i].second = probs[i].second + range[i] * action_and_prob.second;
              divided_probs[state_index][i] = action_and_prob.second;
              break;
            }
          }
        }
        state_index++;
      }
      Action action = SampleAction(probs, std::uniform_real_distribution<double>(0.0, 1.0)(rng)).first;
      int action_index = std::find(legal_actions.begin(), legal_actions.end(), action) - legal_actions.begin();
      state_index = 0;
//      std::cout << "Opp action " << action << "\n";
      for (const std::unique_ptr<State> &iter_state : state_vector) {
        iter_state->ApplyAction(action);
        range[state_index] = range[state_index] * divided_probs[state_index][action_index];
        state_index++;
      }
      NormalizeVector(range);
    } else {
      SPIEL_CHECK_TRUE(state_vector[0]->IsChanceNode());
      std::vector<std::pair<Action, double>> outcomes = state->ChanceOutcomes();
      Action action =
          open_spiel::SampleAction(outcomes, std::uniform_real_distribution<double>(0.0, 1.0)(rng)).first;
      if (action < my_hand) {
        state_vector.erase(state_vector.begin() + action);
        range.erase(range.begin() + action);
      } else {
        state_vector.erase(state_vector.begin() + action - 1);
        range.erase(range.begin() + action - 1);
      }
      for (const std::unique_ptr<State> &iter_state : state_vector) {
        iter_state->ApplyAction(action);
      }
      NormalizeVector(range);
    }
  }

  double reward = 0.;

  int state_index = 0;
  for (const std::unique_ptr<State> &iter_state : state_vector) {
    reward = reward + iter_state->Rewards()[player] * range[state_index];
    state_index++;
  }

//  for (const std::unique_ptr<State> &iter_state : state_vector) {
//
//    std::cout << iter_state->Rewards() << "\n";
//    std::cout << iter_state->ToString() << "\n";
//  }

  return reward;
}
}  // namespace algorithms
}  // namespace open_spiel
