import numpy as np
import random

"""This tutorial shows how to train a DQN agent on the connect four environment, using curriculum learning and self play.

Author: Nick (https://github.com/nicku-a), Jaime (https://github.com/jaimesabalbermudez)
"""

class Opponent:
   """Connect 4 opponent to train and/or evaluate against.

   :param env: Environment to learn in
   :type env: PettingZoo-style environment
   :param difficulty: Difficulty level of opponent, 'random', 'weak' or 'strong'
   :type difficulty: str
   """

   def __init__(self, env, difficulty):
      self.env = env.env
      self.difficulty = difficulty
      if self.difficulty == "random":
         self.getAction = self.random_opponent
      elif self.difficulty == "weak":
         self.getAction = self.weak_rule_based_opponent
      else:
         self.getAction = self.strong_rule_based_opponent
      self.num_cols = 7
      self.num_rows = 6
      self.length = 4
      self.top = [0] * self.num_cols

   def update_top(self):
      """Updates self.top, a list which tracks the row on top of the highest piece in each column."""
      board = np.array(self.env.env.board).reshape(self.num_rows, self.num_cols)
      non_zeros = np.where(board != 0)
      rows, cols = non_zeros
      top = np.zeros(board.shape[1], dtype=int)
      for col in range(board.shape[1]):
         column_pieces = rows[cols == col]
         if len(column_pieces) > 0:
               top[col] = np.min(column_pieces) - 1
         else:
               top[col] = 5
      full_columns = np.all(board != 0, axis=0)
      top[full_columns] = 6
      self.top = top

   def random_opponent(self, action_mask, last_opp_move=None, block_vert_coef=1):
      """Takes move for random opponent. If the lesson aims to randomly block vertical wins with a higher probability, this is done here too.

      :param action_mask: Mask of legal actions: 1=legal, 0=illegal
      :type action_mask: List
      :param last_opp_move: Most recent action taken by agent against this opponent
      :type last_opp_move: int
      :param block_vert_coef: How many times more likely to block vertically
      :type block_vert_coef: float
      """
      if last_opp_move is not None:
         action_mask[last_opp_move] *= block_vert_coef
      action = random.choices(list(range(self.num_cols)), action_mask)[0]
      return action

   def weak_rule_based_opponent(self, player):
      """Takes move for weak rule-based opponent.

      :param player: Player who we are checking, 0 or 1
      :type player: int
      """
      self.update_top()
      max_length = -1
      best_actions = []
      for action in range(self.num_cols):
         possible, reward, ended, lengths = self.outcome(
               action, player, return_length=True
         )
         if possible and lengths.sum() > max_length:
               best_actions = []
               max_length = lengths.sum()
         if possible and lengths.sum() == max_length:
               best_actions.append(action)
      best_action = random.choice(best_actions)
      return best_action

   def strong_rule_based_opponent(self, player):
      """Takes move for strong rule-based opponent.

      :param player: Player who we are checking, 0 or 1
      :type player: int
      """
      self.update_top()

      winning_actions = []
      for action in range(self.num_cols):
         possible, reward, ended = self.outcome(action, player)
         if possible and ended:
               winning_actions.append(action)
      if len(winning_actions) > 0:
         winning_action = random.choice(winning_actions)
         return winning_action

      opp = 1 if player == 0 else 0
      loss_avoiding_actions = []
      for action in range(self.num_cols):
         possible, reward, ended = self.outcome(action, opp)
         if possible and ended:
               loss_avoiding_actions.append(action)
      if len(loss_avoiding_actions) > 0:
         loss_avoiding_action = random.choice(loss_avoiding_actions)
         return loss_avoiding_action

      return self.weak_rule_based_opponent(player)  # take best possible move

   def outcome(self, action, player, return_length=False):
      """Takes move for weak rule-based opponent.

      :param action: Action to take in environment
      :type action: int
      :param player: Player who we are checking, 0 or 1
      :type player: int
      :param return_length: Return length of outcomes, defaults to False
      :type player: bool, optional
      """
      if not (self.top[action] < self.num_rows):  # action column is full
         return (False, None, None) + ((None,) if return_length else ())

      row, col = self.top[action], action
      piece = player + 1

      # down, up, left, right, down-left, up-right, down-right, up-left,
      directions = np.array(
         [
               [[-1, 0], [1, 0]],
               [[0, -1], [0, 1]],
               [[-1, -1], [1, 1]],
               [[-1, 1], [1, -1]],
         ]
      )  # |4x2x2|

      positions = np.array([row, col]).reshape(1, 1, 1, -1) + np.expand_dims(
         directions, -2
      ) * np.arange(1, self.length).reshape(
         1, 1, -1, 1
      )  # |4x2x3x2|
      valid_positions = np.logical_and(
         np.logical_and(
               positions[:, :, :, 0] >= 0, positions[:, :, :, 0] < self.num_rows
         ),
         np.logical_and(
               positions[:, :, :, 1] >= 0, positions[:, :, :, 1] < self.num_cols
         ),
      )  # |4x2x3|
      d0 = np.where(valid_positions, positions[:, :, :, 0], 0)
      d1 = np.where(valid_positions, positions[:, :, :, 1], 0)
      board = np.array(self.env.env.board).reshape(self.num_rows, self.num_cols)
      board_values = np.where(valid_positions, board[d0, d1], 0)
      a = (board_values == piece).astype(int)
      b = np.concatenate(
         (a, np.zeros_like(a[:, :, :1])), axis=-1
      )  # padding with zeros to compute length
      lengths = np.argmin(b, -1)

      ended = False
      # check if winnable in any direction
      for both_dir in board_values:
         # |2x3|
         line = np.concatenate((both_dir[0][::-1], [piece], both_dir[1]))
         if "".join(map(str, [piece] * self.length)) in "".join(map(str, line)):
               ended = True
               break

      # ended = np.any(np.greater_equal(np.sum(lengths, 1), self.length - 1))
      draw = True
      for c, v in enumerate(self.top):
         draw &= (v == self.num_rows) if c != col else (v == (self.num_rows - 1))
      ended |= draw
      reward = (-1) ** (player) if ended and not draw else 0

      return (True, reward, ended) + ((lengths,) if return_length else ())
