import copy
import numpy as np
from games.graphon_mfg import FiniteGraphonMeanFieldGame
from simulator.mean_fields.base import MeanField
from solver.policy.array_policy import ArrayPolicy
from solver.rmd_solver_array import RMDSolver
from solver.amd_solver_array import AMDSolver
from solver.omd_solver_array import normalize_at_last_axis

class ProxRMDSolver(AMDSolver):

    def solve(self, mfg, mu, pi, simulator=None, iteration=0, **kwargs):
        """
        Implementation of the Proximal RMD algorithm (Algorithm 5 in the paper).
        
        For τ=1, this should be equivalent to APP (Algorithm 4).
        For τ>1, this implements the proximal algorithm.
        
        Args:
            mfg: Mean Field Game
            mu: Mean field
            pi: Policy
            simulator: Simulator to compute mean field
            iteration: Current iteration
            **kwargs: Additional arguments
            
        Returns:
            Updated policy and info
        """
        # Update pi_array from the input policy
        self.update_pi_in_solver(pi)
        
        # Create a copy of the policy to work with (π^0 = σ^k in Algorithm 5)
        pi_new = pi.copy()
        
        # Store the current sigma as the baseline for this iteration
        # This is important to ensure we use the same sigma throughout the inner loop
        sigma_baseline = self.sigma_array.copy()
        
        # Check if simulator is provided
        if simulator is None:
            raise ValueError("simulator must be provided for ProxRMDSolver")
        
        # Initialize mean field
        mu_current = copy.deepcopy(mu)
        
        # Run RMD solver sigma_update_time (τ) times
        for t in range(self.sigma_update_time):
            # Compute Q values based on current mean field
            Qs = self.compute_Qs(mfg, mu_current, pi_new)
            
            # Update policy according to RMD update rule
            # Note: We use sigma_baseline here, not self.sigma_array
            likelyhood = pi_new.policy_array ** (1 - self.lambda_eta)
            likelyhood *= sigma_baseline ** self.lambda_eta
            likelyhood *= self.compute_exp_eta_Qs(Qs)
            pi_new.policy_array = normalize_at_last_axis(likelyhood)
            
            # Update mean field after each policy update
            mu_current, _ = simulator.simulate(mfg, pi_new)
        
        # Always update sigma_array with the final policy (σ^{k+1} ← π^{τ} in Algorithm 5)
        # This is different from APP (Algorithm 4) which only updates sigma when k%τ = 0
        self.sigma_array = pi_new.policy_array.copy()
        
        # Return the final policy
        return pi_new, {"pi_new": pi_new.policy_array}
