import numpy as np

from evaluator.base import PolicyEvaluator
from games.graphon_mfg import FiniteGraphonMeanFieldGame
from simulator.mean_fields.base import MeanField
from solver.policy.graphon_policy import DiscretizedGraphonFeedbackPolicy
from solver.omd_graphon_solver import compute_values
import pdb


class DiscretizedGraphonEvaluatorFinite(PolicyEvaluator):
    """
    Exact solver for MDP induced by graphon MFG.
    """

    def __init__(self, **kwargs):
        super().__init__()
        self.num_alphas = 1
        self.alpha = .5
        self.alphas = np.array([self.alpha])

    def evaluate(self, mfg: FiniteGraphonMeanFieldGame, mu: MeanField, pi: DiscretizedGraphonFeedbackPolicy):
        _, values = compute_values(mfg,mu,pi)
        eval_mean_returns_alpha = mfg.initial_state_distribution.dist2(self.alpha).probs.numpy() @ values[0]

        eval_mean_returns = np.mean(eval_mean_returns_alpha)

        return dict({
            # "eval_values_pi": values,
            "eval_mean_returns": eval_mean_returns,
            "eval_mean_returns_alpha": eval_mean_returns_alpha,
        })
