import numpy as np
import random
import tensorflow as tf

from agents.agent import Agent


class BaseC51(Agent):
    
    def __init__(self, model_lambda, buffer, v_min, v_max, target_update_ev, *args,
                 n_atoms=51, test_epsilon=0.03, test_frequency=10000, test_rollouts=1, **kwargs):
        super(BaseC51, self).__init__(*args, **kwargs)
        self.model_lambda = model_lambda
        self.buffer = buffer
        self.v_min = v_min
        self.v_max = v_max
        self.target_update_ev = target_update_ev
        self.n_atoms = n_atoms
        self.test_epsilon = test_epsilon
        self.test_frequency = test_frequency
        self.test_rollouts = test_rollouts
        self.key = 'base'
        
        # atoms
        self.delta_z = (self.v_max - self.v_min) / float(self.n_atoms - 1)
        self.z = np.array([self.v_min + i * self.delta_z for i in range(self.n_atoms)])  # [n_atoms]
        
    def reset(self):
        super(BaseC51, self).reset()
        self.buffer.reset()
        self.episode_reward_hist_per_task = []
        self.episode_fails_hist_per_task = []
    
    def add_training_task(self, task):
        super(BaseC51, self).add_training_task(task)
        tf.keras.backend.clear_session()
        self.buffer.reset()
        self.Q = self.model_lambda()
        self.target_Q = self.model_lambda()
        self.target_Q.set_weights(self.Q.get_weights())
        self.updates_since_target_updated = 0
        self.episode_reward_hist_per_task.append([])
        self.episode_fails_hist_per_task.append([])
    
    def next_sample(self, viewer=None, n_view_ev=None):
        super(BaseC51, self).next_sample(viewer, n_view_ev)
        if self.steps % self.save_ev == 0:
            self.episode_reward_hist_per_task[self.task_index].append(self.episode_reward)
            self.episode_fails_hist_per_task[self.task_index].append(self.episode_fails)
            
    def get_Q_values(self, s, s_enc):
        p = self.Q.predict_on_batch(s_enc)  # [n_batch, n_actions, n_atoms]
        atoms = np.expand_dims(self.z, axis=(0, 1))
        return np.sum(p * atoms, axis=-1)
        
    def train_agent(self, s, s_enc, a, r, s1, s1_enc, gamma, failed):
        
        # remember this experience
        self.buffer.append(s_enc, a, r, s1_enc, gamma)
        
        # sample experience at random
        batch = self.buffer.replay()
        if batch is None: return
        states, actions, rewards, next_states, gammas = batch
        rewards = rewards.reshape((-1, 1))
        gammas = gammas.reshape((-1, 1))
        n_batch = self.buffer.n_batch
        
        # get actions for the next states
        q = self.get_Q_values(next_states, next_states)
        next_actions = np.argmax(q, axis=1)
        
        # compute target values
        z_ = self.target_Q.predict_on_batch(next_states)
        m_prob = np.zeros_like(z_)
        Tz = np.clip(rewards + gammas * self.z.reshape((1, -1)), self.v_min, self.v_max)
        bj = (Tz - self.v_min) / self.delta_z 
        m_l = np.floor(bj).astype(int)
        m_u = np.ceil(bj).astype(int)
        idx = np.arange(n_batch)
        delta_u = z_[idx, next_actions] * (m_u - bj)
        delta_l = z_[idx, next_actions] * (bj - m_l)
        idx0 = np.repeat(idx, self.n_atoms)
        idx1 = np.repeat(actions, self.n_atoms)
        np.add.at(m_prob, (idx0, idx1, m_l.ravel()), delta_u.ravel())
        np.add.at(m_prob, (idx0, idx1, m_u.ravel()), delta_l.ravel())
        
        # fit model
        self.Q.train_on_batch(states, m_prob)
        
        # target update
        self.updates_since_target_updated += 1
        if self.updates_since_target_updated >= self.target_update_ev:
            self.target_Q.set_weights(self.Q.get_weights())
            self.updates_since_target_updated = 0

    def train(self, train_tasks, n_samples, viewers=None, n_view_ev=None, test_tasks=[], **kwargs):
        if viewers is None: 
            viewers = [None] * len(train_tasks + test_tasks)
            
        self.reset()
        
        # saving test performance
        self.test_reward_hist_per_task = []
        self.test_reward_var_hist_per_task = []
        self.test_fails_hist_per_task = []
        train_and_test_tasks = [task.clone() for task in train_tasks] + [task.clone() for task in test_tasks]
        for _ in train_and_test_tasks:
            self.test_reward_hist_per_task.append([])
            self.test_fails_hist_per_task.append([])
            self.test_reward_var_hist_per_task.append([])
        
        # train each one
        for index, (train_task, viewer) in enumerate(zip(train_tasks + test_tasks, viewers)):
            
            # add and activate task
            self.add_training_task(train_task)
            self.set_active_training_task(index)
            
            # training
            for t in range(n_samples):
                
                # train
                self.next_sample(viewer, n_view_ev)
                
                # test
                if t % self.test_frequency == 0:
                    Rs_task, failed_task = [], []
                    for _ in range(self.test_rollouts):
                        R, fail = self.test_agent(train_and_test_tasks[index])
                        Rs_task.append(R)
                        failed_task.append(fail)
                    Rs_mean_task = np.mean(Rs_task)
                    Rs_var_task = np.var(Rs_task)
                    failed_mean_task = np.mean(failed_task)                        
                    self.test_reward_hist_per_task[index].append(Rs_mean_task)
                    self.test_reward_var_hist_per_task[index].append(Rs_var_task)
                    self.test_fails_hist_per_task[index].append(failed_mean_task)
                    
                    # print test results
                    Rs = [Rs_mean_task]
                    failed = [failed_mean_task]
                    print('TR \t {} \t TF \t {}'.format(
                        '\t'.join(map('{:.4f}'.format, Rs)),
                        '\t'.join(map('{:.2f}'.format, failed))))
    
    def get_test_action(self, s_enc):
        if random.random() <= self.test_epsilon:
            a = random.randrange(self.n_actions)
        else:
            q = self.get_Q_values(s_enc, s_enc)
            a = np.argmax(q)
        return a
            
    def test_agent(self, task):
        R = 0.0
        s = task.initialize()
        s_enc = self.encoding(s)
        failed = 0
        for _ in range(self.T):
            a = self.get_test_action(s_enc)
            s1, r, done, noise = task.transition(a)
            s1_enc = self.encoding(s1)
            s, s_enc = s1, s1_enc
            R += r
            if noise[0]: 
                failed += 1
            if done: 
                break
        return R, failed
    
