"""
    This class represents the agent. It consists of n cycle-functions and n q-functions represented in q_function.py and cycle_function.py.
"""

from q_function import QFunction

import numpy as np


class Agent:
    def __init__(self, num_levels, num_actions):
        self.num_levels = num_levels
        self.num_actions = num_actions

        self.qs = []
        self.batch = [[]] * num_levels

        for i in range(num_levels):
            self.qs.append(QFunction(num_actions))

    """
        Evaluate the agent at given state and return action.
    """
    def evaluate(self, state, alphas, beta=1.0):
        action_arr = [0] * self.num_actions
        action_arr = np.array(action_arr)

        for i in range(self.num_levels):
            q_vec = self.qs[i].qtable.get_value(state[i], None)
            action_arr = action_arr + (alphas[i] * q_vec)

        return action_arr


    """
        Store transition to current episodes batch which agent uses to determine cycles.
    """
    def store_transition(self, transition):
        for i in range(self.num_levels):
            state = transition[0][i]
            action = transition[1]
            reward = transition[2][i]
            next_state = transition[3][i]
            done = transition[4][i]
            info = transition[5][i]

            self.batch[i].append((state, action, reward, next_state, done, info))

    """
        Empty history after every episode.
    """
    def clear_batch(self):
        for i in range(self.num_levels):
            self.batch[i] = []


    """
        Update q and cycle functions of agent for every view.
    """
    def update(self, global_steps, successful_reward, alpha, gamma, rho):
        for i in range(self.num_levels):
            self.qs[i].update(self.batch[i], alpha, gamma, rho)

    def save_loop_counter(self, steps):
        for i in range(self.num_levels):
            self.cs[i].loop_counter.save("loop_hashtable_s"+ str(i)+".p")

    """
       Get all q-tables for transfer learning
    """
    def get_transfer_table(self):
        ttables = []
        for i in range(self.num_levels):
            ttables.append(self.qs[i].ttable)

        return ttables

    """
       When loading model set q-tables for transfer learning
    """
    def set_transfer_table(self, ttables):
        for i in range(self.num_levels):
            self.qs[i].qtable = ttables[i]
            
            
