import numpy as np
from scipy.special import softmax
from games.graphon_mfg import FiniteGraphonMeanFieldGame
from simulator.mean_fields.base import MeanField
from solver.omd_graphon_solver import DiscretizedGraphonExactOMDSolverFinite
from solver.policy.array_policy import ArrayPolicy
import pdb


def normalize_at_last_axis(x):
    return x / x.sum(axis=-1, keepdims=True)

class FTRLSolver(DiscretizedGraphonExactOMDSolverFinite):
    """
    Exact OMD solutions for finite state spaces
    """

    def update_pi_in_solver(self, pi):
        self.pi_array = pi.policy_array.copy()
        if self.sigma_array is None:
            self.sigma_array = self.pi_array.copy()

    def solve(self,
              mfg: FiniteGraphonMeanFieldGame,
              mu: MeanField,
              pi: ArrayPolicy,
              **kwargs):
        self.update_pi_in_solver(pi)
        Qs = self.compute_Qs(mfg,mu,pi)

        # FTRL implementation
        if self.y is None:
            self.y = self.eta * Qs
        else:
            self.y = self.y + self.eta * Qs

        pi.policy_array = softmax(self.y.squeeze()/self.temperature_param,axis=-1)

        return pi, {"Q": self.y}

class MDSolver(FTRLSolver):
    def compute_exp_eta_Qs(self, Qs):
        return np.exp(self.eta * Qs)

    def solve(self,
              mfg: FiniteGraphonMeanFieldGame,
              mu: MeanField,
              pi: ArrayPolicy,
              **kwargs):
        self.update_pi_in_solver(pi)
        Qs = self.compute_Qs(mfg,mu,pi)
        likelyhood = pi.policy_array * self.compute_exp_eta_Qs(Qs)
        pi.policy_array = normalize_at_last_axis(likelyhood)

        return pi, {"Qs": Qs}
