import random
import numpy as np

from agents.agent import Agent


class SFDQN(Agent):
    
    def __init__(self, deep_sf, buffer, *args, use_gpi=True, test_epsilon=0.03,
                 test_frequency=10000, plot_cov_frequency=20000, test_rollouts=1, **kwargs):
        super(SFDQN, self).__init__(*args, **kwargs)
        self.sf = deep_sf
        self.buffer = buffer
        self.use_gpi = use_gpi
        self.test_epsilon = test_epsilon
        self.test_frequency = test_frequency
        self.plot_cov_frequency = plot_cov_frequency
        self.test_rollouts = test_rollouts
        self.key = deep_sf.key
        
    def get_Q_values(self, s, s_enc):
        q, c = self.sf.GPI(s_enc, self.task_index, update_counters=self.use_gpi)
        if not self.use_gpi:
            c = self.task_index
        self.c = c
        return q[:, c,:]
    
    def train_agent(self, s, s_enc, a, r, s1, s1_enc, gamma, noise):
        
        # update w
        phi = self.phi(s, a, s1, noise)
        self.sf.update_reward(phi, r, self.task_index)
        
        # remember this experience
        self.buffer.append(s_enc, a, phi, s1_enc, gamma, other=noise[1])
        
        # update SFs
        transitions = self.buffer.replay()
        for index in range(self.n_tasks):
            self.sf.update_successor(transitions, index)
        
    def reset(self):
        super(SFDQN, self).reset()
        self.sf.reset()
        self.buffer.reset()
        self.episode_reward_hist_per_task = []
        self.episode_fails_hist_per_task = []

    def add_training_task(self, task):
        super(SFDQN, self).add_training_task(task)
        self.sf.add_training_task(task, source=None)
        self.episode_reward_hist_per_task.append([])
        self.episode_fails_hist_per_task.append([])
        
    def get_progress_strings(self):
        sample_str, reward_str = super(SFDQN, self).get_progress_strings()
        gpi_percent = self.sf.GPI_usage_percent(self.task_index)
        w_error = np.linalg.norm(self.sf.fit_w[self.task_index] - self.sf.true_w[self.task_index])
        gpi_str = 'GPI% \t {:.4f} \t WERR \t {:.4f}'.format(gpi_percent, w_error)
        return sample_str, reward_str, gpi_str
    
    def next_sample(self, viewer=None, n_view_ev=None):
        super(SFDQN, self).next_sample(viewer, n_view_ev)
        if self.steps % self.save_ev == 0:
            self.episode_reward_hist_per_task[self.task_index].append(self.episode_reward)
            self.episode_fails_hist_per_task[self.task_index].append(self.episode_fails)
        
    def train(self, train_tasks, n_samples, viewers=None, n_view_ev=None, test_tasks=[], plot_var=None):
        if viewers is None: 
            viewers = [None] * len(train_tasks)
            
        # add tasks
        self.reset()
        for train_task in train_tasks:
            self.add_training_task(train_task)
        
        # saving test performance
        self.test_reward_hist_per_task = []
        self.test_reward_var_hist_per_task = []
        self.test_fails_hist_per_task = []
        train_and_test_tasks = [task.clone() for task in train_tasks] + test_tasks
        for _ in train_and_test_tasks:
            self.test_reward_hist_per_task.append([])
            self.test_fails_hist_per_task.append([])
            self.test_reward_var_hist_per_task.append([])
        
        # train each one
        for index, (train_task, viewer) in enumerate(zip(train_tasks, viewers)):
            self.set_active_training_task(index)
            for t in range(n_samples):
                
                # train
                self.next_sample(viewer, n_view_ev)
                
                # test
                if t % self.test_frequency == 0:
                    for i, test_task in enumerate(train_and_test_tasks):
                        Rs_task, failed_task = [], []
                        for _ in range(self.test_rollouts):
                            R, fail = self.test_agent(test_task)
                            Rs_task.append(R)
                            failed_task.append(fail)
                        Rs_mean_task = np.mean(Rs_task)
                        Rs_var_task = np.var(Rs_task)
                        failed_mean_task = np.mean(failed_task)                        
                        self.test_reward_hist_per_task[i].append(Rs_mean_task)
                        self.test_reward_var_hist_per_task[i].append(Rs_var_task)
                        self.test_fails_hist_per_task[i].append(failed_mean_task)
                    
                    # print test results
                    Rs = [arr[-1] for arr in self.test_reward_hist_per_task]
                    failed = [arr[-1] for arr in self.test_fails_hist_per_task]
                    print('TR \t {} \t TF \t {}'.format(
                        '\t'.join(map('{:.4f}'.format, Rs)),
                        '\t'.join(map('{:.2f}'.format, failed))))
                
                # plot covariance
                if plot_var is not None and \
                    self.global_steps > 0 and self.global_steps % self.plot_cov_frequency == 0:
                    test_ws = [test_task.get_w() for test_task in test_tasks]
                    plot_var(self.sf, self.buffer, test_ws, index, self.global_steps)
    
    def get_test_action(self, s_enc, w):
        if random.random() <= self.test_epsilon:
            a = random.randrange(self.n_actions)
        else:
            q, c = self.sf.GPI_w(s_enc, w)
            q = q[:, c,:]
            a = np.argmax(q)
        return a
            
    def test_agent(self, task):
        R = 0.0
        w = task.get_w()
        s = task.initialize()
        s_enc = self.encoding(s)
        failed = 0
        for _ in range(self.T):
            a = self.get_test_action(s_enc, w)
            s1, r, done, noise = task.transition(a)
            s1_enc = self.encoding(s1)
            s, s_enc = s1, s1_enc
            R += r
            if noise[0]:
                failed += 1
            if done:
                break
        return R, failed
    
    def test_agent_rollouts(self, tasks, n_rollouts=1):
        rollouts = []
        for i, task in enumerate(tasks):
            task_rollouts = []
            print('rollouts for task {}'.format(i))
            w = task.get_w()
            for _ in range(n_rollouts):
                episode = []
                s = task.initialize()
                s_enc = self.encoding(s)
                for _ in range(self.T):
                    a = self.get_test_action(s_enc, w)
                    s1, r, done, noise = task.transition(a)
                    episode.append(np.array(noise[1][1]))
                    s1_enc = self.encoding(s1)
                    s, s_enc = s1, s1_enc
                    if done:
                        break
                episode = np.vstack(episode)
                task_rollouts.append(episode)
            rollouts.append(task_rollouts)
        return rollouts
    
