# -*- coding: utf-8 -*-
import sys
import argparse
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
from experiments import args_parser
import numpy as np
import matplotlib.pyplot as plt
import time
from solver.policy.array_policy import ArrayPolicy
from solver.policy.dirichlet_policy import DirichletArrayPolicy
from solver.rmd_solver_array import RMDSolver
from solver.prox_solver import ProxRMDSolver
from solver.amd_solver_array import AMDSolver
from games.finite.beach import BeachGraphon

def parse_args():
    parser = argparse.ArgumentParser(description="Test RMD and AMD solvers on Beach Graphon MFG")
    parser.add_argument('--iterations', type=int, default=100, help='Number of iterations (default: 100)')
    parser.add_argument('--tolerance', type=float, default=1e-4, help='Tolerance for early stopping in AdaptiveProxRMDSolver (default: 1e-4)')
    return parser.parse_args()

def test_beach_graphon_simplified(iterations=100, tolerance=1e-4):
    """Test the Beach Graphon MFG with RMD and AMD solvers (simplified version with 3 solvers)."""
    print(f"Testing Beach Graphon MFG with RMD and AMD solvers for {iterations} iterations...")
    print(f"Game parameters: horizon=3")
    
    # Set random seed for reproducibility
    np.random.seed(0)
    
    # Define solver configurations
    solver_configs = [
        {
            "name": r"RMD ($\eta=0.1$, $\lambda=0.1$)",
            "solver": "rmd_array",
            "reg_param": 0.1,
            "sigma_update_time": 10,
            "eta": 0.1
        },
        {
            "name": r"APP ($\eta=0.1$, $\lambda=0.1$, $\tau=1$)",
            "solver": "amd_array",
            "reg_param": 0.1,
            "sigma_update_time": 1,
            "eta": 0.1
        },
        {
            "name": r"APP ($\eta=0.1$, $\lambda=0.1$, $\tau=2$)",
            "solver": "amd_array",
            "reg_param": 0.1,
            "sigma_update_time": 2,
            "eta": 0.1
        },
        {
            "name": r"Proximal ($\eta=0.1$, $\lambda=0.1$, $\tau=1$)",
            "solver": "prox_rmd",
            "reg_param": 0.1,
            "sigma_update_time": 1,
            "eta": 0.1
        },
        {
            "name": r"Proximal ($\eta=0.1$, $\lambda=0.1$, $\tau=2$)",
            "solver": "prox_rmd",
            "reg_param": 0.1,
            "sigma_update_time": 2,
            "eta": 0.1
        },
        {
            "name": r"Proximal ($\eta=0.1$, $\lambda=0.1$, $\tau=4$)",
            "solver": "prox_rmd",
            "reg_param": 0.1,
            "sigma_update_time": 4,
            "eta": 0.1
        },
        {
            "name": r"Proximal ($\eta=0.1$, $\lambda=0.1$, $\tau=10$)",
            "solver": "prox_rmd",
            "reg_param": 0.1,
            "sigma_update_time": 10,
            "eta": 0.1
        },
    ]
    
    # Store results for plotting
    all_results = {}
    
    # Create a base configuration for the game
    base_config = args_parser.generate_config_from_kw(
        game="Beach-Graphon",
        graphon="power",
        solver="rmd_array",
        simulator="exact",
        evaluator="exact",
        eval_solver="exact",
        iterations=iterations,
        total_iterations=50,
        eta=0.1,
        reg_param=0.0,
        sigma_update_time=10,
        results_dir="./results/",
        exp_name="beach_graphon_simplified",
        env_params={"time_steps": 3},  # H=3
        verbose=0,
        seed=0
    )
    
    # Initialize the game, simulator, evaluator, and eval_solver
    game = base_config["game"](**base_config["game_config"])
    simulator = base_config["simulator"](**base_config["simulator_config"])
    evaluator = base_config["evaluator"](**base_config["evaluator_config"])
    eval_solver = base_config["eval_solver"](**base_config["eval_solver_config"])
    
    # Create a single initial policy for all experiments
    print("Creating initial policy with Dirichlet distribution (alpha=0.5)...")
    np.random.seed(0)  # Ensure reproducibility
    initial_policy = DirichletArrayPolicy(game.time_steps, game.agent_observation_space, game.agent_action_space, alpha=0.5)
    
    # Simulate to get initial mean field
    initial_mu, _ = simulator.simulate(game, initial_policy)
    print("Initial policy created and simulated.")
    
    # Run experiments for each solver configuration
    for solver_config in solver_configs:
        solver_name = solver_config["name"]
        solver_type = solver_config["solver"]
        reg_param = solver_config["reg_param"]
        sigma_update_time = solver_config["sigma_update_time"]
        eta = solver_config["eta"]
        
        print(f"\nRunning {solver_name}...")
        
        # Create solver with current configuration
        solver_config_dict = {
            "total_iterations": 50,
            "eta": eta,
            "reg_param": reg_param,
            "sigma_update_time": sigma_update_time,
            "verbose": 0,
            "num_alphas": 1,
        }
        
        if solver_type == "rmd_array":
            solver = RMDSolver(**solver_config_dict)
        elif solver_type == "amd_array":
            solver = AMDSolver(**solver_config_dict)
        elif solver_type == "prox_rmd":
            solver = ProxRMDSolver(**solver_config_dict)
        else:
            raise ValueError(f"Unsupported solver type: {solver_type}")
        
        # Create a copy of the initial policy for this run
        policy = initial_policy.copy()
        mu = initial_mu
        
        # Run the experiment with the same initial policy and mean field
        logs, nash_conv = run_experiment_with_solver(
            game, simulator, evaluator, eval_solver, solver, policy, mu, iterations
        )
        
        # Store results
        all_results[solver_name] = nash_conv
    
    # Set up plot with LaTeX
    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "serif",
        "font.serif": ["Computer Modern Roman"],
        "font.size": 20,  # Increase base font size
    })
    
    # Figure 1: Comparing Proximal (τ=1) and APP (τ=1)
    plt.figure(figsize=(18, 12))
    
    # Define colors and line styles for Figure 1
    styles_fig1 = {
        'Proximal (τ=1)': {'color': 'orange', 'linestyle': '-', 'linewidth': 6, 'alpha': 0.9},
        'APP (τ=1)': {'color': 'red', 'linestyle': '--', 'linewidth': 4, 'alpha': 0.9}
    }
    
    # Plot Proximal (τ=1)
    for solver_config in solver_configs:
        if "Proximal" in solver_config["name"] and "$\\tau=1$" in solver_config["name"] and "Adaptive" not in solver_config["name"]:
            solver_name = solver_config["name"]
            nash_conv = all_results[solver_name]
            plt.plot(range(iterations + 1), nash_conv, 
                    label=solver_name,
                    **styles_fig1['Proximal (τ=1)'])
    
    # Plot APP (τ=1)
    for solver_config in solver_configs:
        if "APP" in solver_config["name"] and "$\\tau=1$" in solver_config["name"]:
            solver_name = solver_config["name"]
            nash_conv = all_results[solver_name]
            plt.plot(range(iterations + 1), nash_conv, 
                    label=solver_name,
                    **styles_fig1['APP (τ=1)'])
    
    plt.xlabel(r'Iterations', fontsize=35)
    plt.ylabel(r'Exploitability', fontsize=35)
    plt.title(f'Exploitability Comparison for Beach Graphon MFG\nProximal ($\\tau=1$) vs APP ($\\tau=1$) ($H=3$)', fontsize=40)
    plt.legend(fontsize=30)
    plt.grid(True, alpha=0.3)
    plt.yscale('log')  # Set y-axis to logarithmic scale
    plt.tight_layout()
    
    # Save Figure 1
    plt.savefig(f'./results/beach_graphon_comparison_fig1.pdf')
    plt.savefig(f'./results/beach_graphon_comparison_fig1.png', dpi=300)
    
    # Figure 2: Comparing RMD (λ=0) and Proximal (τ=1, 2)
    plt.figure(figsize=(18, 12))
    
    # Define colors and line styles for Figure 2
    styles_fig2 = {
        'RMD': {'color': 'cyan', 'linestyle': '-', 'linewidth': 6, 'alpha': 0.9},
        'Proximal (τ=1)': {'color': 'orange', 'linestyle': '--', 'linewidth': 4, 'alpha': 0.9},
        'Proximal (τ=2)': {'color': 'darkorange', 'linestyle': ':', 'linewidth': 5, 'alpha': 0.9}
    }
    
    # Plot RMD (λ=0)
    for solver_config in solver_configs:
        if "RMD" in solver_config["name"] and "Proximal" not in solver_config["name"]:
            solver_name = solver_config["name"]
            nash_conv = all_results[solver_name]
            plt.plot(range(iterations + 1), nash_conv, 
                    label=solver_name,
                    **styles_fig2['RMD'])
    
    # Plot Proximal (τ=1)
    for solver_config in solver_configs:
        if "Proximal" in solver_config["name"] and "$\\tau=1$" in solver_config["name"] and "Adaptive" not in solver_config["name"]:
            solver_name = solver_config["name"]
            nash_conv = all_results[solver_name]
            plt.plot(range(iterations + 1), nash_conv, 
                    label=solver_name,
                    **styles_fig2['Proximal (τ=1)'])
    
    # Plot Proximal (τ=2)
    for solver_config in solver_configs:
        if "Proximal" in solver_config["name"] and "$\\tau=2$" in solver_config["name"]:
            solver_name = solver_config["name"]
            nash_conv = all_results[solver_name]
            plt.plot(range(iterations + 1), nash_conv, 
                    label=solver_name,
                    **styles_fig2['Proximal (τ=2)'])
    
    plt.xlabel(r'Iterations', fontsize=35)
    plt.ylabel(r'Exploitability', fontsize=35)
    plt.title(f'Exploitability Comparison for Beach Graphon MFG\nRMD ($\\lambda=0$) vs Proximal ($\\tau=1,2$) ($H=3$)', fontsize=40)
    plt.legend(fontsize=30)
    plt.grid(True, alpha=0.3)
    plt.yscale('log')  # Set y-axis to logarithmic scale
    plt.tight_layout()
    
    # Save Figure 2
    plt.savefig(f'./results/beach_graphon_comparison_fig2.pdf')
    plt.savefig(f'./results/beach_graphon_comparison_fig2.png', dpi=300)
    
    print("\nExperiment completed successfully!")

