import numpy as np
import tensorflow as tf
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from spark_env.env import Environment

from actor_agent import ActorAgent
from spark_env.canvas import *
from param import *
from utils import *
from log_obs import log_observation, convert_dict_for_json, CustomJSONEncoder
import json
# Import all new heuristic algorithms

# create result folder
if not os.path.exists(args.result_folder):
    os.makedirs(args.result_folder)

# create log folder
if not os.path.exists('log'):
    os.makedirs('log')

# tensorflow seeding
tf.set_random_seed(args.seed)

# set up environment
env = Environment()

# set up agents
agents = {}

for scheme in args.test_schemes:
    if scheme == 'learn':
        sess = tf.Session()
        agents[scheme] = ActorAgent(
            sess, args.node_input_dim, args.job_input_dim,
            args.hid_dims, args.output_dim, args.max_depth,
            range(1, args.exec_cap + 1))


# store info for all schemes
all_total_reward = {}
for scheme in args.test_schemes:
    all_total_reward[scheme] = []

# Get num_stream_dags from environment
num_stream_dags =args.num_stream_dags

for exp in range(args.num_exp):
    print('Experiment ' + str(exp + 1) + ' of ' + str(args.num_exp))

    for scheme in args.test_schemes:
        # reset environment with seed
        env.seed(args.num_ep + exp)
        env.reset()

        # load an agent
        agent = agents[scheme]

        # start experiment
        obs = env.observe()

        total_reward = 0
        done = False
        
        # Create log file path with scheme and configuration parameters
        log_path = f'log/scheme_{scheme}_{args.exec_cap}_{args.num_init_dags}_{num_stream_dags}_{args.first_job_query_size}_{args.first_job_query_idx}_{args.second_job_query_size}_{args.second_job_query_idx}.log'
        log_f = open(log_path, 'w')

        while not done:
            # First get the action
            node, use_exec = agent.get_action(obs)
            
            # Execute environment step and get new state, reward and completion flag
            new_obs, reward, done = env.step(node, use_exec)
            
            # Generate observation data based on conditions
            obs_data = log_observation(obs, detailed_task_info=False, reward=reward)
            
            # Record as long as there is an action, regardless of whether observation data is empty
            if node is not None:
                to_record = {'obs': obs_data}  # Record even if it's an empty dictionary
                to_record['action'] = {"node": (node.idx, node.job_dag.name), 'use_exec': use_exec}
                to_record['reward'] = reward
                if obs_data == {}:
                    to_record =  {'action':  (node.idx, node.job_dag.name), 'use_exec': use_exec} 
                log_f.write(json.dumps(convert_dict_for_json(to_record), indent=2, cls=CustomJSONEncoder)+'\n')
            else:
                to_record = {'obs': obs_data}  # Record even if it's an empty dictionary
                to_record['action'] = None
                to_record['reward'] = reward
                if obs_data == {}:
                    to_record =  {'action': None} 
                log_f.write(json.dumps(convert_dict_for_json(to_record), indent=2, cls=CustomJSONEncoder)+'\n')
            # Update observation and total reward
            obs = new_obs
            total_reward += reward
                    
          
        log_f.close()

        all_total_reward[scheme].append(total_reward)

        # if args.canvs_visualization:
        #     visualize_dag_time_save_pdf(
        #         env.finished_job_dags, env.executors,
        #         args.result_folder + 'visualization_exp_' + \
        #         str(exp) + '_scheme_' + scheme + \
        #         '.png', plot_type='app')
        # else:
        #     visualize_executor_usage(env.finished_job_dags,
        #         args.result_folder + 'visualization_exp_' + \
        #         str(exp) + '_scheme_' + scheme + '.png')


    # # plot CDF of performance
    # fig = plt.figure()
    # ax = fig.add_subplot(111)

    # for scheme in args.test_schemes:
    #     x, y = compute_CDF(all_total_reward[scheme])
    #     ax.plot(x, y)

    # plt.xlabel('Total reward')
    # plt.ylabel('CDF')
    # plt.legend(args.test_schemes)
    # fig.savefig(args.result_folder + 'total_reward.png')

    # plt.close(fig)

# Print average rewards for each scheme
# print("\nAverage rewards per scheme:")
for scheme in args.test_schemes:
    avg_reward = np.mean(all_total_reward[scheme])
    print(f"{scheme}: {avg_reward:.4f}")