import numpy as np
import time
import collections
import pickle

from scipy.special import softmax
from loggers.logger import Logger

"""
    This class wraps the main loop and interventional loop of the agent. 
"""

class Runner:
    def __init__(self, num_views):
        self.explore = True

        self.total_ep_successful = 0
        self.last_ep_successful = 0

        self.agent = None
        self.env = None

        self.tc_counts = [{}] * num_views
        self.num_views = num_views
        self.max_values = [0] * num_views

        self.mean_rewards = []

        self.logger = Logger()

    def change_agent(self, agent):
        self.agent = agent

    def inc_tc_count(self, state):
        # Count occurence of states in multiple views to weight value functions
        for i in range(len(self.tc_counts)):
            if state[i] not in self.tc_counts[i]:
                self.tc_counts[i][state[i]] = 0
            self.tc_counts[i][state[i]] += 1

            # Save the max count for each view in order to normalize counts
            if self.tc_counts[i][state[i]] > self.max_values[i]:
                self.max_values[i] = self.tc_counts[i][state[i]]

    def calc_alphas(self, state):
        alphas = [0] * self.num_views

        for i in range(len(self.tc_counts)):
            state_value = self.tc_counts[i][state[i]]
            normalized_value = 1 - (state_value / self.max_values[i])
            alphas[i] = normalized_value

        return alphas

    def explore_and_navigate(self, log_name, total_steps, total_eps, total_ep_successful, alpha, gamma, epsilon, rho,
                             exp_thresh, run_num):

        self.total_ep_successful = total_ep_successful

        ep_successful = 0
        success = False
        global_steps = 0

        # EPISODE BEGIN
        done = 0

        # Reset agents history
        self.agent.clear_batch()

        # Reset environment (start)
        state, obs = self.env.reset()

        # STEPS
        while np.sum(done) == 0:
            # Increase count for determining weight of view
            self.inc_tc_count(state)

            # Sample actions
            if np.random.uniform(0, 1) < epsilon and self.explore:
                action = np.random.randint(self.env.get_action_space())
            else:
                # TODO switch between hierarchical and normal
                #b = np.array(self.agent.evaluate(state, [1.0, 0, 0, 0, 0]))
                b = np.array(self.agent.evaluate(state, softmax(self.calc_alphas(state))))
                action = np.argmax(b)

            # Execute action in environment
            next_state, reward, done, info, next_obs = self.env.step(action)

            # Save transition to history
            self.agent.store_transition((state, action, reward, next_state, done, info))
            self.agent.update(global_steps, total_ep_successful, alpha, gamma, rho)

            # Update state and obs to successors
            state = next_state
            obs = next_obs
            global_steps = global_steps + 1

        self.mean_rewards.append(reward[0])

        self.agent.store_transition((state, action, reward, [None] * self.num_views, [True] * self.num_views, info))

        # (Log performance to file)
        if log_name is not None and total_eps % 10 == 0:
            self.logger.log(log_name, total_steps, reward[0], total_eps)

        # Stop exploring after threshold
        if total_ep_successful > exp_thresh:
            self.explore = False
        # Continue exploration if convergence not reached at defined exploration threshold
        elif total_ep_successful > exp_thresh and reward[0] <= 0:
            self.explore = True

        if reward[0] > 0:
            ep_successful += 1

        if total_eps % 10 == 0:

            print(self.mean_rewards)
            print("UPDATE EVERY 10 EPSODES - EPISODES SUCCESSFUL: ", str(total_ep_successful-self.last_ep_successful) +
                  " out of 10" + "  |  " +
                  "MEAN REWARD: " + str(np.mean(self.mean_rewards)) + "  |  " +
                  "TOTAL STEPS: " + str(total_steps) + "  |  " +
                  "LOG PATH: " + str(log_name))
            self.mean_rewards = []
            self.last_ep_successful = total_ep_successful
        return global_steps, ep_successful
