#!/usr/bin/env python3

from functools import partial
from typing import Any, Callable, NamedTuple

import chex
import equinox as eqx
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
from flax.training.train_state import TrainState

ILLEGAL_ACTION = -1
DEFAULT_NUM_CARDS = 5
DEFAULT_NUM_PLAYERS = 2


class IIGoofspielObservation(eqx.Module):
    """Dataclass to hold II Goofpsiel observations"""

    winner: jnp.uint8  # type: ignore
    prize_card: jnp.int32  # type: ignore
    next_prize_card: jnp.int32  # type: ignore

    def __str__(self):
        return f"{self.winner} {self.prize_card} {self.next_prize_card}"


class IIGoofspiel(eqx.Module):
    num_cards: int
    num_players: int

    def __init__(
        self, num_players=DEFAULT_NUM_PLAYERS, num_cards=DEFAULT_NUM_CARDS, **kwargs
    ):
        self.num_players = num_players
        self.num_cards = num_cards

    def name(self):
        return f"ii_goofspiel_{self.num_players}_{self.num_cards}"

    @eqx.filter_jit
    def max_utility(self):
        return jnp.sum(jnp.arange(self.num_cards - 1))

    @eqx.filter_jit
    def step(self, state, action, key):
        return self.next_state(state, (*action, self.sample_prize_card(state, key)))

    @partial(jax.jit, static_argnums=(0,))
    def next_state(self, state, joint_action):
        # assumes we have the next prize card already selected
        joint_action = jnp.array(joint_action).astype(jnp.int32)
        action = joint_action[:-1]
        prize_card = joint_action[-1]
        obs = self.observation(state, action, prize_card)
        step_winnings = state.prize_card * jax.nn.one_hot(
            obs.winner, state.hands.shape[0], dtype=jnp.float16
        )
        prize_deck = state.prize_deck - jax.nn.one_hot(
            obs.next_prize_card, state.prize_deck.shape[-1], dtype=jnp.uint8
        )
        hands = state.hands - jax.nn.one_hot(
            action, state.hands.shape[-1], dtype=jnp.uint8
        )
        new_state = IIGoofspielState(
            hands,
            prize_deck,
            state.winnings + step_winnings,
            obs.next_prize_card,
            state.round + 1,
            state.action_history.at[state.round].set(joint_action),
        )
        return new_state, obs, self.reward(state, action), new_state.is_terminal()

    @eqx.filter_jit
    def sample_prize_card(self, state, key):
        mask = state.prize_deck
        legal = jnp.where(jnp.reshape(mask, (-1)), size=state.prize_deck.shape[0])[0]
        n = self.num_rounds() - state.round
        idx = jax.random.randint(key, shape=(), minval=0, maxval=n)
        return jax.lax.cond(jnp.sum(mask) > 0, lambda: legal[idx], lambda: -1)

    @eqx.filter_jit
    def reward(self, state, action):
        w = self.winner(action)
        p = jax.lax.cond(w >= 0, lambda: state.prize_card, lambda: 0)
        # other players split negative reward evenly
        n = state.hands.shape[0]
        r = jnp.zeros(n) - 1 / (n - 1)
        r = r.at[w].set(1.0)
        return r * p

    @eqx.filter_jit
    def winner(self, action):
        w = jnp.argmax(action)
        max_val = action[w]
        return jax.lax.cond(jnp.sum(action == max_val) == 1, lambda: w, lambda: -1)

    @eqx.filter_jit
    def observation(self, state, action, next_prize_card):
        return IIGoofspielObservation(
            self.winner(action), state.prize_card, next_prize_card
        )

    @eqx.filter_jit
    def reset(self, key):
        root = self.initial_state(-1)
        s, obs, _, _ = self.step(root, jnp.ones(self.num_players) * ILLEGAL_ACTION, key)
        return s, obs

    @eqx.filter_jit
    def initial_state(self, prize_card=-1):
        return IIGoofspielState(
            jnp.ones((self.num_players, self.num_cards), dtype=jnp.uint8),
            jnp.ones(self.num_cards, jnp.uint8),
            jnp.zeros(self.num_players, dtype=jnp.float32),
            prize_card,
            0,
            jnp.zeros((self.num_cards + 1, 3), dtype=jnp.int32) - 1,
        )

    @eqx.filter_jit
    def num_infostate_features(self):
        return self.num_cards * 4 + self.num_players

    @eqx.filter_jit
    def num_actions(self):
        return self.num_cards

    @eqx.filter_jit
    def num_rounds(self):
        return self.num_cards

    @eqx.filter_jit
    def num_rounds_remaining(self, state):
        return self.num_rounds() - state.round


