from collections import defaultdict
import numpy as np
import random
from scipy.special import softmax

from agents.agent import Agent


class PRQL(Agent):
    
    def __init__(self, alpha, beta, omega, tau, eta, *args,
                 noise_init=lambda size: np.random.uniform(low=-0.01, high=0.01, size=size), **kwargs):
        super(PRQL, self).__init__(*args, **kwargs)
        self.alpha = alpha
        self.beta = beta
        self.tau = tau
        self.eta = eta
        self.alpha2 = self.alpha * self.beta
        self.omega = omega
        self.noise_init = noise_init
        self.key = 'prql'
    
    def reset(self):
        super(PRQL, self).reset()
        self.Qs = []
        self.Cs = []
    
    def add_training_task(self, task):
        super(PRQL, self).add_training_task(task)
        newQ = defaultdict(lambda: self.noise_init((self.n_actions,)))
        newC = defaultdict(lambda: np.zeros((self.n_actions,)))
        self.Qs.append(newQ)
        self.Cs.append(newC)
        self.score = np.zeros((self.n_tasks,), dtype=float)
        self.used = np.zeros((self.n_tasks,), dtype=int)
    
    def set_active_training_task(self, index):
        super(PRQL, self).set_active_training_task(index)
        self.curr_score = 0.
        self.c = index
    
    def train_agent(self, s, s_enc, a, r, s1, s1_enc, gamma, noise):
        
        # update Q-values
        Q = self.Qs[self.task_index]
        target = r + gamma * np.max(Q[s1])
        delta = target - Q[s][a]        
        Q[s][a] += self.alpha * delta
        
        # update controllability
        C = self.Cs[self.task_index]
        target2 = -abs(delta)
        delta2 = target2 - C[s][a]
        C[s][a] += self.alpha2 * delta2
    
    def get_Q_values(self, s, s_enc, index):
        return self.Qs[index][s] + self.omega * self.Cs[index][s]
    
    def next_sample(self, viewer=None, n_view_ev=None):
        
        # start a new episode
        if self.new_episode:
            self.s = self.active_task.initialize()
            self.s_enc = self.encoding(self.s)
            self.new_episode = False
            self.episode += 1
            self.steps_since_last_episode = 0
            self.episode_fails = self.fails_since_last_episode
            self.fails_since_last_episode = 0
            self.episode_reward = self.reward_since_last_episode
            self.reward_since_last_episode = 0.   
            if self.episode > 1:
                self.episode_reward_hist.append(self.episode_reward)  
            
            # PRQL specific
            c = self.c
            if self.episode > 1:
                self.score[c] = (self.score[c] * self.used[c] + self.curr_score) / (self.used[c] + 1)
                p = softmax(self.tau * self.score)
                self.c = np.random.choice(np.arange(self.score.size), p=p)
                self.used[c] += 1
            self.curr_score = 0.
            
        # PRQL specific action selection
        if self.task_index != self.c:
            use_prev_policy = random.random() <= self.eta
        else:
            use_prev_policy = False
        if use_prev_policy:
            q = self.get_Q_values(self.s, self.s_enc, self.c)
            a = np.argmax(q)
        else:
            q = self.get_Q_values(self.s, self.s_enc, self.task_index)
            a = self._epsilon_greedy(q)
        
        # take action a and observe reward r and next state s'
        s1, r, terminal, noise = self.active_task.transition(a)
        s1_enc = self.encoding(s1)
        if terminal:
            gamma = 0.
            self.new_episode = True
        else:
            gamma = self.gamma
        
        # train the agent
        self.train_agent(self.s, self.s_enc, a, r, s1, s1_enc, gamma, noise)
        
        # update counters
        self.s, self.s_enc = s1, s1_enc
        self.steps += 1
        self.global_steps += 1
        self.reward += r
        self.steps_since_last_episode += 1
        self.reward_since_last_episode += r
        self.cum_reward += r
        
        if self.steps_since_last_episode >= self.T:
            self.new_episode = True
            
        if noise[0]:
            self.fails += 1
            self.cum_fails += 1
            self.fails_since_last_episode += 1
            
        if self.steps % self.save_ev == 0:
            self.reward_hist.append(self.reward)
            self.fails_hist.append(self.fails)
            self.cum_reward_hist.append(self.cum_reward)
            self.cum_fails_hist.append(self.cum_fails)
        
        # viewing
        if viewer is not None and self.episode % n_view_ev == 0:
            viewer.update()
        
        # printing
        if self.steps % self.print_ev == 0:
            print('\t'.join(self.get_progress_strings()))
    
        # PRQL specific
        self.curr_score += r
