import numpy as np
from scipy.stats import entropy
from scipy.special import softmax
from games.graphon_mfg import FiniteGraphonMeanFieldGame
from simulator.mean_fields.base import MeanField
from solver.policy.array_policy import ArrayPolicy
from solver.omd_solver_array import MDSolver, normalize_at_last_axis
import pdb

def compute_regularized_values(mfg: FiniteGraphonMeanFieldGame,
                   mu: MeanField,
                   pi: ArrayPolicy,
                   sigma: np.ndarray,
                   reg_param: float):
    num_states = mfg.agent_observation_space[1].n
    num_actions = mfg.agent_action_space.n

    Vs = np.zeros((mfg.time_steps, num_states))
    Qs = np.zeros((mfg.time_steps, num_states, num_actions))
    curr_V = np.zeros(num_states)

    for t in reversed(range(mfg.time_steps)):
        # done is now computed only once per time step
        done = t >= mfg.time_steps - 1

        # Compute all rewards at once
        rewards = np.array([[mfg.reward(t, x, u, mu) for u in range(num_actions)] for x in range(num_states)])
        # if t == mfg.time_steps - 1:
        #     pdb.set_trace()
        #     print(f'test Q_tx:{rewards[0,0]}')
        # Compute all transition probabilities at once
        transition_probs = np.array([[mfg.transition_probs(t, x, u, mu) for u in range(num_actions)] for x in range(num_states)])
        # print(f'test:{transition_probs.shape}')
        # Compute Q values for all states and actions at once
        Q_tx = rewards + (1 - done) * np.dot(transition_probs, curr_V)
        Qs[t] = Q_tx
        
        # Compute policy for all states at once
        pi_t = pi.policy_array[t]
        
        # Compute current V values for all states at once
        curr_V = np.einsum('sa,sa->s',Qs[t],pi_t) 
        if reg_param != 0:
            curr_V -= reg_param * entropy(pi_t, sigma[t], axis=-1)
        Vs[t] = curr_V
    # print(f'test Qs:{Qs[-1,0]}')
    return Qs, Vs

class RMDSolver(MDSolver):
    """
    MD solutions for regularized MFGs
    """
    def __init__(self, *args, reg_param=0., **kwargs):
        super().__init__(*args, **kwargs)
        self.reg_param = reg_param
        self.lambda_eta = reg_param * self.eta
        assert self.lambda_eta <= 1

    def compute_Qs(self,
                   mfg: FiniteGraphonMeanFieldGame,
                   mu: MeanField,
                   pi: ArrayPolicy,
                   **kwargs):
        Qs, _ = compute_regularized_values(mfg,
                                           mu,
                                           pi,
                                           self.sigma_array,
                                           self.reg_param)
        return Qs

    def solve(self,
              mfg: FiniteGraphonMeanFieldGame,
              mu: MeanField,
              pi: ArrayPolicy,
              **kwargs):
        self.update_pi_in_solver(pi)
        Qs = self.compute_Qs(mfg,mu,pi)
        # print(f'pi:{pi.policy_array}')
        likelyhood = pi.policy_array ** (1 - self.lambda_eta)
        likelyhood *= self.sigma_array ** self.lambda_eta
        likelyhood *= self.compute_exp_eta_Qs(Qs)
        pi.policy_array = normalize_at_last_axis(likelyhood)
        return pi, {"Qs": Qs}
