import numpy as np
import copy
from utils.custom_metric_logger import CustomMetricLogger


class IndividualReward():
    def __init__(self, n_agents, args):
        self.n_agents = n_agents
        self.collected_samples = {f"agent_{agent_id}_reward": []
                                  for agent_id in range(self.n_agents)}
        self.logger = CustomMetricLogger(args, "individual_reward")
        self._update = False

    def store_rewards(self, env, actions,t_env, test_mode):
        if test_mode:
            env_copy = copy.deepcopy(env)
            _, _, _, reward_list = env_copy.step(actions)
            for agent_id in range(self.n_agents):
                self.collected_samples[f"agent_{agent_id}_reward"].append(
                    reward_list[agent_id])
            if not (self._update):
                self._update = True
        elif self._update:
            self._update = False
            write_data = {}
            write_data[t_env] = self.collected_samples

            self.logger.write(write_data)
            self.collected_samples = {f"agent_{agent_id}_reward": []
                                      for agent_id in range(self.n_agents)}
