import numpy as np
import random

from agents.agent import Agent


class SFQL(Agent):
    
    def __init__(self, lookup_table, *args, use_gpi=True, **kwargs):
        super(SFQL, self).__init__(*args, **kwargs)
        self.sf = lookup_table
        self.use_gpi = use_gpi
        self.key = 'sfql'
        
    def get_Q_values(self, s, s_enc):
        q, self.c = self.sf.GPI(s_enc, self.task_index, update_counters=self.use_gpi)
        if not self.use_gpi:
            self.c = self.task_index
        return q[:, self.c,:]
    
    def train_agent(self, s, s_enc, a, r, s1, s1_enc, gamma, noise):
        
        # update w
        t = self.task_index
        phi = self.phi(s, a, s1, noise)
        self.sf.update_reward(phi, r, t)
        
        # update SF for the current task t
        if self.use_gpi:
            q1, _ = self.sf.GPI(s1_enc, t)
            q1 = np.max(q1[0,:,:], axis=0)
        else:
            q1 = self.sf.GPE(s1_enc, t, t)[0,:]
        a1 = np.argmax(q1)
        transitions = [(s_enc, a, phi, s1_enc, a1, gamma)]
        self.sf.update_successor(transitions, t)
        
        # update SF for source task c
        if self.c != t:
            q1 = self.sf.GPE(s1_enc, self.c, self.c)
            next_action = np.argmax(q1)
            transitions = [(s_enc, a, phi, s1_enc, next_action, gamma)]
            self.sf.update_successor(transitions, self.c)
        
    def next_sample(self, viewer=None, n_view_ev=None):
        # if self.s is not None:
        #    self.visits[self.s[0][0], self.s[0][1]] += 1
        super(SFQL, self).next_sample(viewer, n_view_ev)
        
    def reset(self):
        super(SFQL, self).reset()
        self.sf.reset()
        
    def add_training_task(self, task):
        super(SFQL, self).add_training_task(task)
        self.sf.add_training_task(task, -1)
        self.visits = np.zeros((13, 13))
    
    def get_progress_strings(self):
        sample_str, reward_str = super(SFQL, self).get_progress_strings()
        gpi_percent = self.sf.GPI_usage_percent(self.task_index)
        w_error = np.linalg.norm(self.sf.fit_w[self.task_index] - self.sf.true_w[self.task_index])
        gpi_str = 'GPI% \t {:.4f} \t WERR \t {:.4f}'.format(gpi_percent, w_error)
        return sample_str, reward_str, gpi_str
    
    def get_test_action(self, s_enc, w):
        if random.random() <= self.epsilon:
            a = random.randrange(self.n_actions)
        else:
            q, c = self.sf.GPI_w(s_enc, w)
            q = q[:, c,:]
            a = np.argmax(q)
        return a
            
    def rollout(self, w):
        s = self.active_task.initialize()
        visits = np.zeros((13, 13))
        taboo = set()
        taboo.add(s)
        states = [s]
        for _ in range(self.T):
            a = self.get_test_action(s, w)
            s, r, done, noise = self.active_task.transition(a)
            visits[s[0][0], s[0][1]] += 1
            if s in taboo:
                break
            taboo.add(s)
            states.append(s)
            if done:
                break
        return visits, states
    
    def rollouts(self, w, n_eps=1):
        visits = np.zeros((13, 13))
        all_states = []
        for _ in range(n_eps):
            new_visits, new_states = self.rollout(w)
            visits += new_visits / float(n_eps)
            all_states.extend(new_states)
        return visits, all_states
