import time
import os
import numpy as np
from solver.policy.array_policy import ArrayPolicy
from solver.policy.dirichlet_policy import DirichletArrayPolicy
from games.finite.beach import BeachGraphon
from tqdm import tqdm

def run_experiment(**config):
    """ Initialize """
    game = config["game"](**config["game_config"])
    simulator = config["simulator"](**config["simulator_config"])
    evaluator = config["evaluator"](**config["evaluator_config"])
    solver = config["solver"](**config["solver_config"])
    eval_solver = config["eval_solver"](**config["eval_solver_config"])

    np.random.seed(config["seed"])

    logs = []
    Delta_J_list = []
    mu_list = []
    policy_list = []

    """ Initial mean field and policy """
    print("Initializing. ", flush=True)
    # Use DirichletArrayPolicy with alpha=0.5 for non-uniform initialization
    policy = DirichletArrayPolicy(game.time_steps, game.agent_observation_space, game.agent_action_space, alpha=0.5)
    mu, info = simulator.simulate(game, policy)
    print("Initialized. ", flush=True)

    """ Outer iterations """
    for i in tqdm(range(config["iterations"]), desc="Outer iterations"):
        log = {}
        t = time.time()
        mu_list.append(np.array([[mu.mu_alphas[0].evaluate_integral(t, lambda x: x == s) for s in range(game.N_states)] for t in range(game.time_steps)]))
        policy_list.append(policy.policy_array)
        policy, info = solver.solve(game, mu, policy, iteration=i)
        log["solver"] = info

        mu, info = simulator.simulate(game, policy)
        log["simulator"] = info

        best_response, info = eval_solver.solve(game, mu, policy)
        log["best_response"] = info

        eval_results_pi = evaluator.evaluate(game, mu, policy)
        eval_results_opt = evaluator.evaluate(game, mu, best_response)
        log["eval_pi"] = eval_results_pi
        log["eval_opt"] = eval_results_opt

        Delta_J = eval_results_opt['eval_mean_returns'] - eval_results_pi['eval_mean_returns']
        log["Delta_J"] = Delta_J

        tqdm.write(f"Loop {i}: {time.time()-t:.2f} Nash Conv. = {Delta_J:.8f}")

        logs.append(log)
        Delta_J_list.append(Delta_J)

    # Save results
    os.makedirs(os.path.dirname(f"{config['experiment_directory']}/Delta_J.npy"), exist_ok=True)
    np.save(f"{config['experiment_directory']}/Delta_J.npy", np.array(Delta_J_list))
    np.save(f"{config['experiment_directory']}/Policy.npy", policy.policy_array)
    np.save(f"{config['experiment_directory']}/Mu.npy", np.array([[mu.mu_alphas[0].evaluate_integral(t, lambda x: x == s) for s in range(game.N_states)] for t in range(game.time_steps)]))
    np.save(f"{config['experiment_directory']}/Mu_list.npy", np.array(mu_list))
    np.save(f"{config['experiment_directory']}/Policy_list.npy", np.array(policy_list))
    
    return logs
