# Run the RMAB Simulator
import argparse
import tqdm
import concurrent.futures
import numpy as np
from env import *
from utils import *
from agent import *
from eval import *


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='RMAB Simulator')
    parser.add_argument('-N', type=int, default=900, help='Number of agents')
    parser.add_argument('-d', type=int, default=4, help='Number of states')
    parser.add_argument('-alpha', type=float, default=0.2, help='Fractional allocation')
    parser.add_argument('--gamma', type=float, default=0.5, help='Discount factor')
    parser.add_argument('--num_sims', type=int, default=3200, help='Length of simulations')
    parser.add_argument('--num_evals', type=int, default=10, help='Number of evaluations')
    parser.add_argument('--num_rho', type=int, default=50, help='rho weight')
    parser.add_argument('--Evaluation', type=bool, default=False, help='Policy Evaluate')
    parser.add_argument('--test_entanglement', type=bool, default=False, help='Test Entanglement')
    parser.add_argument('--ent_learning', type=bool, default=False, help='Entanglement Error Curve')
    parser.add_argument('--decomp_learning', type=bool, default=True, help='Learning Error of Decomposition')
    parser.add_argument('--test_circular', type=bool, default=True, help='Test Circular Environment')
    parser.add_argument('--store', type=bool, default=True, help='Store Results')
    parser.add_argument('--plot', type=bool, default=True, help='Plot Results')
    args = parser.parse_args()

    if args.test_circular:
        env = Circular_Env(args.N, args.d, args.alpha, args.gamma)
    else:
        env = RMAB(args.N, args.d, args.alpha, args.gamma)
    Index_POLICY = [2,1,0,3]

    m_star = env.find_m_star(Index_POLICY)


    # Test Entanglement
    if args.test_entanglement:

        def run_entanglement_test():
            env.reset_to_m_star()
            
            curr_state = next_state = env.global_state_agent
            empirical_data = []
            
            for i in range(args.num_sims):
                next_state, next_action, next_reward = env.step_by_agent(Index_POLICY)
                if i >= 0.1 * args.num_sims:
                    curr_state_agent_0 = int(curr_state[0])
                    pi_1 = get_policy(env.config/args.N, Index_POLICY, args.alpha)[curr_state_agent_0]
                    empirical_data.append((curr_state_agent_0, pi_1))
                    curr_state = next_state
            
            def loss_function(pi_star_flat):
                pi_star = pi_star_flat.reshape(args.d, 2)
                pi_star = pi_star / pi_star.sum(axis=1, keepdims=True)
                
                loss = 0
                for s, pi_1 in empirical_data:
                    loss += np.abs(pi_1 - pi_star[s, 1]) + np.abs(1 - pi_1 - pi_star[s, 0])
                
                return loss / len(empirical_data)
            
           
            initial_guess = np.ones(args.d * 2) / 2
            
            constraints = []
            for s in range(args.d):
                constraint = {'type': 'eq', 'fun': lambda x, s=s: x[2*s] + x[2*s+1] - 1}
                constraints.append(constraint)
            
            bounds = [(0, 1) for _ in range(args.d * 2)]
            
            # Optimize using scipy
            from scipy.optimize import minimize
            result = minimize(loss_function, initial_guess, method='SLSQP', 
                            constraints=constraints, bounds=bounds)
            
            optimal_pi_star = result.x.reshape(args.d, 2)
            optimal_pi_star = optimal_pi_star / optimal_pi_star.sum(axis=1, keepdims=True)


            # print(f"\nOptimal pi_star found:")
            # print(optimal_pi_star)
            # print(f"Minimum loss: {result.fun}")
            return result.fun

        with concurrent.futures.ProcessPoolExecutor() as executor:
            futures = []
            for _ in range(args.num_evals):
                futures.append(executor.submit(run_entanglement_test))
            
            losses = []
            for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=args.num_evals, desc="Running entanglement tests"):
                losses.append(future.result())

        mean_loss = np.mean(losses)
        std_loss = np.std(losses) / np.sqrt(args.num_evals)

        print(f"\nEntanglement Test Results:")
        print(f"Mean Loss: {mean_loss*args.N}")
        print(f"Standard Error: {std_loss*args.N}")

        if args.store:
            with open(f"results_circular/results_{args.N}_ent", "w") as f:
                f.write(f"Mean Loss: {mean_loss*args.N}\n")
                f.write(f"Standard Error: {std_loss*args.N}\n")


    # Entanglement Error Curve
    if args.ent_learning:
        env.reset_to_m_star()

        curr_state = next_state = env.global_state_agent
        empirical_data = []
        losses = []
        
        for i in range(args.num_sims):
            next_state, next_action, next_reward = env.step_by_agent(Index_POLICY)
            if i >= 0.1 * args.num_sims:
                curr_state_agent_0 = int(curr_state[0])
                pi_1 = get_policy(env.config/args.N, Index_POLICY, args.alpha)[curr_state_agent_0]
                empirical_data.append((curr_state_agent_0, pi_1))
                curr_state = next_state
                

                if len(empirical_data) % 10 == 0:
 
                    def loss_function(pi_star_flat, empirical_data):
                        pi_star = pi_star_flat.reshape(args.d, 2)
                        pi_star = pi_star / pi_star.sum(axis=1, keepdims=True)
                        
                        loss = 0
                        for s, pi_1 in empirical_data:
                            loss += np.abs(pi_1 - pi_star[s, 1]) + np.abs(1 - pi_1 - pi_star[s, 0])
                        
                        return loss / len(empirical_data)
                    
                    # Initial guess for pi_star (uniform policy)
                    initial_guess = np.ones(args.d * 2) / 2
                    
                    constraints = []
                    for s in range(args.d):
                        constraint = {'type': 'eq', 'fun': lambda x, s=s: x[2*s] + x[2*s+1] - 1}
                        constraints.append(constraint)
                    
                    bounds = [(0, 1) for _ in range(args.d * 2)]
                    
                    # Optimize using scipy
                    from scipy.optimize import minimize
                    result = minimize(loss_function, initial_guess, method='SLSQP', 
                                    constraints=constraints, bounds=bounds, args=(empirical_data,))
                    
                    losses.append(result.fun)
                    print(f"Data points: {len(empirical_data)}, Loss: {result.fun*args.N}")

        # Save the losses
        if args.store:
            np.save('learning_error/entanglement_losses.npy', np.array(losses))
            print(f"Saved losses to learning_error/entanglement_losses.npy")


    # Policy Evaluation
    if args.Evaluation:
        agent = Agent(args.N, args.d, args.gamma)
        Q_values_run = []

        with concurrent.futures.ProcessPoolExecutor() as executor:
            futures = []
            for i in range(args.num_evals):
                futures.append(executor.submit(Local_TD_Learning, env, agent, Index_POLICY, args.num_sims))
            for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=args.num_evals, desc="Local TD-learning"):
                Q_values_run.append(future.result())
        
        mean_Q_values = np.mean(Q_values_run, axis=0)
        std_Q_values = np.std(Q_values_run, axis=0)/np.sqrt(args.num_evals)
        # print(f"Q-values Approx: {mean_Q_Approx} +- {std_Q_Approx/np.sqrt(args.num_evals)}")

        # Policy Evaluation
        loss = []
        Q_std = 0
        global_Q_values = []
        # Q_values_0 = Q_values[::2]
        # Q_values_1 = Q_values[1::2]
        env.reset_to_m_star()

        for _ in tqdm.tqdm(range(args.num_rho)):
            mean_Q_value_MC, std_Q_Value_MC = env.Value_Learning_MC(Index_POLICY, num_sims=args.N*3)
            global_Q_values.append(mean_Q_value_MC)
            
            Q_values_0 = mean_Q_values[::2]
            Q_values_1 = mean_Q_values[1::2]
            Q_value_Approx = get_global_reward(env.config/env.N, Index_POLICY, env.alpha, Q_values_1, Q_values_0) * env.N
            loss.append(np.abs(mean_Q_value_MC - Q_value_Approx))

            # Compute Q_std = std_MC+std_local_Q
            Q_std += std_Q_Value_MC
            # Q_values_std_0 = std_Q_values[::2]
            # Q_values_std_1 = std_Q_values[1::2]
            # Local_Q_std = get_global_reward(env.config/env.N, Index_POLICY, env.alpha, Q_values_std_1, Q_values_std_0) * env.N
            # Q_std += Local_Q_std
            
            # For simplicity, we do not consider std_local_Q

            # print(f"Q-value Approx: {Q_value_Approx} +- {Local_Q_std}")
            print(f"Q-value Approx: {Q_value_Approx}")
            print(f"Value_MC: {mean_Q_value_MC} +- {std_Q_Value_MC}")
            print(f"Diff: {np.abs(mean_Q_value_MC - Q_value_Approx)}")
            env.step(Index_POLICY)

        mean_loss = np.mean(loss)
        std_loss = np.std(loss)
        Q_std /= args.num_rho
        print(f"Loss: {mean_loss} +- {std_loss/np.sqrt(args.num_rho) + Q_std}")
        mean_global_Q_values = np.mean(global_Q_values)

        if args.store:
            with open(f"results_circular/results_{args.N}", "w") as f:
                # f.write(f"Q-values Approx: {Q_value_Approx} +- {std_Q_Approx/np.sqrt(args.num_evals)}\n")
                # f.write(f"Value_MC: {mean_Q_value_MC} +- {std_Q_Value_MC/np.sqrt(args.num_evals)}\n")
                f.write(f"Global Q-value: {mean_global_Q_values}\n")
                f.write(f"Loss: {mean_loss} +- {std_loss/np.sqrt(args.num_rho) + Q_std}")


    # Learning Curve of Decomposition Error
    if args.decomp_learning:
        env.reset_to_m_star()
        global_configs = []
        global_Q_means = []
        global_Q_stds = []

        for _ in tqdm.tqdm(range(args.num_rho)):
            mean_Q_value_MC, std_Q_Value_MC = env.Value_Learning_MC(Index_POLICY, num_sims=args.N*3)
            
            global_configs.append(env.config)
            global_Q_means.append(mean_Q_value_MC)
            global_Q_stds.append(std_Q_Value_MC)

            env.step(Index_POLICY)

        global_configs = np.array(global_configs)
        global_Q_means = np.array(global_Q_means)
        global_Q_stds = np.array(global_Q_stds)

        env.reset_to_m_star()
        agent = Agent(args.N, args.d, args.gamma)
        agent.reset()
        losses = []

        curr_state = next_state = env.global_state_agent
        curr_action = next_action = env.global_action_agent
        curr_reward = next_reward = 0

        def compute_loss(Q_values_0, Q_values_1):
            loss = 0
            for i, config in enumerate(global_configs):
                Qvalue_Approx = get_global_reward(config/args.N, Index_POLICY, args.alpha, Q_values_1, Q_values_0) * args.N
                loss += np.abs(Qvalue_Approx - global_Q_means[i])
            return loss/len(global_configs)
        
        # Local Q-learning
        for i in range(args.num_sims):
            next_state, next_action, next_reward = env.step_by_agent(Index_POLICY)
            if i >= 0.1 * args.num_sims:
                agent.update(curr_state[0], curr_action[0], curr_reward[0], next_state[0], next_action[0]) # agent 0
            curr_state = next_state
            curr_action = next_action
            curr_reward = next_reward

            if i%10 == 0 and i > 0.1 * args.num_sims:
                Q_values = agent.get_Q_values()
                Q_values_0 = Q_values[::2]
                Q_values_1 = Q_values[1::2]
                loss = compute_loss(Q_values_0, Q_values_1)
                print(f"At round {i}, Loss: {loss}")
                losses.append(loss)
        
        if args.store:
            np.save('learning_error/decomp_losses.npy', np.array(losses))
            print(f"Saved losses to learning_error/decomp_losses.npy")
            

    if args.plot:
        mean_loss = []
        std_loss = []
        Q_value = []
        mean_ent = []
        std_ent = []
        true_entangle = [] # Estimated using longer simulations
        # plot_results(mean_loss, std_loss)
        # plot_relative_loss(mean_loss, std_loss, Q_value)
        # plot_results_entanglement(mean_ent, std_ent)
        # est_entangle = np.load('learning_error/entanglement_losses.npy')
        # est_entangle *= args.N
        # decomp_loss = np.load('learning_error/decomp_losses.npy')
        # plot_learning_curve(est_entangle, decomp_loss, true_entangle)
