import numpy as np
# import numba as nb
from soccer import SoccerEnv

class SoccerBuiltinAgent:
    ADVANCE = 0
    DEFEND = 1
    INTERCEPT = 2
    AVOID = 3

    UP = 0
    DOWN = 1
    LEFT = 2
    RIGHT = 3
    NOOP = 4
    NACTION = 5

    defensiveness = (0.1, 0.9)

    def __init__(self):
        #self.defensiveness = (0.1, 0.9)
        self.player_kind = 0
        self.weights = [0,0,0,0]

    # remember to reset agent at each new game
    def reset_agent(self, player_a_or_b, defend_percentage=0.5):
        # randomly choose between defensentive or offensive
        # r = np.random.rand()
        # player_kind = int(r < defend_percentage)

        # alternate between two kinds to reduce variance
        assert defend_percentage == 0.5 or defend_percentage == 0 or defend_percentage == 1
        if defend_percentage == 0.5:
            self.player_kind = 1 - self.player_kind
            # print (player_a_or_b, self.player_kind)
        elif defend_percentage == 0:
            self.player_kind = 1
        else:
            self.player_kind = 0

        defensive = self.defensiveness[self.player_kind]
        offensive = 1 - defensive
        self.weights[self.ADVANCE] = offensive
        self.weights[self.INTERCEPT] = offensive
        self.weights[self.AVOID] = defensive 
        self.weights[self.DEFEND] = defensive
        
        self.player_a_or_b = player_a_or_b
        #   0 1 2 3 4 5 6 7 8
        # 0
        # 1
        # 2 A   a a   b b   B
        # 3 A   a a   b b   B
        # 4
        # 5
        # objective for A is to reach opp_goal_area
        if player_a_or_b == 0:
        #     self.coord = env.coord_a
        #     self.goal_area = env.goal_area_b
        #     self.opp_coord = env.coord_b
        #     self.opp_goal_area = env.goal_area_a
            self.goal_area = (0, (2, 3))
            self.opp_goal_area = (8, (2, 3))
        else:
        #     self.coord = env.coord_b
        #     self.goal_area = env.goal_area_a
        #     self.opp_coord = env.coord_a
        #     self.opp_goal_area = env.goal_area_b
            self.goal_area = (8, (2, 3))
            self.opp_goal_area = (0, (2, 3))
        self.prev_action = None

    # go toward the opponent's goal area
    def act_advance(self, actions, coord, goal_area):
        score = self.weights[self.ADVANCE]
        if coord[0] < goal_area[0]:
            actions[self.RIGHT] += score
        elif coord[0] > goal_area[0]:
            actions[self.LEFT] += score
        if coord[1] < goal_area[1][0]:
            actions[self.UP] += score
        elif coord[1] > goal_area[1][1]:
            actions[self.DOWN] += score

    # go toward my goal area
    def act_defend(self, actions, coord, goal_area):
        score = self.weights[self.DEFEND]
        # if in front of goal, then moving within the goal
        if abs(coord[0] - goal_area[0]) == 1:
            if coord[1] <= goal_area[1][0]:
                actions[self.UP] += score
            elif coord[1] >= goal_area[1][1]:
                actions[self.DOWN] += score
            else:
                actions[self.UP] += score
                actions[self.DOWN] += score
        # moving toward the goal
        else:
            if coord[0] < goal_area[0]:
                actions[self.RIGHT] += score
            #else: # ???
            elif coord[0] > goal_area[0]:
                actions[self.LEFT] += score
            if coord[1] <= goal_area[1][0]:
                actions[self.UP] += score
            elif coord[1] >= goal_area[1][1]:
                actions[self.DOWN] += score

    def is_adjacent(self, coord_a, coord_b):
        return (abs(coord_a[0] - coord_b[0]) == 1 and
                abs(coord_a[1] - coord_b[1]) == 0) or \
               (abs(coord_a[0] - coord_b[0]) == 0 and 
                abs(coord_a[1] - coord_b[1]) == 1)

    # intercept the opponent
    def act_intercept(self, actions, coord, opp_coord):
        score = self.weights[self.INTERCEPT]
        if self.is_adjacent(coord, opp_coord):
        # if self.env.is_adjacent():
            actions[self.NOOP] += score
        else:
            if coord[0] < opp_coord[0]:
                actions[self.RIGHT] += score
            elif coord[0] > opp_coord[0]:
                actions[self.LEFT] += score
            if coord[1] < opp_coord[1]:
                actions[self.UP] += score
            elif coord[1] > opp_coord[1]:
                actions[self.DOWN] += score

    # avoid the opponent
    def act_avoid(self, actions, coord, opp_coord):
        score = self.weights[self.AVOID]
        if coord[0] <= opp_coord[0]:
            actions[self.LEFT] += score
        if coord[0] >= opp_coord[0]:
            actions[self.RIGHT] += score
        if coord[1] <= opp_coord[1]:
            actions[self.DOWN] += score
        if coord[1] >= opp_coord[1]:
            actions[self.UP] += score

    def act_(self, has_ball, coord, goal_area, opp_coord, opp_goal_area, prev_action):
        actions = np.zeros(self.NACTION)  #[0]*SoccerEnv.NACTION
        if has_ball:
            self.act_advance(actions, coord, opp_goal_area)
            self.act_avoid(actions, coord, opp_coord)
        else:
            self.act_defend(actions, coord, goal_area)
            self.act_intercept(actions, coord, opp_coord)
        sorted_actions = np.argsort(actions)[::-1]
        scores = actions[sorted_actions]
        n = 1
        for i in range(1, self.NACTION):
            if scores[i] == scores[i-1]:
                n += 1
            else:
                break
        for i in range(n):
            if sorted_actions[i] == prev_action:
                return prev_action
        return sorted_actions[np.random.randint(n)]
    
    def act(self, obs):
        assert len(obs) == 5
        if len(obs) == 5:
            has_ball = obs[-1] == self.player_a_or_b
            if self.player_a_or_b == 0:
                coord = obs[:2]
                opp_coord = obs[2:4]
            else:
                coord = obs[2:4]
                opp_coord = obs[:2]
        else: # len(obs) == 32
            has_ball = obs[30] == self.player_a_or_b
            if self.player_a_or_b == 0:
                coord = (np.argmax(obs[:9]), np.argmax(obs[9:15]))
                opp_coord = (np.argmax(obs[15:24]), np.argmax(obs[24:30]))
            else:
                coord = (np.argmax(obs[15:24]), np.argmax(obs[24:30]))
                opp_coord = (np.argmax(obs[:9]), np.argmax(obs[9:15]))
        self.prev_action = self.act_(has_ball, coord, self.goal_area,
                                     opp_coord, self.opp_goal_area, self.prev_action)
        return self.prev_action

