import torch
import numpy as np
from Algorithms.learner import Learner


class DQN_master_slave(Learner):
    def __init__(self, config):
        self.obs_size = config['obs_size']
        self.feature_trace = config['feature_trace']
        self.input_size = self.obs_size
        self.eval_type = 'activation_weight'
        self.preserve_features = False
        self.preserved_features_ind = None

        super().__init__(config)
        self.feature_per_aux = int(self.network.hidden_size / (self.num_aux_tasks + 1))
        self.feature_scores = np.zeros(self.network.hidden_size)
        self.feature_variances = np.ones(self.network.hidden_size)
        self.aux_score = np.zeros(self.num_aux_tasks)

    def learn(self, obs_t, a_t, r, terminal, obs_tp1):

        if self.step % self.target_net_update_frequency == 0:
            self.target_network.load_state_dict(self.network.state_dict())

        td_error_squared = self.compute_td_error_squared(obs_t, a_t, r, terminal, obs_tp1)
        x_tp1_detached = self.network.get_features(obs_tp1).detach()

        self.replay_buffer.push(
            {'obs_t': obs_t, 'a_t': a_t, 'r': r, 'terminal': terminal, 'obs_tp1': obs_tp1,
             'td_error_squared': td_error_squared})  # , 'feature_tp1': x_tp1_detached})

        if not self.replay_buffer.replay_start():
            return

        if self.epsilon > self.end_epsilon:
            self.epsilon -= self.epsilon_annealing

        if self.gvf.generator == 'visited_goals' and self.step == 0:
            indices = np.arange(self.num_aux_tasks)
            visited_goals, num_goals = self.replay_buffer.sample_states(self.num_aux_tasks,
                                                                        np.zeros((self.num_aux_tasks, self.obs_size)))
            self.gvf.reset_visited_goals(indices, visited_goals, num_goals)
        if self.gvf.generator == 'feature_attainment' and self.step == 0:
            indices = np.arange(self.num_aux_tasks)
            target_features, num_target_features = self.replay_buffer.sample_features(self.num_aux_tasks,
                                                                                      np.zeros((self.num_aux_tasks,
                                                                                                self.gvf.feature_size)))
            self.gvf.reset_target_features(indices, target_features, num_target_features)

        for r in range(self.num_replay):

            batch_list = self.replay_buffer.sample()
            batch = {k: np.asarray([dic[k] for dic in batch_list]) for k in batch_list[0]}

            x_t = self.network.get_features(batch['obs_t'])
            x_tp1 = self.network.get_features(batch['obs_tp1'])

            x_t_detached = torch.cat((x_t[:, :self.feature_per_aux], x_t[:, self.feature_per_aux:].detach()), axis=1)
            x_tp1_detached = torch.cat((x_tp1[:, :self.feature_per_aux], x_tp1[:, self.feature_per_aux:].detach()),
                                       axis=1)

            network_output_main = self.network.forward(x_t_detached, 0)
            network_output_main_tp1 = self.target_network.forward(x_tp1_detached, 0)

            main_task_loss = 0
            non_terminal_indices = np.where(batch['terminal'] == 0)[0]
            if self.main_task_on:
                batch_a_t_tensor = torch.tensor(batch['a_t'], dtype=torch.int64).unsqueeze(1)
                batch_q_sa = network_output_main.gather(1, batch_a_t_tensor)
                target = torch.tensor(batch['r'], dtype=torch.float32)
                max_q_tp1, _ = torch.max(network_output_main_tp1, 1)
                target[non_terminal_indices] += self.gamma * max_q_tp1[non_terminal_indices]
                main_task_loss = self.criterion(batch_q_sa, target.unsqueeze(1).detach())

            aux_losses = []
            batch_aux_q_sa = []
            x_t_detached_list = []
            x_tp1_detached_list = []
            if self.num_aux_tasks > 0:

                for aux_ind in np.arange(self.num_aux_tasks):
                    x_t_detached_list.append(torch.cat((x_t[:, :(aux_ind + 1) * self.feature_per_aux].detach(),
                                                        x_t[:, (aux_ind + 1) * self.feature_per_aux:(
                                                                                                                aux_ind + 2) * self.feature_per_aux:],
                                                        x_t[:, (aux_ind + 2) * self.feature_per_aux:].detach()),
                                                       axis=1))
                    x_tp1_detached_list.append(torch.cat((x_tp1[:, :(aux_ind + 1) * self.feature_per_aux].detach(),
                                                          x_tp1[:, (aux_ind + 1) * self.feature_per_aux:(
                                                                                                                    aux_ind + 2) * self.feature_per_aux:],
                                                          x_tp1[:, (aux_ind + 2) * self.feature_per_aux:].detach()),
                                                         axis=1))

                    network_output_aux = self.network.forward(x_t_detached_list[aux_ind], aux_ind + 1)
                    network_output_tp1_aux = self.target_network.forward(x_tp1_detached_list[aux_ind], aux_ind + 1)

                    batch_a_t_aux_tensor = torch.tensor(batch['a_t'], dtype=torch.int64).unsqueeze(1)
                    batch_aux_q_sa.append(network_output_aux.gather(1, batch_a_t_aux_tensor))

                    cumulants = self.gvf.cumulant(aux_ind, batch['obs_tp1'],
                                                  batch_obs_t=batch['obs_t'],
                                                  reward=batch['r'],
                                                  )

                    aux_target = torch.tensor(cumulants, dtype=torch.float32)
                    aux_max_q_tp1, _ = torch.max(network_output_tp1_aux, 1)
                    gamma = self.gvf.continuation_function(aux_ind, batch['obs_tp1'])
                    aux_target[non_terminal_indices] += torch.tensor(gamma, dtype=torch.float32)[non_terminal_indices] * \
                                                        aux_max_q_tp1[non_terminal_indices]
                    aux_losses.append(self.criterion(batch_aux_q_sa[aux_ind], aux_target.unsqueeze(1).detach()))

            loss = main_task_loss
            for i in np.arange(self.num_aux_tasks):
                loss += self.aux_weight_loss * aux_losses[i]
            self.optimizer.zero_grad()

            loss.backward()
            # self.network.zero_grad_preserved_features(self.preserved_features_ind)
            self.optimizer.step()

            if self.generate_and_test:

                features_numpy = x_t_detached.cpu().detach().numpy()
                eval_metric = self.evaluate_aux(features_numpy)
                self.aux_score = eval_metric

                i_replace = self.gvf.gen_and_test(eval_metric, direct=False)
                if i_replace is not None:
                    self.reset_feature_scores(i_replace)
                    if self.gvf.generator == 'visited_goals':
                        visited_goals, num_goals = self.replay_buffer.sample_states(self.gvf.replace_number,
                                                                                    self.gvf.visited_goals)
                        self.gvf.reset_visited_goals(i_replace, visited_goals, num_goals)

        self.step += 1

    def evaluate_features(self, features):
        if self.eval_type == 'variance':
            self.feature_variances = np.std(features, axis=0)
            new_score = np.sum(np.abs(self.network.hidden_output[0].weight.cpu().detach().numpy()), axis=0)
        else:
            sum_weight_mag = np.sum(self.network.hidden_output[0].weight.cpu().detach().numpy(), axis=0)
            sum_weight_mag = np.tile(sum_weight_mag, (features.shape[0], 1))
            new_score = np.mean(np.abs(sum_weight_mag * features), axis=0)

        self.feature_scores = self.feature_scores * (1 - self.feature_trace) + self.feature_trace * new_score

        self.preserved_features_ind = np.argpartition(self.feature_scores, self.feature_scores.shape[0] // 5)[
                                      :self.feature_scores.shape[0] // 5]

    def reset_feature_scores(self, i_replace):
        preserved_scores = np.copy(self.feature_scores)
        i_replace_feature_ind = np.arange(0, self.feature_per_aux)
        for i in np.arange(i_replace.shape[0]):
            i_replace_feature_ind = np.concatenate((i_replace_feature_ind,
                                                    np.arange((i_replace[i] + 1) * self.feature_per_aux,
                                                              (i_replace[i] + 2) * self.feature_per_aux)))
        preserved_scores = np.delete(preserved_scores, i_replace_feature_ind)
        median_scores = np.percentile(preserved_scores, 50)
        for i in i_replace:
            self.feature_scores[(i + 1) * self.feature_per_aux:(i + 2) * self.feature_per_aux] = median_scores

    def evaluate_aux(self, features):
        self.evaluate_features(features)
        aux_scores = np.zeros(self.num_aux_tasks)
        if self.eval_type == 'variance':
            for i in np.arange(self.num_aux_tasks):
                aux_scores[i] = np.sum(
                    self.feature_scores[(i + 1) * self.feature_per_aux:(i + 2) * self.feature_per_aux] * \
                    self.feature_variances[(i + 1) * self.feature_per_aux:(i + 2) * self.feature_per_aux])
        else:
            for i in np.arange(self.num_aux_tasks):
                aux_scores[i] = np.sum(
                    self.feature_scores[(i + 1) * self.feature_per_aux:(i + 2) * self.feature_per_aux])

        return aux_scores

    def select_action(self, obs_t, follow_main=True, aux=0):

        if follow_main:
            x_t_detached = self.network.get_features(obs_t).detach()
            q_vec = self.network.forward(x_t_detached, 0).cpu().detach().numpy().flatten()
        else:
            obs_t = np.array(obs_t)
            obs_t = np.expand_dims(obs_t, 0)
            q_vec = self.network.forward(obs_t)[aux].cpu().detach().numpy().flatten()

        if np.random.random() > self.epsilon:
            a = np.argmax(q_vec)
        else:
            a = np.random.randint(self.num_actions)
        return a

    def compute_td_error_squared(self, obs_t, a_t, r, terminal, obs_tp1):
        # x_t_detached = self.network.get_features(obs_t).detach()
        # q_vec = self.network.forward(x_t_detached, 0).cpu().detach().numpy().flatten()
        #
        # # return q_vec[a_t]

        # x_tp1_detached = self.network.get_features(obs_tp1).detach()
        # q_tp1_vec = self.network.forward(x_tp1_detached, 0).cpu().detach().numpy().flatten()
        # td_error = r - q_vec[a_t]
        # if terminal == False:
        #     td_error += np.max(q_tp1_vec)
        # return td_error ** 2

        x_tp1_detached = self.network.get_features(obs_tp1).detach()
        q_tp1_vec = self.network.forward(x_tp1_detached, 0).cpu().detach().numpy().flatten()
        td_error = r
        if terminal == False:
            td_error += np.max(q_tp1_vec)
        return td_error

    def reset_history(self, obs_t):
        return

    def update_history(self, obs_t, a_t):
        return

    def predict(self, obs_t):
        prediction = self.network(obs_t)
        return prediction.cpu().detach().numpy()
