import numpy as np

from games.graphon_mfg import FiniteGraphonMeanFieldGame
from simulator.mean_fields.base import MeanField
from solver.base import Solver
from solver.policy.finite_policy import FiniteFeedbackPolicy, QSoftMaxPolicy
from solver.policy.graphon_policy import DiscretizedGraphonFeedbackPolicy
import pdb

def compute_values(mfg: FiniteGraphonMeanFieldGame,
                   mu: MeanField,
                   pi: FiniteFeedbackPolicy,
                   alpha=.5,
                    ):
    num_states = mfg.agent_observation_space[1].n
    num_actions = mfg.agent_action_space.n

    Vs = np.zeros((mfg.time_steps, num_states))
    Qs = np.zeros((mfg.time_steps, num_states, num_actions))
    curr_V = np.zeros(num_states)

    for t in reversed(range(mfg.time_steps)):
        # done is now computed only once per time step
        done = t >= mfg.time_steps - 1

        # Compute all rewards at once
        rewards = np.array([[mfg.reward(t, x, u, mu) for u in range(num_actions)] for x in range(num_states)])
        
        # Compute all transition probabilities at once
        transition_probs = np.array([[mfg.transition_probs(t, x, u, mu) for u in range(num_actions)] for x in range(num_states)])
        
        # Compute Q values for all states and actions at once
        Q_tx = rewards + (1 - done) * np.dot(transition_probs, curr_V)
        Qs[t] = Q_tx
        
        # Compute policy for all states at once
        pi_t = np.array([pi.pmf(t, x) for x in range(num_states)])
        
        # Compute current V values for all states at once
        curr_V = np.einsum('sa,sa->s',Qs[t],pi_t)
        Vs[t] = curr_V
    
    return Qs, Vs

class DiscretizedGraphonExactOMDSolverFinite(Solver):
    """
    Exact OMD solutions for finite state spaces
    """

    def __init__(self, eta=1, **kwargs):
        super().__init__(**kwargs)
        self.num_alphas = 1
        self.alpha = .5
        self.alphas = np.array([self.alpha])
        self.y = None
        self.temperature_param = 1.
        self.eta = eta
        self.pi_array = None
        self.sigma_array = None
    
    def compute_Qs(self,
                   mfg: FiniteGraphonMeanFieldGame,
                   mu: MeanField,
                   pi: FiniteFeedbackPolicy,):
        if self.pi_array is None:
            self.pi_array = np.array([[pi.pmf(t, x)
                                       for x in range(mfg.agent_observation_space[1].n)]
                                      for t in range(mfg.time_steps)])
        if self.sigma_array is None:
            self.sigma = self.pi_array.copy()
        Qs, _ = compute_values(mfg,mu,pi)
        return Qs

    def solve(self,
              mfg: FiniteGraphonMeanFieldGame,
              mu: MeanField,
              pi: FiniteFeedbackPolicy,
              **kwargs):

        Qs = self.compute_Qs(mfg,mu,pi)

        # FTRL implementation
        if self.y is None:
            self.y = self.eta * Qs
        else:
            self.y = self.y + self.eta * Qs

        qsoftmax = QSoftMaxPolicy(
            mfg.agent_observation_space,
            mfg.agent_action_space,
            self.y.squeeze(),
            1 / self.temperature_param,
            policy_array=self.pi_array)

        policy = DiscretizedGraphonFeedbackPolicy(mfg.agent_observation_space,
                                                  mfg.agent_action_space,
                                                  qsoftmax)

        # MD implementation
        # etaQs = self.eta * Qs
        # policy = ModifiedSoftmaxPolicy(mfg.agent_observation_space,
        #                                mfg.agent_action_space,
        #                                etaQs,
        #                                1 / self.temperature_param,
        #                                self.pi_array,
        #                                self.sigma_array)

        return policy, {"Q": self.y}
