"""
Created on comparing the mean episode reward of different decision tree algorithms using the RL maze example
1. Decision Tree with SVM hyperplane
2. CART Decision Tree
3. Random Forest
4. GDBT
5. Extra Trees

"""

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

from QLearning import QLearningTable
from joblib import load

from arguments.args import get_args
from arguments.utils import make_env

import argparse
import os


args = get_args()

import random

# Set the random seed for reproducibility
random_seed = args.random_seed  # You can choose any seed value
np.random.seed(random_seed)
random.seed(random_seed)



# Construct the directory path including the scenario name
model_dir_RGMDT = os.path.join('outputs', 'Step2_RGMDTModels',
                         'MaxDepth_{max_depth_RGMDT}'.format(max_depth_RGMDT=args.max_depth_RGMDT))

# Load the trained baseline models
RGMDT_agent_1 = load(os.path.join(model_dir_RGMDT, 'DT_agent_1_level_2.joblib'))
RGMDT_agent_2 = load(os.path.join(model_dir_RGMDT, 'DT_agent_2_level_2.joblib'))
RGMDT_agent_3 = load(os.path.join(model_dir_RGMDT, 'DT_agent_3_level_2.joblib'))


episode_total_reward = []
Total_Rewards = 0

mean_episode_reward_list = []
task_completed_step = []

evaluate_episodes = args.evaluate_episodes
evaluate_episode_len = args.evaluate_episode_len

def run_maze_with_model():
    mean_episode_rewards = 0
    for episode in range(evaluate_episodes):
        current_episode_reward = 0
        observation = env.reset()
        print("Start epsiode", episode)

        for s in range(evaluate_episode_len):
            # fresh env
            env.render()
            # Reshape the observation to have shape (1, n_features)
            # Assuming `observation` here is a NumPy array; if not, you may need to convert it
            observation_reshaped_agent_1 = observation[0]
            observation_reshaped_agent_2 = observation[1]
            observation_reshaped_agent_3 = observation[2]

            agent_1_input_observation = observation_reshaped_agent_1[0]
            agent_2_input_observation = observation_reshaped_agent_2[0]
            agent_3_input_observation = observation_reshaped_agent_3[0]



            # Now, use the reshaped observation for prediction
            action_0 = model_0.predict(agent_1_input_observation.reshape(1, -1))
            action_1 = model_1.predict(agent_2_input_observation.reshape(1, -1))
            action_2 = model_2.predict(agent_3_input_observation.reshape(1, -1))


            action_n = []
            action_n.append(action_0)
            action_n.append(action_1)
            action_n.append(action_2)



            observation_, reward, done = env.step(action_n)  # Execute the action
            current_episode_reward += reward
            print('current episode reward:', current_episode_reward)
            # swap observation
            observation = observation_

            # break while loop when end of this episode
            if done:
                print("task achieved: YES!!!!!!!!!!!!")
                print("task achieved after ", s, " steps")
                task_completed_step.append(s)
                break
            s += 1

        print('current episode reward when this episode ends:', current_episode_reward)
        episode_total_reward.append(current_episode_reward)
        print('Total reward List after', episode, "episode is:", episode_total_reward)
        Total_Rewards = np.sum(episode_total_reward)
        mean_episode_reward = Total_Rewards / (episode + 1)
        mean_episode_reward_list.append(mean_episode_reward)

        print("Mean Episode Rewards after", episode, "episode is:", mean_episode_reward)
        print("Mean Episode Rewards List:", mean_episode_reward_list)
        print("End Episode :", episode)
        print("\n")
        episode += 1
    print("training process over mean episode rewards:",
          mean_episode_reward)  # average rewards over 100 episodes without noise

    # end of game
    print('game over')
    # Construct the directory path including the scenario name
    save_dir = os.path.join('outputs', 'Step3_Evaluate_RGMDT', 'Max_Depth_{max_depth_RGMDT}'.format(max_depth_RGMDT=args.max_depth_RGMDT),
                            "RGMDT_MeanEpisodeRewardList_RandomSeed_{RandomSeed}".format(RandomSeed=args.random_seed))
    os.makedirs(save_dir, exist_ok=True)

    pd.DataFrame(episode_total_reward).to_csv(
                os.path.join(save_dir, "EpisodeTotalReward_{episode}_{step}.csv".format(
                    episode=evaluate_episodes,
                    step=evaluate_episode_len)))
    pd.DataFrame(mean_episode_reward_list).to_csv(
                os.path.join(save_dir, "MeanEpisodeReward_{episode}_{step}.csv".format(
                    episode=evaluate_episodes,
                    step=evaluate_episode_len)))
    pd.DataFrame(task_completed_step).to_csv(
                os.path.join(save_dir, "TaskCompleteStep_{episode}_{step}.csv".format(
                    episode=evaluate_episodes,
                    step=evaluate_episode_len)))

    #env.destroy()


if __name__ == "__main__":
    args = get_args()
    env, args = make_env(args)

    model_0 = RGMDT_agent_1
    model_1 = RGMDT_agent_2
    model_2 = RGMDT_agent_3



    #env.after(100, run_maze_with_model())
    #env.mainloop()
    run_maze_with_model()

    plt.plot(np.arange(len(episode_total_reward)), episode_total_reward)
    plt.xlabel('Episode')
    plt.ylabel('Total reward')
    plt.show()

    plt.plot(np.arange(len(task_completed_step)), task_completed_step)
    plt.xlabel('Episode')
    plt.ylabel('Task Completed Steps in current episode')
    plt.show()

    plt.plot(np.arange(len(mean_episode_reward_list)), mean_episode_reward_list)
    plt.xlabel('Episode')
    plt.ylabel('Mean Episode reward')
    plt.show()