#!/usr/bin/env python3

from collections.abc import Callable
from functools import partial
from typing import Any

import chex
import equinox as eqx
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
from flax.training.train_state import TrainState

ILLEGAL_ACTION = -1


class IIGoofspielInfostate(eqx.Module):
    hand: jax.Array
    prize_deck: jax.Array
    winnings: jax.Array
    prize_card: jnp.int32
    round: jnp.uint8


class IIGoofspielState(eqx.Module):
    hands: jax.Array
    prize_deck: jax.Array
    winnings: jax.Array
    prize_card: jnp.int32
    round: jnp.uint8
    action_history: jax.Array

    @eqx.filter_jit
    def utility(self) -> jax.Array:
        return self.winnings - jnp.flip(self.winnings)

    @eqx.filter_jit
    def is_terminal(self) -> bool:
        return self.round > self.hands.shape[-1]

    @eqx.filter_jit
    def legal_actions(self, player_id: int, round: int) -> jax.Array:
        n, k = self.hands.shape
        legal = jnp.array(jnp.nonzero(self.hands, size=n * (k - round + 1), fill_value=-1)[-1])
        return jnp.reshape(legal, (n, -1))[player_id]

    @eqx.filter_jit
    def legal_actions_k_hot(self, player_id: int) -> jax.Array:
        return self.hands[player_id]

    @eqx.filter_jit
    def infostate_features(self, player_id: int) -> jax.Array:
        return jnp.concatenate(
            (
                self.hands[player_id],
                self.prize_deck,
                jax.nn.one_hot(self.prize_card, self.prize_deck.shape[-1]),
                self.winnings,
                jax.nn.one_hot(self.round, self.prize_deck.shape[-1]),
            )
        )

    @eqx.filter_jit
    def infostate(self, player_id: int) -> IIGoofspielInfostate:
        return IIGoofspielInfostate(
            self.hands[player_id],
            self.prize_deck,
            self.winnings,
            self.prize_card,
            self.round,
        )

    def infostate_string(self, player_id: int) -> str:
        return jnp.array_str(self.infostate_features(player_id))


class IIGoofspielObservation(eqx.Module):
    winner: jnp.uint8
    prize_card: jnp.int32
    next_prize_card: jnp.int32