###############################################################################

class BimatrixGame():
    def __init__(self, A, B):
        self.A = A
        self.B = B
        self.nact_x, self.nact_y = A.shape
#         self.reset()
    def reset(self):
        pass
    def step(self, action):
        x, y = action # row and column player, discrete action space
        ra = self.A[x, y]
        rb = self.B[x, y]
        return None, (ra, rb), True, {}
#     def render(self, mode='text', close=False): pass

# class ZeroSumGame(BimatrixGame):
    # def __init__(self, A):
        # BimatrixGame.__init__(self, A, -A)
class ZeroSumGame:
    def __init__(self, A):
        self.A = np.array(A, np.float32)
        self.nact_x, self.nact_y = self.A.shape
    def reset(self):
        return np.empty(0), np.float32(0), False, {}
    def step(self, action):
        return np.empty(0), self.A[action[0], action[1]], True, {}

class MatchingPenny(ZeroSumGame):
    def __init__(self, **kwargs):
        A = np.array([[1, -1],
                      [-1, 1]], dtype=np.float32)
        ZeroSumGame.__init__(self, A, **kwargs)

# the game has two nash equilibria:
#  EE  1  P1:  (1)  1/2  1/2  EP=  0  P2:  (1)  1/2  1/2    0  EP=  0
#  EE  2  P1:  (1)  1/2  1/2  EP=  0  P2:  (2)    0  2/3  1/3  EP=  0
# any interpolation of the two is also NE.
# by http://cgi.csc.liv.ac.uk/~rahul/bimatrix_solver/
class MatchingPennyAugmented(ZeroSumGame):
    def __init__(self, **kwargs):
        A = np.array([[1, -1, -2],
                      [-1, 1, 2]], dtype=np.float32)
        ZeroSumGame.__init__(self, A, **kwargs)