class IIGoofspielInfostate(NamedTuple):
    hand: jax.Array
    prize_deck: jax.Array
    prize_card: int
    winnings: jax.Array
    round: int


class IIGoofspielState(eqx.Module):
    hands: jnp.array
    prize_deck: jnp.array
    winnings: jnp.array
    prize_card: jnp.int32  # type: ignore
    round: jnp.uint8  # type: ignore
    action_history: jax.Array

    def is_terminal(self):
        return self.round > self.hands.shape[-1]

    @eqx.filter_jit
    def utility(self):
        return self.winnings - jnp.flip(self.winnings)

    @partial(jax.jit, static_argnums=(2,))
    def legal_actions(self, player_id, round):
        n, k = self.hands.shape
        legal = jnp.array(
            jnp.nonzero(self.hands, size=n * (k - round + 1), fill_value=-1)[-1]
        )
        return jnp.reshape(legal, (n, -1))[player_id]

    @eqx.filter_jit
    def legal_actions_k_hot(self, player_id):
        return self.hands[player_id]

    @eqx.filter_jit
    def infostate_features(self, player_id):
        return jnp.concatenate(
            (
                self.hands[player_id],
                self.prize_deck,
                jax.nn.one_hot(self.prize_card, self.prize_deck.shape[-1]),
                self.winnings,
                jax.nn.one_hot(self.round, self.prize_deck.shape[-1]),
            )
        )

    @eqx.filter_jit
    def infostate(self, player_id):
        return IIGoofspielInfostate(
            self.hands[player_id],
            self.prize_deck,
            self.prize_card,
            self.winnings,
            self.round,
        )

    def infostate_string(self, player_id):
        return jnp.array_str(self.infostate_features(player_id))

    def __str__(self):
        strings = []
        strings.append(f"hands:\n{self.hands}")
        strings.append(f"prize deck: {self.prize_deck}")
        strings.append(f"prize card: {self.prize_card}")
        strings.append(f"winnings: {self.winnings}")
        strings.append(f"action hist: {self.action_history}")
        return "\n".join(strings)


class GoofspielGymEnv(gym.Env):
    metadata = {'jax': True}

    def __init__(self, num_cards: int, opponent: TrainState | Callable | None, seed: int) -> None:
        super().__init__()

        # Initialize observation and action spaces
        self.observation_space = gym.spaces.MultiBinary(4 * num_cards + 2, seed=seed)
        self.action_space = gym.spaces.Discrete(num_cards, seed=seed)

        # Initialize two-player Goofspiel with `num_cards` cards
        self._env = IIGoofspiel(2, num_cards)
        self._opponent = opponent if isinstance(opponent, Callable | None) else \
            partial(opponent.apply_fn, opponent.params)

        self._key = jax.random.PRNGKey(seed)
        self._state = None

    def action_masks(self) -> jnp.ndarray:
        return self._state.legal_actions_k_hot(0)

    def step(self, action: int) -> tuple[jnp.ndarray, float, bool, bool, dict[str, Any]]:
        opponent_action = self._opponent_predict(
            self._state.infostate_features(1), self._state.legal_actions_k_hot(1)
        )

        # Make sure that both players execute valid actions. Comment this out
        # during real training as it decreases the performance considerably.
        # assert action in np.nonzero(self._state.legal_actions_k_hot(0))[0], \
        #     f'{action} not in {np.nonzero(self._state.legal_actions_k_hot(0))[0]}'
        # assert opponent_action in np.nonzero(self._state.legal_actions_k_hot(1))[0], \
        #     f'{opponent_action} not in {np.nonzero(self._state.legal_actions_k_hot(1))[0]}'

        self._key, step_key = jax.random.split(self._key, 2)
        self._state, _, reward, done = self._env.step(
            self._state, [action, opponent_action], step_key
        )

        return self._state.infostate_features(0), float(reward[0]), done, False, {}

    def reset(
        self, *, seed: int | None = None, options: dict[str, Any] | None = None
    ) -> tuple[jnp.ndarray,  dict[str, Any]]:
        if seed is not None:
            self.observation_space.seed(seed)
            self.action_space.seed(seed)
            self._key = jax.random.PRNGKey(seed)

        self._key, reset_key = jax.random.split(self._key, 2)
        self._state, _ = self._env.reset(reset_key)

        return self._state.infostate_features(0), {}

    def _opponent_predict(self, obs: jnp.ndarray, mask: jnp.ndarray) -> int:
        if self._opponent is None:
            return np.random.choice(np.nonzero(mask)[0])

        self._key, sample_key = jax.random.split(self._key, 2)
        dist = self._opponent(obs[jnp.newaxis, :])[0].logits
        dist = jnp.where(mask > 0, dist, -jnp.inf)

        return jax.random.categorical(sample_key, dist)