class IIGoofspiel(eqx.Module):
    num_cards: int
    num_players: int

    def __init__(self, num_players: int = 2, num_cards: int = 5) -> None:
        self.num_players = num_players
        self.num_cards = num_cards

    @eqx.filter_jit
    def initial_state(self, prize_card: int = -1) -> IIGoofspielState:
        return IIGoofspielState(
            jnp.ones((self.num_players, self.num_cards), jnp.uint8),
            jnp.ones(self.num_cards, jnp.uint8),
            jnp.zeros(self.num_players, jnp.float32),
            prize_card,
            0,
            jnp.zeros((self.num_cards + 1, 3), jnp.int32) - 1,
        )

    @eqx.filter_jit
    def step(
        self, state: IIGoofspielState, action: int, key: jax.Array
    ) -> tuple[IIGoofspielState, IIGoofspielObservation, jax.Array, bool]:
        return self.next_state(state, (*action, self.sample_prize_card(state, key)))

    @eqx.filter_jit
    def sample_prize_card(self, state: IIGoofspielState, key: jax.Array) -> jax.Array:
        mask = state.prize_deck
        legal = jnp.where(jnp.reshape(mask, (-1)), size=state.prize_deck.shape[0])[0]
        n = self.num_rounds() - state.round
        idx = jax.random.randint(key, shape=(), minval=0, maxval=n)
        return jax.lax.cond(jnp.sum(mask) > 0, lambda: legal[idx], lambda: jnp.int32(-1))

    @eqx.filter_jit
    def next_state(self, state, joint_action):
        # Assumes we have the next prize card already selected
        joint_action = jnp.array(joint_action).astype(jnp.int32)
        action = joint_action[:-1]
        prize_card = joint_action[-1]
        obs = self.observation(state, action, prize_card)
        step_winnings = state.prize_card * jax.nn.one_hot(
            obs.winner, state.hands.shape[0], dtype=jnp.float16
        )
        prize_deck = state.prize_deck - jax.nn.one_hot(
            obs.next_prize_card, state.prize_deck.shape[-1], dtype=jnp.uint8
        )
        hands = state.hands - jax.nn.one_hot(action, state.hands.shape[-1], dtype=jnp.uint8)
        new_state = IIGoofspielState(
            hands,
            prize_deck,
            state.winnings + step_winnings,
            obs.next_prize_card,
            state.round + 1,
            state.action_history.at[state.round].set(joint_action),
        )
        return new_state, obs, self.reward(state, action), new_state.is_terminal()

    @eqx.filter_jit
    def observation(
        self, state: IIGoofspielState, action: jax.Array, next_prize_card: int
    ) -> IIGoofspielObservation:
        return IIGoofspielObservation(self.winner(action), state.prize_card, next_prize_card)

    @eqx.filter_jit
    def reward(self, state: IIGoofspielState, action: jax.Array) -> jax.Array:
        w = self.winner(action)
        p = jax.lax.cond(w >= 0, lambda: state.prize_card, lambda: 0)

        # Other players split negative reward evenly
        n = state.hands.shape[0]
        r = jnp.zeros(n) - 1 / (n - 1)
        r = r.at[w].set(1.0)
        return r * p

    @eqx.filter_jit
    def winner(self, action: jax.Array) -> int:
        w = jnp.argmax(action)
        max_val = action[w]
        return jax.lax.cond(jnp.sum(action == max_val) == 1, lambda: w, lambda: -1)

    @eqx.filter_jit
    def reset(self, key: jax.Array) -> tuple[IIGoofspielState, IIGoofspielObservation]:
        root = self.initial_state(-1)
        s, obs, _, _ = self.step(root, jnp.ones(self.num_players) * ILLEGAL_ACTION, key)
        return s, obs

    @eqx.filter_jit
    def num_actions(self) -> int:
        return self.num_cards

    @eqx.filter_jit
    def num_rounds(self) -> int:
        return self.num_cards

    @eqx.filter_jit
    def num_remaining_rounds(self, state: IIGoofspielState) -> int:
        return self.num_rounds() - state.round

    @eqx.filter_jit
    def num_infostate_features(self) -> int:
        return self.num_cards * 4 + self.num_players

    @eqx.filter_jit
    def max_utility(self) -> float:
        return jnp.sum(jnp.arange(self.num_cards - 1))

    @eqx.filter_jit
    def construct_mc_initial_state(
        self, state: IIGoofspielState, player_id: int, round: int
    ) -> IIGoofspielState:
        chex.assert_equal(self.num_players, 2)
        action_hist = state.action_history
        opp_id = (1 - player_id) % self.num_players
        action_indices = jnp.array([player_id, opp_id, 2])
        new_state = self.initial_state()
        new_state = self.next_state(new_state, action_hist[0])[0]

        def _win_or_loss_mask(action):
            winner = self.winner(action)
            player_a = action[player_id]
            win_actions = jnp.where(jnp.arange(self.num_actions()) < player_a, 1, 0)
            loss_actions = jnp.where(jnp.arange(self.num_actions()) > player_a, 1, 0)
            return jax.lax.cond(
                winner < 0,
                lambda: jax.nn.one_hot(player_a, self.num_actions(), dtype=jnp.int32),
                lambda: jax.lax.cond(
                    winner == player_id, lambda: win_actions, lambda: loss_actions
                ),
            )

        def _most_constrained_action(win_loss_mask):
            # Take the argmin of rows (timesteps) with greater than 0 assignable actions
            row = jnp.argmin(
                jnp.where(
                    jnp.sum(win_loss_mask, axis=1) > 0,
                    jnp.sum(win_loss_mask, axis=1),
                    jnp.inf,
                )
            )

            # Take the most constrained card for that timestep
            col = jnp.argmin(
                jnp.where(
                    (win_loss_mask[row] * jnp.sum(win_loss_mask, axis=0)) > 0,
                    win_loss_mask[row] * jnp.sum(win_loss_mask, axis=0),
                    jnp.inf,
                )
            )

            return row + 1, col

        def _step(i, args):
            mask, new_action_hist = args

            # Find most constrained action in history (with more than 0 possible assignments)
            t, opp_a = _most_constrained_action(mask)
            player_a = action_hist[t, player_id]
            prize = action_hist[t, -1]
            a = jnp.take(jnp.array([player_a, opp_a, prize]), action_indices)

            # Assign the action
            new_action_hist = new_action_hist.at[t].set(a)

            # Zero out timestep and assigned opponent action for all other timesteps
            mask = mask.at[:, opp_a].set(jnp.zeros_like(mask[:, opp_a]))
            mask = mask.at[t - 1].set(jnp.zeros_like(mask[t - 1]))

            return mask, new_action_hist

        mask = jnp.array(jax.vmap(_win_or_loss_mask)(action_hist[1:, :-1]))
        new_action_hist = jnp.full_like(action_hist, -1, jnp.int32)

        mask, new_action_hist = jax.lax.fori_loop(1, round, _step, (mask, new_action_hist))

        new_state = jax.lax.fori_loop(
            1, round, lambda i, state: self.next_state(state, new_action_hist[i])[0], new_state
        )

        return new_state

    @eqx.filter_jit
    def get_mc_neighbor(
        self, key, state: IIGoofspielState, player_id: int, round: int
    ) -> tuple[IIGoofspielState, float]:
        chex.assert_equal(self.num_players, 2)
        action_hist = state.action_history
        opp_id = (1 - player_id) % self.num_players
        winners = jax.vmap(self.winner)(action_hist[1:, :-1])
        win_or_loss = winners >= 0

        # swap with another and report what happens to the obs sequence
        def _try_swap(chosen_index, candidate_action):
            swapped = action_hist.at[chosen_index, opp_id].set(candidate_action)
            played_index, swap_action = jax.lax.cond(
                jnp.any(candidate_action == action_hist[:, opp_id]),
                lambda: (
                    jnp.argmax(candidate_action == action_hist[:, opp_id]),
                    action_hist[chosen_index, opp_id],
                ),
                lambda: (round + 1, -1),
            )
            swapped = swapped.at[played_index, opp_id].set(swap_action)
            opp_card_add = jax.lax.cond(
                played_index > round,
                lambda: action_hist[chosen_index, opp_id],
                lambda: -1,
            )
            opp_card_remove = jax.lax.cond(
                played_index > round,
                lambda: candidate_action,
                lambda: -1,
            )
            add_remove_mask = jax.nn.one_hot(
                opp_card_add, self.num_cards, dtype=jnp.uint8
            ) - jax.nn.one_hot(opp_card_remove, self.num_cards, dtype=jnp.uint8)
            hands = state.hands.at[opp_id].set(state.hands[opp_id] + add_remove_mask)
            return jax.vmap(self.winner)(swapped[:, :-1]), IIGoofspielState(
                hands,
                state.prize_deck,
                state.winnings,
                state.prize_card,
                state.round,
                swapped,
            )

        def _choose_neighbor():
            key, choice_key = jax.random.split(key, 2)
            choice = jax.random.choice(
                choice_key,
                jnp.arange(1, win_or_loss.shape[0] + 1, dtype=jnp.int32),
                p=win_or_loss / jnp.sum(win_or_loss),
            )
            swapped_winners, states = jax.vmap(_try_swap, in_axes=(None, 0))(
                choice, jnp.arange(self.num_actions(), dtype=jnp.int32)
            )
            # filter out the ones that don't match the obs sequence
            mask = jnp.all(swapped_winners[:, 1:] == winners, axis=1)

            key, choice_key = jax.random.split(key, 2)
            i = jax.random.choice(
                choice_key, jnp.arange(self.num_actions()), p=mask / jnp.sum(mask)
            )

            new_state = IIGoofspielState(
                states.hands[i],
                states.prize_deck[i],
                states.winnings[i],
                states.prize_card[i],
                states.round[i],
                states.action_history[i],
            )

            # The second output is an approximation of the transition
            # probability, a constant might be better?
            return new_state, 1.0 / (jnp.sum(mask) * jnp.sum(win_or_loss))

        return jax.lax.cond(win_or_loss.any(), _choose_neighbor, lambda: (state, 1.0))