class RockPaperScissors(ZeroSumGame):
    ROCK = 0
    PAPER = 1
    SCISSORS = 2
    def __init__(self):
        A = np.array([[0, -1, 1], [1, 0, -1], [-1, 1, 0]], dtype=np.float32)
        ZeroSumGame.__init__(self, A)

# Iterated Bi-matrix Game, assume state is memory-1 history
class BimatrixGame1Memory():
    def __init__(self, A, B, state_is_ind=True):
        self.A = A
        self.B = B
        self.state_is_ind = state_is_ind
        self.nact = len(A)
        self.NACTION = self.nact
        self.nstate = 1 + self.nact**2
        # assert A.shape[0] == self.nact
        assert A.shape[1] == self.nact
        assert B.shape[0] == self.nact
        assert B.shape[1] == self.nact
        self.reset()
    def reset(self):
        self.state = np.array(0, dtype=np.int32)
        return self.state, (0, 0), False, {}
    def step(self, action):
        x, y = action # row and column player, discrete action space
        ra = self.A[x, y]
        rb = self.B[x, y]
        last_state = self.state
        self.state = np.array(1 + x * nact + y, dtype=np.int32)
        return last_state, (ra, rb), False, {}
    #def state_to_ind(self, s):
    #    return 1 + x * nact + y

class ZeroSumGame1Memory(BimatrixGame1Memory):
    def __init__(self, A, **kwargs):
        BimatrixGame1Memory.__init__(self, A, -A, **kwargs)
    def reset(self):
        self.state = np.array(0, dtype=np.int32)
        return self.state, 0, False, {}
    def step(self, action):
        x, y = action # row and column player, discrete action space
        ra = self.A[x, y]
        last_state = self.state
        self.state = np.array(1 + x * self.nact + y, dtype=np.int32)
        return last_state, ra, False, {}

class MatchingPenny1Memory(ZeroSumGame1Memory):
    def __init__(self, **kwargs):
        A = np.array([[1, -1], [-1, 1]], dtype=np.float32)
        ZeroSumGame1Memory.__init__(self, A, **kwargs)
    # def step(self, action):
        # return ZeroSumGame1Memory.step(self, action)

class RockPaperScissors(ZeroSumGame):
    ROCK = 0
    PAPER = 1
    SCISSORS = 2
    def __init__(self):
        A = np.array([[0, -1, 1], [1, 0, -1], [-1, 1, 0]], dtype=np.float32)
        ZeroSumGame.__init__(self, A)

class RockPaperScissors1Memory(ZeroSumGame1Memory):
    ROCK = 0
    PAPER = 1
    SCISSORS = 2
    def __init__(self):
        A = np.array([[0, -1, 1], [1, 0, -1], [-1, 1, 0]], dtype=np.float32)
        ZeroSumGame1Memory.__init__(self, A)

if __name__=="__main__":
    g = RockPaperScissors()
    g.step([RockPaperScissors.PAPER, RockPaperScissors.SCISSORS])


###############################################################################
try:
    import gym
    import gym_renju
    from gym_renju.envs.core.domain.player import PlayerColor
except:
    pass

class RenjuEnv:
    sequential = True

    def __init__(self, board_size=9):
        self.env = gym.make('Renju{0}x{0}-learning-v0'.format(board_size))

    def reset(self):
        # TODO: implement swap2 rule?
        self.swap = False
        o = self.env.reset()
        return o, 0

    # a \in [0,81), make sure it's int instead of tensor
    def step(self, a):
        o,r,t,i = self.env.step(a)
        if t:
            i = 0
            r = self.env.get_rewards()[0]
            if self.swap: r = -r
        else:
            i = self.swap ^ int(
                i['state'].get_next_player() == PlayerColor.WHITE)
            
        return o,r,t,i


class GoEnv:
    sequential = True

    def __init__(self, board=9):
        self.env = gym.make('Go%dx%d-learning-v0' % (board, board))
    def reset(self):
        # 3 x 9 x 9: black stones, white stones, valid places
        o = self.env.reset()
        return o.reshape(-1), 0, False, 0  # 0 - black, 1 - white
    def step(self, a):
        o,r,t,i = self.env.step(a)
        i = int(i['state'].color != 1)
        return o.reshape(-1),r,t,i
