import torch
import numpy as np
from Algorithms.learner import Learner


class DQN(Learner):
    def __init__(self, config):
        self.obs_size = config['obs_size']
        self.input_size = self.obs_size
        super().__init__(config)

    def learn(self, obs_t, a_t, r, terminal, obs_tp1):

        if self.step % self.target_net_update_frequency == 0:
            self.target_network.load_state_dict(self.network.state_dict())

        self.replay_buffer.push(
            {'obs_t': obs_t, 'a_t': a_t, 'r': r, 'terminal': terminal, 'obs_tp1': obs_tp1})
        if not self.replay_buffer.replay_start():
            return

        if self.epsilon > self.end_epsilon:
            self.epsilon -= self.epsilon_annealing

        for r in range(self.num_replay):

            batch_list = self.replay_buffer.sample()
            batch = {k: np.asarray([dic[k] for dic in batch_list]) for k in batch_list[0]}

            network_output = self.network(batch['obs_t'])
            network_output_tp1 = self.target_network(batch['obs_tp1'])

            main_task_loss = 0
            non_terminal_indices = np.where(batch['terminal'] == 0)[0]
            if self.main_task_on:
                batch_a_t_tensor = torch.tensor(batch['a_t'], dtype=torch.int64).unsqueeze(1)
                batch_q_sa = network_output.gather(1, batch_a_t_tensor)

                target = torch.tensor(batch['r'], dtype=torch.float32)
                max_q_tp1, _ = torch.max(network_output_tp1[:, :self.num_actions], 1)
                target[non_terminal_indices] += self.gamma * max_q_tp1[non_terminal_indices]

                main_task_loss = self.criterion(batch_q_sa, target.unsqueeze(1).detach())

            aux_losses = []
            batch_aux_q_sa = []

            for aux_ind in np.arange(self.num_aux_tasks):
                batch_a_t_aux = (aux_ind + self.main_task_ind) * self.num_actions + batch['a_t']
                batch_a_t_aux_tensor = torch.tensor(batch_a_t_aux, dtype=torch.int64).unsqueeze(1)
                batch_aux_q_sa.append(network_output.gather(1, batch_a_t_aux_tensor))


                cumulants = self.gvf.cumulant(aux_ind, batch['obs_tp1'],
                                              batch_obs_t=batch['obs_t'],
                                              reward=batch['r']
                                              )
                aux_target = torch.tensor(cumulants, dtype=torch.float32)
                aux_max_q_tp1, _ = torch.max(network_output_tp1[:, self.num_actions * \
                                                                   (aux_ind + self.main_task_ind): self.num_actions * (
                            aux_ind + self.main_task_ind + 1)], 1)
                gamma = self.gvf.continuation_function(aux_ind, batch['obs_tp1'])
                aux_target[non_terminal_indices] += torch.tensor(gamma, dtype=torch.float32)[non_terminal_indices] * aux_max_q_tp1[non_terminal_indices]
                aux_losses.append(self.criterion(batch_aux_q_sa[aux_ind], aux_target.unsqueeze(1).detach()))

            loss = main_task_loss
            for i in np.arange(self.num_aux_tasks):
                loss += self.aux_weight_loss * aux_losses[i]

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        self.step += 1

    def select_action(self, obs_t, follow_main=True, aux=0):
        index = 0
        if self.main_task_on == False:
            index = aux * self.num_actions
        else:
            if follow_main == False:
                index = aux * self.num_actions
            else:
                index = 0

        q_vec = self.network(obs_t).cpu().detach().numpy().flatten()
        q_vec = q_vec[index:index + self.num_actions]
        if np.random.random() > self.epsilon:
            a = np.argmax(q_vec)
        else:
            a = np.random.randint(self.num_actions)
        return a

    def reset_history(self, obs_t):
        return

    def update_history(self, obs_t, a_t):
        return

    def predict(self, obs_t):
        prediction = self.network(obs_t)
        return prediction.cpu().detach().numpy()