import os
import pickle
import numpy as np
from config import *
from central_agent import compute_weighted_reward, get_frequency_features
import torch


class Generator:
    def __init__(self, cnt_round, cnt_gen, dic_path, dic_agent_conf, agent, central_agent, env):
        self.cnt_round = cnt_round
        self.cnt_gen = cnt_gen
        self.dic_path = dic_path
        self.dic_agent_conf = dic_agent_conf
        self.env = env

        self.nb_agents = 3

        self.path_to_log = os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"], "train_logs")
        if not os.path.exists(self.path_to_log):
            os.makedirs(self.path_to_log)

        self.shared_agent = agent
        self.central_agent = central_agent

        self.history_dir = os.path.join(self.dic_path["PATH_TO_WORK_DIRECTORY"], "history")
        if not os.path.exists(self.history_dir):
            os.makedirs(self.history_dir)

    def generate(self):
        done = False
        state_list,_ = self.env.reset()
        step_num = 0
        transition_dict = {
            'states': [],
            'features': [],
            'actions': [],
            'next_states': [],
            'rewards': [],
            'dones': []
        }
        transition_central_dict = {
            'states': [],
            'actions': [],
            'next_states': [],
            'rewards': [],
            'dones': []
        }

        self._clear_history_file()

        while not done:
            history_file_path = os.path.join(self.history_dir, "current_samples.pkl")
            selected_features, central_actions, query, features = self.central_agent.choose_history_length(
                history_file_path, step_num)
            self._log_state(state_list)

            action, agent_attention_inputs = self.shared_agent.choose_action_with_features(state_list,
                                                                                           selected_features)
            next_state_list, reward, done,_ = self.env.step(action)
            reward_list = []
            for i in range(self.nb_agents):
                reward_list.append(reward)
            weights = compute_weighted_reward(agent_attention_inputs, query)
            central_reward = np.dot(weights, reward_list)

            next_central_states = get_frequency_features(history_file_path, step_num + 1)['features']

            for i in range(self.nb_agents):
                transition_dict['states'].append(state_list[i])
                transition_dict['features'].append(selected_features)
                transition_dict['actions'].append(action[i])
                transition_dict['next_states'].append(next_state_list[i])
                transition_dict['rewards'].append(reward_list[i])
                transition_dict['dones'].append(done)

            transition_central_dict['states'].append(features)
            transition_central_dict['actions'].append(central_actions)
            transition_central_dict['next_states'].append(next_central_states)
            transition_central_dict['rewards'].append(central_reward)
            transition_central_dict['dones'].append(done)

            state_list = next_state_list
            print(step_num)
            step_num += 1

        return transition_dict, transition_central_dict

    def _clear_history_file(self):
        history_file_path = os.path.join(self.history_dir, "current_samples.pkl")
        with open(history_file_path, 'wb') as f:
            pass

    def _log_state(self, state_list):
        history_file_path = os.path.join(self.history_dir, "current_samples.pkl")
        with open(history_file_path, 'ab') as f:
            pickle.dump({"state": state_list}, f)

    def _log_mean_reward(self, total_reward, mean_reward):
        reward_file_path = os.path.join(self.path_to_log, "mean_rewards.pkl")
        with open(reward_file_path, 'ab') as f:
            pickle.dump({"round": self.cnt_round, "total_reward": total_reward, "mean_reward": mean_reward}, f)