import jax
import jax.numpy as jnp
import numpy as np
import chex

import functools
from games.jax_game import JaxGame, GameState

INVALID_ID = 0
FOLD_ID = 1
CALL_ID = 2
RAISE_ID = 3

@chex.dataclass(frozen=True)
class LeducGameState(GameState):
    action_history: chex.Array
    public_card: chex.Array
    private_cards: chex.Array
    current_chips: chex.Array
    turns_this_round: chex.Array
    terminal: chex.Array #Remember this to make sure that the state is correctly marked as terminal, 
    #when playing additional actions in terminal state


class JaxLeduc(JaxGame):
  def __init__(self):
    #Invalid action, fold, call, raise
    self.num_actions = 4
    #Two rounds and in each the maximum length trajectory consists of
    # actions Call, Raise, Raise, [Call or Fold],
    self.max_turns = 8
    #Cards in two suits, three cards from each suit
    self.total_cards = 6
    self.players = 2
    #Max 4 raises. Starting at 1, raises in the first
    # round are 2 + 2 and in the second round 4 + 4
    self.max_bet_amount = 13
    self.max_raises_per_round = 2
    self.raise_amount = 2
    #assuming that cards of different suits are counted as 
    #distinct cards
    self.private_chance_outcomes = 30
    #there are 4 cards left in the deck
    #self.public_chance_outcomes = 4
    #self.chance_outcomes = 120
    #JAX constants TODO: Probably put this somewhere else
    self.invalid_action_mask = jax.nn.one_hot(INVALID_ID, self.num_actions)

  
  def num_distinct_actions(self):
    return self.num_actions
  
  def max_chance_outcomes(self):
    return self.private_chance_outcomes
  
  def max_trajectory_length(self):
    return self.max_turns
  
  def information_state_tensor_shape(self):
    # One hot encoded receiving player
    # One hot encoded private card of player
    # One hot encoded public card (1 bit added to recognize not yet revealed)
    # One hot encoded actions in each non-terminal turn
    return self.players + self.total_cards + self.total_cards + 1  + (self.max_turns - 1) * (self.num_actions - 1)
  
  def public_state_tensor_shape(self):
    # One hot encoded public card (1 bit added to recognize not yet revealed)
    # One hot encoded actions in each non-terminal turn (not the invalid added actions)
    return self.total_cards + 1 + (self.max_turns - 1) * (self.num_actions - 1)
  
  def generate_all_private_card_nodes(self) :
    """ Get an array of all game states corresponding
    to all the outcomes of the first chance node
    and their legal actions (all the root states have the
    same legal actions.)
    """
    #We are doing a full tree traversal,
    # so PRNG key does not matter
    outcomes = []
    dummy_key = jax.random.key(0)    
    state, legals = self.initialize_structures(dummy_key)
    for c1 in range(6):
      for c2 in range(6):
        if c1 == c2:
          continue
        private_cards = jnp.array([c1, c2], dtype=int)
        new_game_state = LeducGameState(
                          action_history = state.action_history,
                          public_card = state.public_card,
                          private_cards = private_cards,
                          current_chips = state.current_chips,
                          turns_this_round = state.turns_this_round,
                          terminal = state.terminal
        )
        outcomes.append(new_game_state)
    return outcomes, legals
  
  def generate_all_public_card_nodes(self, state):
    """Return a list of all states correspoding to the 
     second chance node outcomes. Requires to be given a state that was
     already generated by apply action (eg. is past the chance node)"""
    outcomes = []
    #6 public cards
    for pc in range(6):
      if pc == state.private_cards[0] or pc == state.private_cards[1]:
        continue
      new_game_state = LeducGameState(action_history = state.action_history,
                                      current_chips = state.current_chips,
                                      private_cards = state.private_cards,
                                      public_card = jnp.array([pc + 1]),
                                      turns_this_round = state.turns_this_round,
                                      terminal = state.terminal
                                      )
      outcomes.append(new_game_state)
    return outcomes
  
  def generate_pc_nodes_and_mask(self, state:LeducGameState):
    """Generate all public card nodes 
    even those that are not possible and
    return a validity mask of them.
    The sampled outcome is always first on the list,
    but will be generated in the list again,
    so the shapes are one more than the chance node outcomes.
    Used for vmap."""
    outcomes = []
    valid = jnp.ones(self.total_cards)
    p1_card_oh = jax.nn.one_hot(state.private_cards[0], self.total_cards)
    p2_card_oh = jax.nn.one_hot(state.private_cards[1], self.total_cards)
    #curr_state_pc_oh = jax.nn.one_hot(state.public_card[0], self.total_cards + 1)
    valid = valid - p1_card_oh - p2_card_oh #- curr_state_pc_oh
    #outcomes = [state]
    for pc in range(self.total_cards):
      new_game_state = LeducGameState(action_history = state.action_history,
                                      current_chips = state.current_chips,
                                      private_cards = state.private_cards,
                                      public_card = jnp.array([pc + 1]),
                                      turns_this_round = state.turns_this_round,
                                      terminal = state.terminal
                                      )
      outcomes.append(new_game_state)
    return outcomes, valid
       
  
  @functools.partial(jax.jit, static_argnums=(0))
  def initialize_structures(self, key):
    """
    !!! If you want to vmap this function, send in a vector of PRNG keys !!! 
    """
    fold_oh = jax.nn.one_hot(FOLD_ID, self.num_actions)
    starting_action_mask = jnp.ones(self.num_actions) - self.invalid_action_mask - fold_oh
    #We assume that player 1 is the starting one      
    p1_legal_mask = starting_action_mask
    p2_legal_mask = self.invalid_action_mask
    current_chips = jnp.ones(self.players)
    #[H - 1, A - 1]
    #action history does not contain the last action as that
    #will always take the game to a terminal state
    action_history = jnp.zeros([self.max_turns - 1, self.num_actions - 1])
    cards = jnp.arange(self.total_cards)[..., None]
    p1_private_cards = jnp.repeat(cards, self.total_cards - 1, axis=0)
    #TODO: This can probably be done better.
    p2_private_cards = jnp.concatenate([jnp.r_[0:i:1, i+1:self.total_cards:1] for i in range(self.total_cards)])[..., None]
    private_cards = jnp.concatenate([p1_private_cards, p2_private_cards], axis=1)
    #private_cards = jnp.repeat(jnp.concatenate([p1_private_cards, p2_private_cards], axis=1), [self.public_chance_outcomes, 1])
    public_card = jnp.zeros(1, dtype=int)
    turns_this_round = jnp.zeros(1, dtype=int)
    chosen_cards = jax.random.choice(key, private_cards, axis=0)
    legals = jnp.stack([p1_legal_mask, p2_legal_mask], axis=0)
    game_state = LeducGameState(action_history=action_history,
                            public_card = public_card,
                            private_cards=chosen_cards,
                            current_chips = current_chips,
                            turns_this_round = turns_this_round,
                            terminal= jnp.array(0, dtype=bool))
    return game_state, legals
  
  @functools.partial(jax.jit, static_argnums=(0))
  def get_info(self, game_state:LeducGameState):
    #One additional bit for public card not dealt yet
    public_card_oh = jax.nn.one_hot(game_state.public_card, self.total_cards + 1)
    public_state_tensor = jnp.concatenate([public_card_oh.ravel(), game_state.action_history.ravel()], axis=0)
    private_cards_oh = jax.nn.one_hot(game_state.private_cards, self.total_cards)

    p1_player = jax.nn.one_hot(0, 2)
    
    p1_iset_tensor = jnp.concatenate([p1_player.ravel(), private_cards_oh[0], public_state_tensor], axis=0)
    p2_iset_tensor = jnp.concatenate([1 - p1_player.ravel(), private_cards_oh[1], public_state_tensor], axis=0)

    state_tensor = jnp.concatenate([private_cards_oh.ravel(), public_state_tensor], axis=0)

    return state_tensor, p1_iset_tensor, p2_iset_tensor, public_state_tensor
  
  def reconstruct_state_from_isets(self, p1_iset_tensor, p2_iset_tensor):
    """Return a LeducGameState corresponding to the union of information
      contained within the two isets and legal actions in that state"""
    #
    private_cards = jnp.arange(self.total_cards)
    actions = jnp.arange(self.num_actions - 1)
    public_cards = jnp.arange(self.total_cards + 1)
    p1_priv_card = jnp.sum(p1_iset_tensor[2:self.total_cards + 2] * private_cards)
    p2_priv_card = jnp.sum(p2_iset_tensor[2:self.total_cards + 2] * private_cards)
    fold_oh = jax.nn.one_hot(FOLD_ID, self.num_actions)
    raise_oh = jax.nn.one_hot(RAISE_ID, self.num_actions)
    invalid_legals = jax.nn.one_hot(INVALID_ID, self.num_actions)
    public_card = jnp.sum(p1_iset_tensor[self.total_cards + 2 :2 * self.total_cards + 3] * public_cards)
    start = 2* self.total_cards + 3
    action_history = jnp.reshape(p1_iset_tensor[start: ], (self.max_turns - 1, self.num_actions - 1))
    round = 0
    turns_this_round = 0
    terminal = False
    bets_equal = True
    num_raises = 0
    current_chips = np.array([1, 1])
    @chex.dataclass(frozen=True)
    class TurnInfo:
      current_chips : chex.Array
      turns_this_round: int
      round: int
      bets_equal: bool
      terminal: bool
      num_raises: int
    cur_turn_info = TurnInfo(current_chips= current_chips, 
                             turns_this_round = turns_this_round,
                             round = round, 
                             bets_equal = bets_equal, 
                             terminal = terminal, 
                             num_raises = num_raises)
    state_legals = jnp.array([0, 1, 1, 1])
    def process_turn(turn_info, action):
      current_chips = turn_info.current_chips
      turns_this_round = turn_info.turns_this_round
      round = turn_info.round
      num_raises = turn_info.num_raises
      players = jnp.arange(2)
      action_oh = jax.nn.one_hot(action - 1, self.num_actions -1)
      bets_equal = current_chips[0] == current_chips[1]
      max_chips = jnp.max(current_chips)
      current_player = turns_this_round % 2
      action_chips = jnp.array([current_chips[current_player], max_chips, max_chips + self.raise_amount * (round + 1)])
      player_chips = jnp.sum(action_oh * action_chips)
      current_chips = jnp.where(players == current_player, current_chips, player_chips).astype(dtype=int)
      num_raises = jnp.where(action == RAISE_ID, num_raises + 1, 0)
      is_chance = jnp.logical_and(turns_this_round > 0, action == CALL_ID)
      #is_chance = jnp.logical_or(jnp.logical_and(turns_this_round > 0, action == CALL_ID), num_raises == 2)
      turns_this_round = jnp.where(is_chance, -1, turns_this_round)
      round = jnp.where(is_chance, round + 1, round)
      turns_this_round += 1
      #jax.debug.breakpoint()
      #breakpoint()
      terminal = jnp.where(jnp.logical_or(round == 2, action == FOLD_ID), True, False)
      new_turn_info = TurnInfo(current_chips= current_chips, 
                             turns_this_round = turns_this_round,
                             round = round, 
                             bets_equal = bets_equal, 
                             terminal = terminal, 
                             num_raises = num_raises)
      return new_turn_info
    def skip_turn(turn_info, action):
      return turn_info
    for i in range(action_history.shape[0]):
      action = jnp.sum(action_history[i] * actions)
      cur_turn_info = jax.lax.cond(action == 0, skip_turn, process_turn, cur_turn_info, action + 1)
      
    state_legals  = jnp.where(bets_equal, state_legals - fold_oh, state_legals)
    state_legals = jnp.where(num_raises == self.max_raises_per_round, state_legals - raise_oh, state_legals)
    current_player = cur_turn_info.turns_this_round % 2
    players = jnp.stack([0, 1])[..., None]
    legals = jnp.where(players == current_player, state_legals[None, ...], invalid_legals[None, ...])
    game_state = LeducGameState(action_history=jnp.array(action_history),
                            public_card = jnp.array([public_card]),
                            private_cards=jnp.array([p1_priv_card, p2_priv_card]),
                            current_chips = jnp.array(cur_turn_info.current_chips),
                            turns_this_round = jnp.array([cur_turn_info.turns_this_round]),
                            terminal= jnp.array(cur_turn_info.terminal))
    return game_state, legals


  
  @functools.partial(jax.jit, static_argnums=(0))
  def apply_action(self, game_state : LeducGameState , key, turn, actions):
    oh_actions = jax.nn.one_hot(actions, self.num_actions)
    #action history does not contain the last action as that
    #will always take the game to a terminal state
    oh_turn = jax.nn.one_hot(turn, self.max_turns - 1)
    fold_oh = jax.nn.one_hot(FOLD_ID, self.num_actions)
    raise_oh = jax.nn.one_hot(RAISE_ID, self.num_actions)
    
    #We assume that player 1 is the starting one
    current_player = game_state.turns_this_round % 2
    max_chips = jnp.max(game_state.current_chips)
    oh_valid_action = jax.nn.one_hot(actions[current_player] - 1, self.num_actions - 1)

    #Integer division by 2 places cards into the [J1, J2], [Q1, Q2], [K1, K2] buckets
    player_card_types = jnp.floor_divide(game_state.private_cards, 2)
    possible_public_card_mask = 1 - jnp.sum(jax.nn.one_hot(game_state.private_cards, self.total_cards), axis=0)
    possible_public_cards = jnp.flatnonzero(possible_public_card_mask, size=self.total_cards - 2)
    round = game_state.public_card > 0

    
    folded = jnp.any(oh_actions[current_player] * fold_oh)
    raised = jnp.sum(oh_actions * raise_oh, axis=1)

    tie = jnp.all(jnp.isclose(player_card_types[0], player_card_types[1]))
    card_matched = jnp.any(jnp.floor_divide(game_state.public_card - 1, 2) == player_card_types)
    winner = jnp.where(card_matched, jnp.argmin(jnp.abs(jnp.floor_divide(game_state.public_card - 1, 2) - player_card_types)), jnp.argmax(player_card_types))
    winner = jnp.where(folded, 1 - current_player, winner)

    this_turn_played = oh_valid_action[..., None, :] * oh_turn[None, :, None]
    action_history = (game_state.action_history + this_turn_played)[0]

    #Taking advantage of the fact, that raises can only happen after each other
    num_raises = jnp.where(turn > 0, action_history[turn - 1, RAISE_ID - 1] + raised[current_player], 0)

    action_chips = jnp.concatenate([jnp.repeat(game_state.current_chips[..., None], 2, axis =1), jnp.array([max_chips, max_chips])[..., None] * jnp.ones(2)], axis=1)
    current_chips = jnp.sum(action_chips * oh_actions, axis=1)
    current_chips = current_chips + (raised * (round + 1) * self.raise_amount)
    
    bets_equal = jnp.all(jnp.isclose(current_chips[0], current_chips[1]))
    play_chance = jnp.logical_and(jnp.logical_and(turn >= 1, round == 0), bets_equal) 
    public_card = jnp.where(play_chance, jax.random.choice(key, possible_public_cards) + 1, game_state.public_card)
    #make sure to properly reset to ready for the new round
    turns_this_round = jnp.where(play_chance, -1, game_state.turns_this_round)
    next_player = (turns_this_round + 1) % 2
    
    new_acting_legals = jnp.ones(self.num_actions) - self.invalid_action_mask
    #whether raise is still possible
    new_acting_legals = jnp.where(num_raises < self.max_raises_per_round, new_acting_legals, new_acting_legals - raise_oh)
    #whether fold is possible
    new_acting_legals = jnp.where(bets_equal, new_acting_legals - fold_oh, new_acting_legals)
    new_legals = jnp.where(next_player == 0, jnp.stack([new_acting_legals, self.invalid_action_mask], axis=0), jnp.stack([self.invalid_action_mask, new_acting_legals], axis=0))

   

    terminal = jnp.logical_or(folded, jnp.logical_and(jnp.logical_and(round > 0, turns_this_round >= 1), bets_equal))
    terminal = jnp.squeeze(terminal)
    
  
    #The division by self.max_bet_amount to make sure the rewards is normalized to [-1, 1] range
    reward = jnp.where(terminal, jnp.where(jnp.logical_and(tie, ~folded), 0, ((1 - 2 * winner) * current_chips[1-winner]) / self.max_bet_amount), 0)

    #If terminal state was already reached, mark this as terminal as well
    terminal = jnp.logical_or(game_state.terminal, terminal)

    new_game_state = LeducGameState(action_history=action_history,
                           public_card = public_card,
                           private_cards=game_state.private_cards,
                           current_chips = current_chips,
                           turns_this_round = turns_this_round + 1,
                           terminal = terminal)

    return new_game_state, terminal, reward[0], new_legals