import numpy as np
from games.graphon_mfg import FiniteGraphonMeanFieldGame
from simulator.mean_fields.base import MeanField
from solver.policy.array_policy import ArrayPolicy
from solver.rmd_solver_array import RMDSolver, compute_regularized_values
from solver.omd_solver_array import normalize_at_last_axis
import pdb


class AMDSolver(RMDSolver):
    """
    Implementation of Algorithm 4: APP for MFG
    """
    def __init__(self, *args, sigma_update_time=10., **kwargs):
        super().__init__(*args, **kwargs)
        if sigma_update_time < 0:
            raise ValueError("sigma_update_time must be non-negative")
        self.sigma_update_time = sigma_update_time

    def solve(self,
              mfg: FiniteGraphonMeanFieldGame,
              mu: MeanField,
              pi: ArrayPolicy,
              iteration: int,
              **kwargs):
        """
        Implementation of Algorithm 4: APP for MFG
        
        Args:
            mfg: Mean Field Game
            mu: Mean field
            pi: Policy
            iteration: Current iteration
            **kwargs: Additional arguments
            
        Returns:
            Updated policy and info
        """
        # Update pi_array from the input policy
        self.update_pi_in_solver(pi)
        
        # Create a copy of the policy to work with
        pi_new = pi.copy()
        
        # Compute Q values based on current mean field
        Qs = self.compute_Qs(mfg, mu, pi_new)
        
        # Update policy according to RMD update rule
        likelyhood = pi_new.policy_array ** (1 - self.lambda_eta)
        likelyhood *= self.sigma_array ** self.lambda_eta
        likelyhood *= self.compute_exp_eta_Qs(Qs)
        pi_new.policy_array = normalize_at_last_axis(likelyhood)
        
        # Update sigma_array based on iteration (Algorithm 4: If k%τ = 0 then σ ← σ^{k+1})
        if self.sigma_update_time > 0 and iteration % self.sigma_update_time == 0:
            self.sigma_array = pi_new.policy_array.copy()
        
        return pi_new, {"Qs": Qs}
