# -*- coding: UTF-8 -*-
import numpy as np
import random

from agents.agent import Agent


class DQN_GPI(Agent):
    
    def __init__(self, model_lambda, buffer, *args, target_update_ev=1000, test_epsilon=0.03, include_target=True, **kwargs):
        """
        Creates a new DQN agent that supports universal value function approximation (UVFA) and generalized policy improvement (GPI). #? Maybe not meaningful to support UVFA if GPI is also used
        
        Parameters
        ----------
        model_lambda : function
            returns a keras Model instance
        buffer : ReplayBuffer
            a replay buffer that implements randomized experience replay
        target_update_ev : integer
            how often to update the target network (defaults to 1000)
        test_epsilon : float
            the exploration parameter for epsilon greedy used during testing 
            (defaults to 0.03 as in the paper)
        include_target : bool
            whether model input consist of goal locations of tasks
            (defaults to 'True' as in the paper)
        """
        super(DQN_GPI, self).__init__(*args, **kwargs)
        self.model_lambda = model_lambda
        self.buffer = buffer
        self.target_update_ev = target_update_ev
        self.test_epsilon = test_epsilon
        self.include_target = include_target # ADDED flexibility to use DQN without tagrget locations
        self.Qs = [] # ADDED to implement GPI for DQN
        self.target_Qs = []
    
    def reset(self):
        Agent.reset(self)
        # self.Q = self.model_lambda(self.include_target)
        # self.target_Q = self.model_lambda(self.include_target)
        # self.target_Q.set_weights(self.Q.get_weights())
        self.buffer.reset()
        self.updates_since_target_updated = 0

    def get_GPI_preds(self, s, s_enc):
        GPI_preds = []
        for i, Q_ in enumerate(self.Qs):
            GPI_preds.append(Q_.predict_on_batch(s_enc))        
        
    def get_Q_values(self, s, s_enc):
        GPI_preds = []
        # iterate over each task to get Q values
        for i, Q_ in enumerate(self.Qs):
            GPI_preds.append(Q_.predict_on_batch(s_enc))
        return np.max(np.array(GPI_preds), axis=0) # shape (n_batch, n_action)
    
    def update_agent(self, states, actions, rewards, next_states, next_actions, gammas, indices, task_index):

        targets = self.Qs[task_index].predict_on_batch(states)
        targets[indices, actions] = rewards + \
            gammas * self.target_Qs[task_index].predict_on_batch(next_states)[indices, next_actions]
        self.Qs[task_index].train_on_batch(states, targets)

    
    def train_agent(self, s, s_enc, a, r, s1, s1_enc, gamma):
        
        # remember this experience
        self.buffer.append(s_enc, a, r, s1_enc, gamma)
        
        # sample experience at random
        batch = self.buffer.replay()
        if batch is None: return
        states, actions, rewards, next_states, gammas = batch
        n_batch = self.buffer.n_batch
        indices = np.arange(n_batch)
        rewards = rewards.flatten()

        # main update 
        # for each task, set the Q value network targets of the unchaged actions as same same the Q predictions (hence 0 loss), 
        # and set network targets for the chosen actions in the usual Q learning update target.
        next_actions = np.argmax(self.get_Q_values(0, next_states), axis=1) # 0 as a dummy argument for next states

        for task_index in range(self.n_tasks):
            self.update_agent(states, actions, rewards, next_states, next_actions, gammas, indices, task_index)
        
        # target update
        self.updates_since_target_updated += 1
        if self.updates_since_target_updated >= self.target_update_ev:
            for task_index in range(self.n_tasks):
                self.target_Qs[task_index].set_weights(self.Qs[task_index].get_weights())
                self.updates_since_target_updated = 0
    
    def train(self, train_tasks, n_samples, viewers=None, n_view_ev=None, test_tasks=[], n_test_ev=1000, test_task_idx=None):
        if viewers is None: 
            viewers = [None] * len(train_tasks)
            
        # add tasks
        self.reset()
        for train_task in train_tasks:
            self.add_training_task(train_task)
            
        # train each one
        return_data = []
        for index, (train_task, viewer) in enumerate(zip(train_tasks, viewers)):
            self.set_active_training_task(index)
            for t in range(n_samples):
                
                # train
                self.next_sample(viewer, n_view_ev)
                
                # test
                if t % n_test_ev == 0:
                    if not test_task_idx: # ADDED flexibility to test on specific task or average
                        Rs = []
                        for test_task in test_tasks:
                            R = self.test_agent(test_task)
                            Rs.append(R)
                        avg_R = np.mean(Rs)
                        return_data.append(avg_R)
                        print('test performance: {}'.format('\t'.join(map('{:.4f}'.format, Rs))))
                    else:
                        test_task = test_tasks[test_task_idx]
                        R = self.test_agent(test_task)
                        return_data.append(R)
                        print('test performance: {}'.format(R))
        return return_data
    
    def get_test_action(self, s_enc):
        if random.random() <= self.test_epsilon:
            a = random.randrange(self.n_actions)
        else:
            q = self.get_Q_values(s_enc, s_enc)
            a = np.argmax(q)
        return a
            
    def test_agent(self, task):
        R = 0.
        s = task.initialize()
        s_enc = self.encoding(s)
        for _ in range(self.T):
            a = self.get_test_action(s_enc)
            s1, r, done = task.transition(a)
            s1_enc = self.encoding(s1)
            s, s_enc = s1, s1_enc
            R += r
            if done:
                break
        return R

    def add_training_task(self, task): # ADDED to override default method so that GPI can be implemented
        """
        Adds a training task to be trained by the agent.
        """
        self.tasks.append(task)   
        self.n_tasks = len(self.tasks)  
        Q = self.model_lambda(self.include_target)
        target_Q = self.model_lambda(self.include_target)
        target_Q.set_weights(Q.get_weights())
        self.Qs.append(Q)
        self.target_Qs.append(target_Q)
        self.phis.append(task.features) # task.features is a function for reacher task that generate phi at each state. It is independent of the task target location (i.e. same for all tasks)               
        if self.n_tasks == 1:
            self.n_actions = task.action_count()
            self.n_features = task.feature_dim()
            if self.encoding == 'task':
                self.encoding = task.encode
