from collections.abc import Hashable
from dataclasses import dataclass

from src.searchlight.headers import *



@dataclass(frozen=True)
class HiddenState:
    player_0_played_cards: tuple[int] # cards played by player 0
    player_1_played_cards: tuple[int] # cards played by player 1
    played_prize_cards: tuple[int] # cards played as prizes
    prize_deck: frozenset[int] # cards that are still in the prize deck
    player_0_hand: frozenset[int] # cards in player 0's hand
    player_1_hand: frozenset[int] # cards in player 1's hand
    player_0_cumulative_score: int # player 0's cumulative score
    player_1_cumulative_score: int # player 1's cumulative score
    is_environment_turn: bool # whether it is the environment's turn to draw a prize card or the players' turn to each play a card
    contested_points: int # points that are contested from previous rounds
    player_0_pending_card: Optional[int] = None # pending card played by player 0 (hidden information)
    player_1_pending_card: Optional[int] = None # pending card played by player 1 (hidden information)

    @staticmethod
    def init_from_num_cards(num_cards: int):
        return HiddenState(tuple(), tuple(), tuple(), frozenset(range(1, num_cards+1)), frozenset(range(1, num_cards+1)), frozenset(range(1, num_cards+1)), 0, 0, True, 0)
    
@dataclass(frozen=True)
class ObservedState:
    '''
    We assume cards in deck and hand are list(range(1, num_cards+1))
    '''
    player_0_played_cards: tuple[int,...] # cards played by player 0
    player_1_played_cards: tuple[int,...] # cards played by player 1
    played_prize_cards: tuple[int,...] # cards played as prizes
    prize_deck: frozenset[int] # cards that are still in the prize deck
    player_0_hand: frozenset[int] # cards in player 0's hand
    player_1_hand: frozenset[int] # cards in player 1's hand
    player_0_cumulative_score: int # player 0's cumulative score
    player_1_cumulative_score: int # player 1's cumulative score
    is_environment_turn: bool # whether it is the environment's turn to draw a prize card or the players' turn to each play a card
    contested_points: int # points that are contested from previous rounds

class GOPSForwardTransitor(ForwardTransitor):

    def _transition(self, state: HiddenState, action: int, actor: int) ->tuple[HiddenState, dict[int, float]]:  # type: ignore
        reward = {0: 0.0, 1: 0.0}

        if state.is_environment_turn: # random state
            # make sure the action is in the prize deck
            assert action in state.prize_deck
            assert actor == -1
            # append the action to the prize cards
            prize_card = action
            new_prize_deck = state.prize_deck - {prize_card}
            new_played_prize_cards: tuple[int,...] = state.played_prize_cards + (prize_card,)
            new_is_environment_turn = False
            return HiddenState(state.player_0_played_cards, state.player_1_played_cards, new_played_prize_cards, new_prize_deck, state.player_0_hand, state.player_1_hand, state.player_0_cumulative_score, state.player_1_cumulative_score, new_is_environment_turn, state.contested_points), reward
        else: # simultaneous state 
            acting_player = actor
            if acting_player == 0:
                new_player_0_hand = state.player_0_hand - {action}
                new_player_1_hand = state.player_1_hand
                new_player_0_played_cards = state.player_0_played_cards + (action,)
                new_player_0_pending_card = action
                new_player_1_pending_card = state.player_1_pending_card
            elif acting_player == 1:
                new_player_1_hand = state.player_1_hand - {action}
                new_player_0_hand = state.player_0_hand
                new_player_1_played_cards = state.player_1_played_cards + (action,)
                new_player_0_pending_card = state.player_0_pending_card
                new_player_1_pending_card = action
            else:
                raise ValueError('Invalid actor: '+str(acting_player))
            
            # see if both players have played (i.e. both pending cards are not None)
            if (new_player_0_pending_card is not None) and (new_player_1_pending_card is not None):
                # if so, we need to update the scores
                new_contested_points = state.contested_points + state.played_prize_cards[-1]
                new_player_0_score = state.player_0_cumulative_score
                new_player_1_score = state.player_1_cumulative_score
                # add simultaneous actions to the played cards
                new_player_0_played_cards = state.player_0_played_cards + (new_player_0_pending_card,)
                new_player_1_played_cards = state.player_1_played_cards + (new_player_1_pending_card,)

                # see who won the round
                if new_player_0_pending_card > new_player_1_pending_card:
                    new_player_0_score += new_contested_points
                elif new_player_0_pending_card < new_player_1_pending_card:
                    new_player_1_score += new_contested_points

                # reset the pending cards
                new_player_0_pending_card = None
                new_player_1_pending_card = None
                new_is_environment_turn = True
                return HiddenState(new_player_0_played_cards, new_player_1_played_cards, state.played_prize_cards, state.prize_deck, new_player_0_hand, new_player_1_hand, new_player_0_score, new_player_1_score, new_is_environment_turn, new_contested_points), reward
            else:
                return HiddenState(state.player_0_played_cards, state.player_1_played_cards, state.played_prize_cards, state.prize_deck, new_player_0_hand, new_player_1_hand, state.player_0_cumulative_score, state.player_1_cumulative_score, state.is_environment_turn, state.contested_points), reward

class GOPSActorActionEnumerator(ActorActionEnumerator):

    def __init__(self, default_player_order: tuple[int, ...] = (0,1)):
        super().__init__(player_order=default_player_order)

    def _enumerate(self, state: HiddenState) -> tuple[int | None, frozenset]: # type: ignore
        # if both hands are empty, the game is over, return None
        if (len(state.player_0_hand) == 0) and (len(state.player_1_hand) == 0):
            return None, frozenset()
        elif state.is_environment_turn:
            return -1, state.prize_deck
        else:
            # return whoever is first in the player order and has no pending card
            for actor in self.player_order:
                if actor == 0:
                    if state.player_0_pending_card is None:
                        return actor, state.player_0_hand
                elif actor == 1:
                    if state.player_1_pending_card is None:
                        return actor, state.player_1_hand
                else:
                    raise ValueError('Invalid actor: '+str(actor))

    @staticmethod
    def parse_str_to_action(action_str: str) -> int:
        return int(action_str)
    
class GOPSInformationFunction(InformationFunction):

    def _get_information_set(self, state: HiddenState, actor: int) -> ObservedState: # type: ignore
        '''
        Returns the information set for the given actor (where nobody has played yet)
        '''
        if actor == 0:
            return ObservedState(state.player_0_played_cards, state.player_1_played_cards, state.played_prize_cards, state.prize_deck, state.player_0_hand, frozenset(), state.player_0_cumulative_score, state.player_1_cumulative_score, state.is_environment_turn, state.contested_points)
        elif actor == 1:
            return ObservedState(state.player_0_played_cards, state.player_1_played_cards, state.played_prize_cards, state.prize_deck, frozenset(), state.player_1_hand, state.player_0_cumulative_score, state.player_1_cumulative_score, state.is_environment_turn, state.contested_points)
        else:
            raise ValueError('Invalid actor: '+str(actor))
    
ground_truth_models = {"transitor": GOPSForwardTransitor(), "actor_action_enumerator": GOPSActorActionEnumerator(), "start_state": HiddenState.init_from_num_cards(6), "information_function": GOPSInformationFunction()}