class GoofspielGymEnv(gym.Env):
    metadata = {'jax': True}

    def __init__(self, num_cards: int, opponent: TrainState | Callable | None, seed: int) -> None:
        super().__init__()

        # Initialize observation and action spaces
        self.observation_space = gym.spaces.MultiBinary(4 * num_cards + 2, seed=seed)
        self.action_space = gym.spaces.Discrete(num_cards, seed=seed)

        # Initialize two-player Goofspiel with `num_cards` cards
        self._env = IIGoofspiel(2, num_cards)
        self._opponent = (
            opponent
            if isinstance(opponent, Callable | None)
            else partial(opponent.apply_fn, opponent.params)
        )

        self._key = jax.random.key(seed)
        self._state = None

    def action_masks(self) -> jax.Array:
        return self._state.legal_actions_k_hot(0)

    def step(self, action: int) -> tuple[jax.Array, float, bool, bool, dict[str, Any]]:
        opponent_action = self._opponent_predict(
            self._state.infostate_features(1), self._state.legal_actions_k_hot(1)
        )

        # Make sure that both players execute valid actions. Comment this out
        # during real training as it decreases the performance considerably.
        # assert action in np.nonzero(self._state.legal_actions_k_hot(0))[0], \
        #     f'{action} not in {np.nonzero(self._state.legal_actions_k_hot(0))[0]}'
        # assert opponent_action in np.nonzero(self._state.legal_actions_k_hot(1))[0], \
        #     f'{opponent_action} not in {np.nonzero(self._state.legal_actions_k_hot(1))[0]}'

        self._key, step_key = jax.random.split(self._key, 2)
        self._state, _, reward, done = self._env.step(
            self._state, [action, opponent_action], step_key
        )

        return self._state.infostate_features(0), float(reward[0]), done, False, {}

    def reset(
        self, *, seed: int | None = None, options: dict[str, Any] | None = None
    ) -> tuple[jax.Array, dict[str, Any]]:
        if seed is not None:
            self.observation_space.seed(seed)
            self.action_space.seed(seed)
            self._key = jax.random.key(seed)

        self._key, reset_key = jax.random.split(self._key, 2)
        self._state, _ = self._env.reset(reset_key)

        return self._state.infostate_features(0), {}

    def _opponent_predict(self, obs: jax.Array, mask: jax.Array) -> int:
        if self._opponent is None:
            return np.random.choice(np.nonzero(mask)[0])

        self._key, sample_key = jax.random.split(self._key, 2)
        dist = self._opponent(obs[jnp.newaxis, :])[0].logits
        dist = jnp.where(mask > 0, dist, -jnp.inf)

        return jax.random.categorical(sample_key, dist)
