import numpy as np

from games.graphon_mfg import FiniteGraphonMeanFieldGame
from simulator.mean_fields.base import MeanField
from solver.omd_graphon_solver import DiscretizedGraphonExactOMDSolverFinite, compute_values
from solver.policy.finite_policy import FiniteFeedbackPolicy, QMaxPolicy, QSoftMaxPolicy
from solver.policy.graphon_policy import DiscretizedGraphonFeedbackPolicy


class DiscretizedGraphonExactSolverFinite(DiscretizedGraphonExactOMDSolverFinite):
    """
    Exact solutions for finite state spaces
    """

    def __init__(self, eta=0,  **kwargs):
        super().__init__(eta=eta, **kwargs)

    def get_policy(self,
                   mfg: FiniteGraphonMeanFieldGame,
                   Qs: np.ndarray):
        if self.eta != 0:
            return QSoftMaxPolicy(mfg.agent_observation_space,
                                  mfg.agent_action_space,
                                  Qs,
                                  1 / self.eta)
        else:
            return QMaxPolicy(mfg.agent_observation_space,
                              mfg.agent_action_space,
                              Qs)

    def solve(self,
              mfg: FiniteGraphonMeanFieldGame,
              mu: MeanField,
              pi: FiniteFeedbackPolicy,
              **kwargs):
        Qs, _ = compute_values(mfg,mu,pi)
        policy = DiscretizedGraphonFeedbackPolicy(mfg.agent_observation_space,
                                                  mfg.agent_action_space,
                                                  self.get_policy(mfg,Qs)
                                                  )

        return policy, {"Q": Qs}
