// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "minispiel/games/nannon.h"

#include <algorithm>
#include <cstdlib>
#include <set>
#include <utility>
#include <vector>

#include "minispiel/abseil-cpp/absl/strings/str_cat.h"
#include "minispiel/game_parameters.h"
#include "minispiel/spiel.h"
#include "minispiel/spiel_utils.h"

namespace open_spiel::nannon {
namespace {

//const std::vector<std::pair<Action, double>> kChanceOutcomes = {
//    std::pair<Action, double>(0, 1.0 / 6),
//    std::pair<Action, double>(1, 1.0 / 6),
//    std::pair<Action, double>(2, 1.0 / 6),
//    std::pair<Action, double>(3, 1.0 / 6),
//    std::pair<Action, double>(4, 1.0 / 6),
//    std::pair<Action, double>(5, 1.0 / 6),
//};
//
//static const std::vector<int> kChanceOutcomeValues = {1, 2, 3, 4, 5, 6};


// Facts about the game
static const GameType kGameType{
    /*short_name=*/"nannon",
    /*long_name=*/"Nannon",
    GameType::Dynamics::kSequential,
    GameType::ChanceMode::kExplicitStochastic,
    GameType::Information::kPerfectInformation,
    GameType::Utility::kZeroSum,
    GameType::RewardModel::kTerminal,
    /*min_num_players=*/2,
    /*max_num_players=*/2,
    /*provides_information_state_string=*/false,
    /*provides_information_state_tensor=*/false,
    /*provides_observation_string=*/true,
    /*provides_observation_tensor=*/true,
    /*parameter_specification=*/
    {
        {"n_points",
         GameParameter(static_cast<int>(kDefaultNumPoints))},
        {"n_chex",
         GameParameter(static_cast<int>(kDefaultNumCheckersPerPlayer))},
        {"n_die",
         GameParameter(static_cast<int>(kDefaultNumChanceOutcomes))}}};

std::shared_ptr<const Game> Factory(const GameParameters& params) {
  return std::shared_ptr<const Game>(new NannonGame(params));
}

REGISTER_SPIEL_GAME(kGameType, Factory);
}  // namespace

NannonState::NannonState(std::shared_ptr<const Game> game,
        int n_points,
        int n_chex_per_player,
        int n_die,
        std::vector<std::pair<Action, double>> chance_outcomes,
        std::vector<int> chance_outcome_values,
        int state_encoding_size,
        int home_pos,
        int score_pos) :
            State(game),
            n_points_(n_points),
            n_chex_per_player_(n_chex_per_player),
            n_die_(n_die),
            chance_outcomes_(chance_outcomes),
            chance_outcome_values_(chance_outcome_values),
            state_encoding_size_(state_encoding_size),
            home_pos_(home_pos),
            score_pos_(score_pos),
            cur_player_(kChancePlayerId),
            prev_player_(kChancePlayerId),
            true_prev_player_(kChancePlayerId),
            turns_(-1),
            x_turns_(0),
            o_turns_(0),
            dice_({}),
            scores_({0, 0}),
            home_({0, 0}),
            board_({std::vector<int>(n_points, 0), std::vector<int>(n_points, 0)}),
            turn_history_info_({}) {

    InitalizeBoard();
}

void NannonState::InitalizeBoard() {
    home_[kXPlayerId] = 1;
    for (int i = 0; i < n_chex_per_player_ - 1; ++i) {
        board_[kXPlayerId][i] = 1;
    }
    home_[kOPlayerId] = 1;
    for (int i = 0; i < n_chex_per_player_ - 1; ++i) {
        int idx = n_points_ - 1 - i;
        board_[kOPlayerId][idx] = 1;
    }
}

std::string NannonState::ObservationString(Player player) const {
    return ToString();
}


void NannonState::ApplyChanceNodeAction(Action chance_outcome) {
    turn_history_info_.emplace_back(kChancePlayerId, prev_player_, dice_,
                                           chance_outcome, false);
    if (turns_ == -1 && dice_.size() < 2) {
        // Initial dice roll to determine who starts.
        RollDice(chance_outcome);
        return;
    } else if (turns_ == -1 && dice_.size() == 2) {
        // Start of game: see who won the toss (on a single dice).
        if (dice_[0] == dice_[1]) {
            // Tie. Start again!
            dice_.clear();
            RollDice(chance_outcome);
            return;
        }
        int winner_die, loser_die;
        // The dice_[0] vs dice_[1] will determine the starting player.
        if (dice_[0] > dice_[1]) {
            // X starts.
            cur_player_ = prev_player_ = kXPlayerId;
            winner_die = dice_[0];
            loser_die = dice_[1];
        } else {
            // O starts.
            cur_player_ = prev_player_ = kOPlayerId;
            winner_die = dice_[1];
            loser_die = dice_[0];
        }
        int winner_dieval = winner_die - loser_die;
        dice_.clear();
        dice_.push_back(winner_dieval);
        turns_ = 0;
        return;
    } else {
        // Normal chance node.
        SPIEL_CHECK_TRUE(dice_.empty());
        RollDice(chance_outcome);
        cur_player_ = Opponent(prev_player_);
        return;
    }
}

bool NannonState::IsTerminal() const {
    return (scores_[kXPlayerId] == n_chex_per_player_ ||
            scores_[kOPlayerId] == n_chex_per_player_);
}

Player NannonState::CurrentPlayer() const {
    return IsTerminal() ? kTerminalPlayerId : Player{cur_player_};
}

std::set<int> NannonState::GetBlockedPoints(std::vector<int> curr_board,
                                            std::vector<int> opp_board) const {
    std::set<int> blocked_points;
    for (int i = 0; i < n_points_; ++i) {
        if (curr_board[i] != 0) {
            blocked_points.insert(i);
        }
    }

    for (int i = 0; i < n_points_ - 1; ++i) {
        int curr = opp_board[i];
        int next = opp_board[i + 1];
        if (curr != 0 && next != 0) {
            blocked_points.insert(i);
            blocked_points.insert(i + 1);
        }
    }
    return blocked_points;
}

std::set<int> NannonState::GetHittablePoints(std::vector<int> board) const {
    std::set<int> hittable_points;
    for (int i = 0; i < n_points_; ++i) {
        int curr = board[i];
        if (curr == 0) {
            continue;
        }
        bool left_neighbor;
        bool right_neighbor;
        if (i == 0) {
            left_neighbor = false;
        } else {
            left_neighbor = board[i - 1] == 1;
        }
        if (i == n_points_ - 1) {
            right_neighbor = false;
        } else {
            right_neighbor = board[i + 1] == 1;
        }
        if (!left_neighbor && !right_neighbor) {
            hittable_points.insert(i);
        }
    }
    return hittable_points;
}

std::set<CheckerMove> NannonState::LegalCheckerMoves(int player) const {
    std::set<CheckerMove> legal_moves;
    std::vector<int> curr_board = board_[CurrentPlayer()];
    std::vector<int> opp_board = board_[Opponent(CurrentPlayer())];
    std::set<int> blocked_points = GetBlockedPoints(curr_board, opp_board);
    std::set<int> hittable_points = GetHittablePoints(opp_board);
    if (home_[player] > 0) {
        int pos = PositionFromHome(player, dice_[0]);
        if (blocked_points.find(pos) == blocked_points.end()) {
            bool hit = hittable_points.find(pos) != hittable_points.end();
            legal_moves.insert(CheckerMove(home_pos_, dice_[0], hit));
        }
    }
    for (int i = 0; i < n_points_; ++i) {
        if (board_[player][i] > 0) {
            int pos = PositionFrom(player, i, dice_[0]);
            if (blocked_points.find(pos) == blocked_points.end()) {
                bool hit = hittable_points.find(pos) != hittable_points.end();
                legal_moves.insert(CheckerMove(i, dice_[0], hit));
            }
        }
    }
    return legal_moves;
}

int NannonState::PositionFromHome(int player, int spaces) const {
    if (player == kXPlayerId) {
        return -1 + spaces;
    } else if (player == kOPlayerId) {
        return n_points_ - spaces;
    } else {
        SpielFatalError(absl::StrCat("Invalid player:", player));
    }
}

int NannonState::PositionFrom(int player, int pos, int spaces) const {
    if (pos == home_pos_) {
        return PositionFromHome(player, spaces);
    }

    if (player == kXPlayerId) {
        int new_pos = pos + spaces;
        return (new_pos >= n_points_ ? score_pos_ : new_pos);
    } else if (player == kOPlayerId) {
        int new_pos = pos - spaces;
        return (new_pos < 0 ? score_pos_ : new_pos);
    } else {
        SpielFatalError(absl::StrCat("Invalid player:", player));
    }
}

std::vector<Action> NannonState::ProcessLegalMoves(const std::set<CheckerMove>& legal_moves) const {
    if (legal_moves.empty()) {
        return {CheckerMoveToSpielMove(CheckerMove(kPassPos, -1, false))};
    }
    std::vector<Action> legal_actions;
    legal_actions.reserve(legal_moves.size());
    for (const auto& move : legal_moves) {
        legal_actions.push_back(CheckerMoveToSpielMove(move));
    }
    SPIEL_CHECK_FALSE(legal_actions.empty());
    return legal_actions;
}


std::vector<Action> NannonState::LegalActions() const {
    if (IsChanceNode()) return LegalChanceOutcomes();
    if (IsTerminal()) return {};
    std::set<CheckerMove> legal_moves = LegalCheckerMoves(cur_player_);
    std::vector<Action> legal_actions = ProcessLegalMoves(legal_moves);
    std::sort(legal_actions.begin(), legal_actions.end());
    return legal_actions;
}

std::vector<std::pair<Action, double>> NannonState::ChanceOutcomes() const {
    SPIEL_CHECK_TRUE(IsChanceNode());
    return chance_outcomes_;
}

std::string NannonState::ActionToString(Player player, Action action) const {
    if (player == kChancePlayerId) {
        return absl::StrCat("chance outcome ", action,
                            " (roll: ", chance_outcome_values_[action], ")");
    }
    const CheckerMove move = SpielMoveToCheckerMove(player, action);
    if (move.pos == -1 && move.num == -1) {
        return absl::StrCat(kPassPos, " - ", "Pass");

    }
    std::string move_start;
    std::string move_end;
    int end;
    if (player == kOPlayerId) {
        move_start = move.pos == home_pos_ ? "Home" : std::to_string(move.pos + 1);
        end = move.pos == home_pos_ ? n_points_ + 1 - move.num : move.pos + 1 - move.num;
    } else {
        move_start = move.pos == home_pos_ ? "Home" : std::to_string(n_points_ - move.pos);
        end = move.pos == home_pos_ ? n_points_ + 1 - move.num : n_points_ - move.pos - move.num;
    }
    if (end <= 0) {
        move_end = "Off";
    } else if (move.hit) {
        move_end = std::to_string(end) + "*";
    } else {
        move_end = std::to_string(end);
    }
    std::string return_val = absl::StrCat(action, " - ", move_start, "/", move_end);
    return return_val;
}


void NannonState::DoApplyAction(Action action) {
    std::vector<Action> legal_actions = LegalActions();
    bool is_legal {false};
    if(std::count(legal_actions.begin(), legal_actions.end(), action)) {
        is_legal = true;
    }
    if (!is_legal) {
        SpielFatalError("Invalid action passed to DoApplyAction");
    }

    true_prev_player_ = prev_player_;
    if (IsChanceNode()) {
        ApplyChanceNodeAction(action);
    } else {
        ApplyNormalAction(action);
    }
}

void NannonState::RollDice(int outcome) {
    dice_.push_back(chance_outcome_values_[outcome]);
}

int NannonState::Opponent(int player) const {
    return 1 - player;
}

void NannonState::ApplyNormalAction(Action action) {
    CheckerMove move = SpielMoveToCheckerMove(cur_player_, action);
    bool move_hit = ApplyCheckerMove(cur_player_, move);
    turn_history_info_.emplace_back(cur_player_,
            prev_player_, dice_, action, move_hit);
    prev_player_ = cur_player_;
    cur_player_ = kChancePlayerId;
    if (CountTotalCheckers(kXPlayerId) != n_chex_per_player_ ||
        CountTotalCheckers(kOPlayerId) != n_chex_per_player_) {
        SpielFatalError("Invalid number of checkers");
    }

    dice_.clear();
}

bool NannonState::ApplyCheckerMove(int player, const CheckerMove &move) {
    if (move.pos < 0) {
        return false;
    }

    int next_pos;
    if (move.pos == home_pos_) {
        home_[player]--;
        next_pos = PositionFromHome(player, move.num);
    } else {
        board_[player][move.pos]--;
        next_pos = PositionFrom(player, move.pos, move.num);
    }

    if (next_pos == score_pos_) {
        scores_[player]++;
    } else {
        board_[player][next_pos]++;
    }
    if (move.hit ||
        (next_pos != score_pos_ && board_[Opponent(player)][next_pos] == 1)) {
        board_[Opponent(player)][next_pos]--;
        home_[Opponent(player)]++;
    }
    return move.hit;
}

CheckerMove NannonState::SpielMoveToCheckerMove(Player player, Action action) const {
    if (action == EncodedPassMove()) {
        return CheckerMove(-1, -1,false);
    }
    int dest_point = action == EncodedHomeMove() ? PositionFromHome(player, dice_[0]) :
                                                   PositionFrom(player, action, dice_[0]);
    std::vector<int> opp_board = board_[Opponent(player)];
    std::set<int> hittable_points = GetHittablePoints(opp_board);
    bool hit = hittable_points.find(dest_point) != hittable_points.end();
    return CheckerMove(action, dice_[0], hit);
}

Action NannonState::CheckerMoveToSpielMove(const CheckerMove &move) const {
    int pos = move.pos;
    if (pos == kPassPos) {
        return EncodedPassMove();;
    } else if (pos == home_pos_) {
        return EncodedHomeMove();
    } else {
        return Action(pos);
    }
}

Action NannonState::EncodedHomeMove() const {
    return Action(home_pos_);
}

Action NannonState::EncodedPassMove() const {
    return Action(home_pos_ + 1);
}

void NannonState::SetState(int cur_player, const std::vector<int> &dice, const std::vector<int> &home,
                           const std::vector<int> &scores, const std::vector<std::vector<int>> &board,
                           const int turns) {

    cur_player_ = cur_player;
    dice_ = dice;
    home_ = home;
    scores_ = scores;
    board_ = board;
    turns_ = turns;
    SPIEL_CHECK_EQ(CountTotalCheckers(kXPlayerId), n_chex_per_player_);
    SPIEL_CHECK_EQ(CountTotalCheckers(kOPlayerId), n_chex_per_player_);
}

int NannonState::CountTotalCheckers(int player) const {
    int total = 0;
    for (int i = 0; i < n_points_; ++i) {
        SPIEL_CHECK_GE(board_[player][i], 0);
        total += board_[player][i];
    }
    SPIEL_CHECK_GE(home_[player], 0);
    total += home_[player];
    SPIEL_CHECK_GE(scores_[player], 0);
    total += scores_[player];
    return total;
}

std::vector<double> NannonState::Returns() const {
    if (scores_[kXPlayerId] == n_chex_per_player_) {
        return {1.0, -1.0};
    } else if (scores_[kOPlayerId] == n_chex_per_player_) {
        return {-1.0, 1.0};
    } else {
        return {0.0, 0.0};
    }
}

std::unique_ptr<State> NannonState::Clone() const {
    return std::unique_ptr<State>(new NannonState(*this));
}

std::string NannonState::GetBoardString() const {
    std::string board(n_points_, '-');
    for (int pos = 0; pos < n_points_; ++pos) {
        if (board_[kXPlayerId][pos] > 0) {
            board[pos] = 'x';
        } else if (board_[kOPlayerId][pos] > 0) {
            board[pos] = 'o';
        }
    }
    return board;
}

std::string NannonState::GetXPlayerHomeString() const {
    int num_x_home = home_[kXPlayerId];
    std::string x_home(n_chex_per_player_, ' ');
    for (int i = n_chex_per_player_ - 1; i >=0; --i) {
        if (num_x_home > 0) {
            x_home[i] = 'x';
            num_x_home--;
        }
    }
    return x_home;
}

std::string NannonState::GetOPlayerHomeString() const {
    int num_o_home = home_[kOPlayerId];
    std::string o_home(n_chex_per_player_, ' ');
    for (int i = 0; i < n_chex_per_player_; ++i) {
        if (num_o_home > 0) {
            o_home[i] = 'o';
            num_o_home--;
        }
    }
    return o_home;
}

std::string NannonState::ToString() const {
    std::string board = GetBoardString();
    std::string x_home = GetXPlayerHomeString();
    std::string o_home = GetOPlayerHomeString();
    std::string board_str = x_home + "|" + board + '|' + o_home + '\n';
    absl::StrAppend(&board_str, "Turn: ");
    absl::StrAppend(&board_str, CurrentPlayerHumanReadable());
    absl::StrAppend(&board_str, "\n");
    absl::StrAppend(&board_str, "Dice: ");
    absl::StrAppend(&board_str,  DiceHumanReadable());
    absl::StrAppend(&board_str, "\n");
    absl::StrAppend(&board_str, "Scores, X: ", scores_[kXPlayerId]);
    absl::StrAppend(&board_str, ", O: ", scores_[kOPlayerId], "\n");
    return board_str;
}

std::string NannonState::CurrentPlayerHumanReadable() const {
    if (cur_player_ == kChancePlayerId) {
        return "Chance";
    } else if (cur_player_ == kXPlayerId) {
        return "x";
    } else if (cur_player_ == kOPlayerId) {
        return "o";
    } else {
        SpielFatalError("Invalid player");
    }
}

std::string NannonState::DiceHumanReadable() const {
    if (dice_.empty()) {
        return "No roll";
    } else if (dice_.size() == 1) {
        return std::to_string(dice_[0]);
    } else if (dice_.size() == 2) {
        return std::to_string(dice_[0]) + "/" + std::to_string(dice_[1]);
    } else {
        SpielFatalError("Invalid number of dice");
    }
}

void NannonState::ObservationTensor(Player player, std::vector<double> *values) const {
    SPIEL_CHECK_GE(player, 0);
    SPIEL_CHECK_LE(player, 1);
    int opponent = Opponent(player);
    values->clear();
    values->reserve(game_->ObservationTensorSize());

    values->push_back(home_[kXPlayerId]); // 1

    for (int count : board_[kXPlayerId]) { // 1 + 6 = 7
        values->push_back((count == 1) ? 1 : 0);
    }
    values->push_back(scores_[kXPlayerId]); // 1 + 7 =8
    values->push_back(scores_[kOPlayerId]); // 1 + 8 =9

    for (int count : board_[kOPlayerId]) { // 9 + 6 = 15
        values->push_back((count == 1) ? 1 : 0);
    }
    values->push_back(home_[kOPlayerId]); // 1 + 15 = 16
    if (dice_.empty()){
        values->push_back(-1);
    } else {
        values->push_back(dice_[0]);
    }
    values->push_back(true_prev_player_); // 1 + 16 = 17
    values->push_back(cur_player_); // 1 + 17 = 18
    SPIEL_CHECK_EQ(state_encoding_size_, values->size());
}

void NannonState::UndoAction(Player player, Action action) {
    const TurnHistoryInfo& thi = turn_history_info_.back();
    SPIEL_CHECK_EQ(thi.player, player);
    SPIEL_CHECK_EQ(action, thi.action);
    cur_player_ = thi.player;
    prev_player_ = thi.prev_player;
    dice_ = thi.dice;
    if (player != kChancePlayerId) {
        CheckerMove move = SpielMoveToCheckerMove(player, action);
        move.hit = thi.move_hit;
        UndoCheckerMove(player, move);
        turns_--;
        if (player == kXPlayerId) {
            x_turns_--;
        } else if (player == kOPlayerId) {
            o_turns_--;
        }
    }
    turn_history_info_.pop_back();
    history_.pop_back();
}

void NannonState::UndoCheckerMove(int player, const CheckerMove &move) {
    // Undoing a pass does nothing
    if (move.pos < 0) {
        return;
    }
    int next_pos = move.pos == home_pos_ ? PositionFromHome(player, move.num) :
                                           PositionFrom(player, move.pos, move.num);
    if (move.hit) {
        home_[Opponent(player)]--;
        board_[Opponent(player)][next_pos]++;
    }

    if (next_pos == score_pos_) {
        scores_[player]--;
    } else {
        board_[player][next_pos]--;
    }

    if (move.pos == score_pos_) {
        home_[player]++;
    } else {
        board_[player][move.pos]++;
    }
}

const std::vector<std::pair<Action, double>> GetChanceOutcomes(int n_die) {
    std::vector<std::pair<Action, double>> chance_outcomes;
    for (int i = 0; i < n_die; i++) {
        std::pair<Action, double> pair = {i, 1.0 / n_die};
        chance_outcomes.push_back(pair);
    }
    return chance_outcomes;
}

std::vector<int> GetChanceOutcomeValues(int n_die) {
    std::vector<int> values;
    for (int i = 0; i < n_die; i++) {
        values.push_back(i + 1);
    }
    return values;
}

int GetStateEncodingSize(int n_points) {
    int board_encoding_size = kNumPlayers * n_points;
    int state_encoding_size = 3 * kNumPlayers + board_encoding_size + 1;
    return state_encoding_size;
}

NannonGame::NannonGame(const GameParameters &params)
    : Game(kGameType, params),
    n_points_(ParameterValue<int>("n_points")),
    n_chex_per_player_(ParameterValue<int>("n_chex")),
    n_die_(ParameterValue<int>("n_die")),
    n_distinct_actions_(n_points_ + 2),
    chance_outcomes_(GetChanceOutcomes(n_die_)),
    chance_outcome_values_(GetChanceOutcomeValues(n_die_)),
    state_encoding_size_(GetStateEncodingSize(n_points_)),
    home_pos_(n_points_),
    score_pos_(n_points_ + 1){}

double NannonGame::MaxUtility() const {
    return 1;
}
}  // namespace open_spiel
