import copy
import numpy as np
from games.graphon_mfg import FiniteGraphonMeanFieldGame
from simulator.mean_fields.base import MeanField
from solver.policy.array_policy import ArrayPolicy
from solver.prox_solver import ProxRMDSolver
from solver.omd_solver_array import normalize_at_last_axis

class AdaptiveProxRMDSolver(ProxRMDSolver):
    """
    Implementation of the Adaptive Proximal RMD algorithm.
    
    This solver extends the Proximal RMD algorithm by adding an early stopping
    mechanism based on the policy update magnitude. If the policy update is
    smaller than a specified tolerance, the inner loop terminates early.
    """
    
    def __init__(self, *args, tolerance=1e-6, **kwargs):
        """
        Initialize the AdaptiveProxRMDSolver.
        
        Args:
            *args: Arguments to pass to the parent class
            tolerance: Tolerance for early stopping (default: 1e-6)
            **kwargs: Keyword arguments to pass to the parent class
        """
        super().__init__(*args, **kwargs)
        self.tolerance = tolerance
    
    def solve(self, mfg, mu, pi, simulator=None, iteration=0, **kwargs):
        """
        Implementation of the Adaptive Proximal RMD algorithm.
        
        This method extends the Proximal RMD algorithm by adding an early stopping
        mechanism based on the policy update magnitude. If the policy update is
        smaller than a specified tolerance, the inner loop terminates early.
        
        Args:
            mfg: Mean Field Game
            mu: Mean field
            pi: Policy
            simulator: Simulator to compute mean field
            iteration: Current iteration
            **kwargs: Additional arguments
            
        Returns:
            Updated policy and info
        """
        # Update pi_array from the input policy
        self.update_pi_in_solver(pi)
        
        # Create a copy of the policy to work with (π^0 = σ^k in Algorithm 5)
        pi_new = pi.copy()
        
        # Store the current sigma as the baseline for this iteration
        # This is important to ensure we use the same sigma throughout the inner loop
        sigma_baseline = self.sigma_array.copy()
        
        # Check if simulator is provided
        if simulator is None:
            raise ValueError("simulator must be provided for AdaptiveProxRMDSolver")
        
        # Initialize mean field
        mu_current = copy.deepcopy(mu)
        
        # Track the actual number of iterations performed
        actual_iterations = 0
        
        # Run RMD solver sigma_update_time (τ) times or until convergence
        for t in range(self.sigma_update_time):
            # Store the previous policy for comparison
            pi_prev = pi_new.policy_array.copy()
            
            # Compute Q values based on current mean field
            Qs = self.compute_Qs(mfg, mu_current, pi_new)
            
            # Update policy according to RMD update rule
            # Note: We use sigma_baseline here, not self.sigma_array
            likelyhood = pi_new.policy_array ** (1 - self.lambda_eta)
            likelyhood *= sigma_baseline ** self.lambda_eta
            likelyhood *= self.compute_exp_eta_Qs(Qs)
            pi_new.policy_array = normalize_at_last_axis(likelyhood)
            
            # Update mean field after each policy update
            mu_current, _ = simulator.simulate(mfg, pi_new)
            
            # Increment the actual iterations counter
            actual_iterations += 1
            
            # Check for early stopping
            # Calculate the maximum absolute difference between the current and previous policies
            policy_diff = np.max(np.abs(pi_new.policy_array - pi_prev))
            
            # If the difference is smaller than the tolerance, stop early
            if policy_diff < self.tolerance:
                break
        
        # Always update sigma_array with the final policy (σ^{k+1} ← π^{τ} in Algorithm 5)
        self.sigma_array = pi_new.policy_array.copy()
        
        # Return the final policy and info including the actual number of iterations
        return pi_new, {
            "pi_new": pi_new.policy_array,
            "actual_iterations": actual_iterations,
            "early_stopped": actual_iterations < self.sigma_update_time
        }
