""":mod:``gpugt.games.exploitability` defines exploitability."""

from collections.abc import Callable, Hashable
from dataclasses import dataclass, field
from typing import Generic, TypeVar
from statistics import mean

from gpugt.algorithms.best_response import BestResponse
from gpugt.algorithms.expected_payoffs import ExpectedPayoffs
from gpugt.games.finite_extensive_form_game import FiniteExtensiveFormGame

_V = TypeVar('_V', bound=Hashable)
_H = TypeVar('_H', bound=Hashable)
_A = TypeVar('_A', bound=Hashable)
_I = TypeVar('_I', bound=Hashable)


@dataclass
class Exploitability(Generic[_V, _H, _A, _I]):
    """An implementation of exploitability calculation.

    :param game: The finite extensive-form game.
    :param strategy_profile: The strategy profile.
    """

    game: FiniteExtensiveFormGame[_V, _H, _A, _I]
    strategy_profile: Callable[[_V, _A], float]
    expected_payoffs: ExpectedPayoffs[_V, _H, _A, _I] = field(init=False)
    best_responses: dict[_I, BestResponse[_V, _H, _A, _I]] = field(
        init=False,
        default_factory=dict,
    )
    exploitabilities: dict[_I, float] = field(init=False, default_factory=dict)
    exploitability: float = field(init=False)

    def __post_init__(self) -> None:
        self.expected_payoffs = ExpectedPayoffs(
            self.game,
            self.strategy_profile,
        )

        for player in self.game.players:
            if player == self.game.nature:
                continue

            best_response = BestResponse(
                self.game,
                self.strategy_profile,
                player,
            )
            self.best_responses[player] = best_response
            self.exploitabilities[player] = (
                best_response.get_expected_payoff(self.game.initial_node)
                - self.expected_payoffs(self.game.initial_node, player)
            )

        self.exploitability = mean(self.exploitabilities.values())
