import os
import pickle


class Test:
    def __init__(self, cnt_gen, dic_path, dic_agent_conf, agent, central_agent, env):
        self.cnt_gen = cnt_gen
        self.dic_path = dic_path
        self.dic_agent_conf = dic_agent_conf
        self.env = env

        self.nb_agents = 3

        self.path_to_log = os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"], "train_logs")
        if not os.path.exists(self.path_to_log):
            os.makedirs(self.path_to_log)

        self.shared_agent = agent
        self.central_agent = central_agent

        self.history_dir = os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"], "history")
        if not os.path.exists(self.history_dir):
            os.makedirs(self.history_dir)

    def test(self):
        done = False
        state_list,_ = self.env.reset()
        total_reward = 0
        step_num = 0

        self._clear_history_file()

        while not done:
            history_file_path = os.path.join(self.history_dir, "current_samples.pkl")
            selected_features, central_actions, query, features = self.central_agent.choose_history_length(
                history_file_path, step_num)
            self._log_state(state_list)

            action, agent_attention_inputs = self.shared_agent.choose_action_with_features(state_list,selected_features)

            next_state_list, reward, done, _ = self.env.step(action)

            total_reward += reward

            state_list = next_state_list
            step_num += 1

        if total_reward > 0:
            total_reward = 1
        else:
            total_reward = 0
        return total_reward


    def _clear_history_file(self):
        history_file_path = os.path.join(self.history_dir, "current_samples.pkl")
        with open(history_file_path, 'wb') as f:
            pass

    def _log_state(self, state_list):
        history_file_path = os.path.join(self.history_dir, "current_samples.pkl")
        with open(history_file_path, 'ab') as f:
            pickle.dump({"state": state_list}, f)