def run_experiment_with_solver(game, simulator, evaluator, eval_solver, solver, policy, mu, iterations):
    """Run experiment with given solver, policy, and mean field."""
    logs = []
    Delta_J_list = []
    mu_list = []
    policy_list = []

    # Evaluate initial policy
    best_response, info = eval_solver.solve(game, mu, policy)
    eval_results_pi = evaluator.evaluate(game, mu, policy)
    eval_results_opt = evaluator.evaluate(game, mu, best_response)
    Delta_J = eval_results_opt['eval_mean_returns'] - eval_results_pi['eval_mean_returns']
    print(f"Initial Nash Conv. = {Delta_J:.8f}")
    
    # Store initial values
    logs.append({
        "eval_pi": eval_results_pi,
        "eval_opt": eval_results_opt,
        "Delta_J": Delta_J
    })
    Delta_J_list.append(Delta_J)
    mu_list.append(np.array([[mu.mu_alphas[0].evaluate_integral(t, lambda x: x == s) for s in range(game.N_states)] for t in range(game.time_steps)]))
    policy_list.append(policy.policy_array.copy())

    # Outer iterations
    for i in range(iterations):
        log = {}
        t_start = time.time()
        
        # Solve for policy update
        if isinstance(solver, ProxRMDSolver):
            policy, info = solver.solve(game, mu, policy, simulator=simulator, iteration=i)
        else:
            policy, info = solver.solve(game, mu, policy, iteration=i)
        log["solver"] = info

        # Simulate to get new mean field
        mu, info = simulator.simulate(game, policy)
        log["simulator"] = info

        # Evaluate current policy
        best_response, info = eval_solver.solve(game, mu, policy)
        log["best_response"] = info

        eval_results_pi = evaluator.evaluate(game, mu, policy)
        eval_results_opt = evaluator.evaluate(game, mu, best_response)
        log["eval_pi"] = eval_results_pi
        log["eval_opt"] = eval_results_opt

        Delta_J = eval_results_opt['eval_mean_returns'] - eval_results_pi['eval_mean_returns']
        log["Delta_J"] = Delta_J

        print(f"Loop {i}: {time.time()-t_start:.2f} Nash Conv. = {Delta_J:.8f}")

        logs.append(log)
        Delta_J_list.append(Delta_J)
        mu_list.append(np.array([[mu.mu_alphas[0].evaluate_integral(t, lambda x: x == s) for s in range(game.N_states)] for t in range(game.time_steps)]))
        policy_list.append(policy.policy_array.copy())
        
    return logs, Delta_J_list

if __name__ == "__main__":
    args = parse_args()
    test_beach_graphon_simplified(args.iterations, args.tolerance)
