import torch
import numpy as np
import copy
from Algorithms.learner import Learner


class DQN_gradient_alignment(Learner):
    def __init__(self, config):
        self.obs_size = config['obs_size']
        self.input_size = self.obs_size
        super().__init__(config)

        self.gradient_traces = torch.zeros((self.num_aux_tasks + 1, self.network.get_shared_weight_size()))
        self.trace_param = 0.01
        self.aux_score = np.zeros(self.num_aux_tasks)

    def learn(self, obs_t, a_t, r, terminal, obs_tp1):

        if self.step % self.target_net_update_frequency == 0:
            self.target_network.load_state_dict(self.network.state_dict())

        self.replay_buffer.push(
            {'obs_t': obs_t, 'a_t': a_t, 'r': r, 'terminal': terminal, 'obs_tp1': obs_tp1})
        if not self.replay_buffer.replay_start():
            return

        if self.epsilon > 0.1:
            self.epsilon -= self.epsilon_annealing

        for r in range(self.num_replay):

            batch_list = self.replay_buffer.sample()
            batch = {k: np.asarray([dic[k] for dic in batch_list]) for k in batch_list[0]}

            network_output = self.network(batch['obs_t'])
            network_output_tp1 = self.target_network(batch['obs_tp1'])

            main_task_loss = 0
            non_terminal_indices = np.where(batch['terminal'] == 0)[0]
            if self.main_task_on:
                batch_a_t_tensor = torch.tensor(batch['a_t'], dtype=torch.int64).unsqueeze(1)
                batch_q_sa = network_output.gather(1, batch_a_t_tensor)

                target = torch.tensor(batch['r'], dtype=torch.float32)
                max_q_tp1, _ = torch.max(network_output_tp1[:, :self.num_actions], 1)
                target[non_terminal_indices] += self.gamma * max_q_tp1[non_terminal_indices]

                main_task_loss = self.criterion(batch_q_sa, target.unsqueeze(1).detach())

            aux_losses = []
            batch_aux_q_sa = []

            for aux_ind in np.arange(self.num_aux_tasks):
                batch_a_t_aux = (aux_ind + self.main_task_ind) * self.num_actions + batch['a_t']
                batch_a_t_aux_tensor = torch.tensor(batch_a_t_aux, dtype=torch.int64).unsqueeze(1)
                batch_aux_q_sa.append(network_output.gather(1, batch_a_t_aux_tensor))


                cumulants = self.gvf.cumulant(aux_ind, batch['obs_tp1'],
                                              batch_obs_t=batch['obs_t'],
                                              reward=batch['r'],
                                              )
                aux_target = torch.tensor(cumulants, dtype=torch.float32)
                aux_max_q_tp1, _ = torch.max(network_output_tp1[:, self.num_actions * \
                                                                   (aux_ind + self.main_task_ind): self.num_actions * (
                            aux_ind + self.main_task_ind + 1)], 1)
                gamma = self.gvf.continuation_function(aux_ind, batch['obs_tp1'])
                aux_target[non_terminal_indices] += torch.tensor(gamma, dtype=torch.float32)[non_terminal_indices] * aux_max_q_tp1[non_terminal_indices]
                aux_losses.append(self.criterion(batch_aux_q_sa[aux_ind], aux_target.unsqueeze(1).detach()))

            self.optimizer.zero_grad()
            main_task_loss.backward(retain_graph=True)
            main_gradient = self.network.get_shared_gradient()
            self.gradient_traces[0, :] = (1 - self.trace_param) * self.gradient_traces[0, :] + self.trace_param * main_gradient
            sum_gradients = copy.deepcopy(main_gradient)
            aux_gradients = []
            for i in np.arange(self.num_aux_tasks):
                if i < (self.num_aux_tasks - 1):
                    aux_losses[i].backward(retain_graph = True)
                else:
                    aux_losses[i].backward()
                aux_gradients.append(self.network.get_shared_gradient() - sum_gradients)
                sum_gradients += aux_gradients[i]
                self.gradient_traces[i+1, :] = (1 - self.trace_param) * self.gradient_traces[
                    i+1, :] + self.trace_param * aux_gradients[i]

            cos = torch.nn.CosineSimilarity(dim = 0)
            cosine_similarity = np.zeros(self.num_aux_tasks)
            for i in range(self.num_aux_tasks):
                # cosine_similarity[i] = cos(self.gradient_traces[0], self.gradient_traces[i+1])
                cosine_similarity[i] = torch.dot(self.gradient_traces[0], self.gradient_traces[i+1]) / cos(self.gradient_traces[0], self.gradient_traces[i+1])

            self.optimizer.step()
            self.aux_score = cosine_similarity
            i_replace = self.gvf.gen_and_test(cosine_similarity, direct = False)
            if i_replace is not None:
                self.network.reset_aux_weights(i_replace)


        self.step += 1

    def get_features_binary(self, obs_t):
        x_t = self.network.get_features_binary(obs_t)
        return x_t

    def get_features(self, obs_t):
        x_t = self.network.get_features(obs_t)
        return x_t

    def select_action(self, obs_t, follow_main=True, aux=0):
        index = 0
        if self.main_task_on == False:
            index = aux * self.num_actions
        else:
            if follow_main == False:
                index = aux * self.num_actions
            else:
                index = 0

        q_vec = self.network(obs_t).cpu().detach().numpy().flatten()
        q_vec = q_vec[index:index + self.num_actions]
        if np.random.random() > self.epsilon:
            a = np.argmax(q_vec)
        else:
            a = np.random.randint(self.num_actions)
        return a

    def reset_history(self, obs_t):
        return

    def update_history(self, obs_t, a_t):
        return

    def predict(self, obs_t):
        prediction = self.network(obs_t)
        return prediction.cpu().detach().numpy()