#######################################################################
# Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com)    #
# Permission given to modify the code as long as you keep this        #
# declaration at the top                                              #
#######################################################################

from ..network import *
from ..component import *
from ..utils import *
import time
from .BaseAgent import *
from .DQN_agent import *

from copy import deepcopy

class CategoricalDQNActor(DQNActor):
    def __init__(self, config):
        super().__init__(config)

    def _set_up(self):
        self.config.atoms = tensor(self.config.atoms)

    def compute_q(self, prediction):
        q_values = (prediction['prob'] * self.config.atoms).sum(-1)
        return to_np(q_values)


class CategoricalDQNAgent(DQNAgent):
    def __init__(self, config):
        BaseAgent.__init__(self, config)
        self.config = config
        config.lock = mp.Lock()
        config.atoms = np.linspace(config.categorical_v_min, config.categorical_v_max, config.categorical_n_atoms)

        self.replay = config.replay_fn()
        self.actor = CategoricalDQNActor(config)

        self.network = config.network_fn()
        self.network.share_memory()
        self.target_network = config.network_fn()
        self.target_network.load_state_dict(self.network.state_dict())
        self.optimizer = config.optimizer_fn(self.network.parameters())

        self.actor.set_network(self.network)
        self.total_steps = 0
        self.batch_indices = range_tensor(config.batch_size)
        self.atoms = tensor(config.atoms)
        self.delta_atom = (config.categorical_v_max - config.categorical_v_min) / float(config.categorical_n_atoms - 1)

    def eval_step(self, state):
        self.config.state_normalizer.set_read_only()
        state = self.config.state_normalizer(state)
        prediction, feature = self.network(state)
        q = (prediction['prob'] * self.atoms).sum(-1)
        action = to_np(q.argmax(-1))
        self.config.state_normalizer.unset_read_only()
        return action, feature

    def CEloss(self, p, logq):
        # return (p * p.add(1e-9).log() - p * logq).sum(-1) # KL divergene
        return (- p * logq).sum(-1)  #

    def compute_loss(self, transitions):
        config = self.config
        states = self.config.state_normalizer(transitions.state)
        next_states = self.config.state_normalizer(transitions.next_state)
        with torch.no_grad():
            prob_next, _ = self.target_network(next_states)
            prob_next = prob_next['prob']
            q_next = (prob_next * self.atoms).sum(-1)
            if config.double_q:
                prob_ = self.network(next_states)
                a_next = torch.argmax((prob_['prob'] * self.atoms).sum(-1), dim=-1)
            else:
                a_next = torch.argmax(q_next, dim=-1)
            prob_next = prob_next[self.batch_indices, a_next, :]

        rewards = tensor(transitions.reward).unsqueeze(-1)
        masks = tensor(transitions.mask).unsqueeze(-1)

        atoms_target = rewards + self.config.discount ** config.n_step * masks * self.atoms.view(1, -1)
        atoms_target.clamp_(self.config.categorical_v_min, self.config.categorical_v_max)
        atoms_target = atoms_target.unsqueeze(1)
        target_prob = (1 - (atoms_target - self.atoms.view(1, -1, 1)).abs() / self.delta_atom).clamp(0,1) * prob_next.unsqueeze(1)
        target_prob = target_prob.sum(-1)

        if self.config.method == 'C51':
            log_prob, _ = self.network(states)
            log_prob = log_prob['log_prob']
            actions = tensor(transitions.action).long()
            log_prob = log_prob[self.batch_indices, actions, :]
            KL = (target_prob * target_prob.add(1e-5).log() - target_prob * log_prob).sum(-1)
            return KL

        else:
            probs, _ = self.network(states)
            log_prob = probs['log_prob']
            actions = tensor(transitions.action).long()
            log_prob = log_prob[self.batch_indices, actions, :]
            target_q = (target_prob * self.atoms).sum(-1)  # [bs]
            bs = target_prob.shape[0]
            atoms = self.atoms.unsqueeze(0).repeat(bs, 1).detach()  # [bs, 51]

            ######## revised version: varepsilon is a ratio
            flag_mapping = (atoms < target_q.unsqueeze(1)).long() + (atoms + self.delta_atom > target_q.unsqueeze(1)).long() - 1

            # target_prob * flag_mapping * self.config.varepsilon # maintain target prob
            p_e = (target_prob * (1 - flag_mapping)).sum(-1).unsqueeze(1) # [bs, 1]
            true_epsilon = p_e / (1 - (target_prob * flag_mapping).sum(-1) * self.config.varepsilon).unsqueeze(1) # [bs, 1], equivalent to normalization
            # print('target_prob', target_prob)
            # print('pe', (target_prob * flag_mapping).sum(-1))
            # print('true epsilon', true_epsilon)
            # new_target_prob = (target_prob - (1 - true_epsilon) * flag_mapping) / true_epsilon
            new_target_prob = (target_prob * (1-flag_mapping)) / true_epsilon + target_prob * flag_mapping * self.config.varepsilon
            # print(new_target_prob.sum(-1))
            # loss0 = self.CEloss(flag_mapping, log_prob)
            # entropy = self.CEloss(new_target_prob, log_prob)
            # KL_loss = (1 - true_epsilon) * loss0 + true_epsilon * entropy # not hold anymore since we use the ratio reduction


            ######## revised version: smart transformation
            # temp_target_q = deepcopy(target_q.unsqueeze(1).detach())
            # flag_mapping = (atoms < temp_target_q).long() + (atoms + self.delta_atom > temp_target_q).long() - 1
            # flag_mapping = flag_mapping.detach()
            # loss0 = self.CEloss(flag_mapping, log_prob) # only update parameters in log_prob
            # # new_target_prob = target_prob * (1-flag_mapping) / self.config.varepsilon + target_prob * flag_mapping * (1 - self.config.varepsilon) / self.config.varepsilon
            # new_target_prob = (target_prob - (1 - self.config.varepsilon) * flag_mapping) / self.config.varepsilon
            # new_target_prob = new_target_prob.detach() # only update log_prob
            # entropy = self.CEloss(new_target_prob, log_prob)
            # KL_loss = (1 - self.config.varepsilon) * loss0 + self.config.varepsilon * entropy

            # new version instead of flag_mapping (wrong)
            # max_index = target_prob.argmax(dim=1).unsqueeze(1) # [bs, 1]
            # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            # flag_mapping = torch.zeros_like(target_prob).to(device)
            # flag_mapping = flag_mapping.scatter_(1, max_index, 1) # scalar to one-hot labels

            ######## revised version end
            ######## previous version start
            # print('atoms ', atoms[0,:])
            # print('delta_atom', self.delta_atom)
            # print('target_q', target_q[0])
            # flag_mapping = (atoms < target_q.unsqueeze(1)).long() + (atoms + self.delta_atom > target_q.unsqueeze(1)).long() - 1  # [bs, 51] True + True - 1 = True
            # pe = (target_prob * flag_mapping).sum(-1)  # [bs]
            # epsilon = 1 - pe # [bs]
            # epsilon = epsilon.detach()  # no gradient as it is only an coefficient, otherwise less known updating directions
            # print('flagging', flag_mapping[0,:])
            # new_target_prob = (target_prob / epsilon.unsqueeze(1)) * (1 - flag_mapping)
            # print('target prob', target_prob[0,:])
            # print('flagging target prob:, ', (target_prob*(1-flag_mapping))[0,:])
            # print('sum target prob: ', target_prob[0,:].sum())
            # print('epsilon: ', epsilon[0])
            # print('divide epsilon: new target prob', new_target_prob[0,:])
            # print('sum new target prob: ', new_target_prob[0, :].sum())
            # print('---------check------------')
            # print('1: ', (1-epsilon[0]) * flag_mapping[0,:])
            # print('2: ', epsilon[0] * new_target_prob[0, :])
            # print('1+2: ', (1-epsilon[0]) * flag_mapping[0,:] + epsilon[0] * new_target_prob[0, :])
            # print('true: ', target_prob[0,:])
            # loss0 = self.CEloss(flag_mapping, log_prob)
            # entropy = self.CEloss(new_target_prob, log_prob)
            # KL_loss = (1-epsilon.detach()) * loss0 + epsilon.detach() * entropy
            ######## previous version end


            if self.config.method == 'C51_ent':
                # CE_loss = (1-epsilon) * loss0
                CE_loss = self.CEloss(new_target_prob.detach(), log_prob)
            elif self.config.method == 'C51_CE':
                CE_loss = self.CEloss(target_prob.detach(), log_prob)  # only update log_prob as target_prob is fixed from target policy
            # elif self.config.method == 'C51_ext':
            #     CE_loss = (1 - self.config.varepsilon) * loss0
            # elif self.config.method == 'C51_exp+ent': # may not be useful
            #     # CE_loss = (1 - epsilon) * loss0 + epsilon * entropy
            #     CE_loss = (1 - self.config.varepsilon) * loss0 + self.config.varepsilon * entropy # may be useless

            else:
                print('error!')
                CE_loss = 0

            # total KL loss
            # KL = self.CEloss(target_prob, log_prob)
            # print('loss0: {}, entropy: {},  KL_loss: {},  true KL: {}'.format(loss0.mean(), entropy.mean(), KL_loss.mean(), KL.mean()))
            # print('loss0: {}, entropy: {},  KL_loss: {},  true KL: {}'.format(loss0, entropy, KL_loss, KL))
            return CE_loss # CE_loss = KL_loss as we use target NN




        # elif self.config.method == 'C51DQN':
        #     current_prob, _ = self.network(states)
        #     current_prob = current_prob['prob']
        #     actions = tensor(transitions.action).long()
        #     current_prob = current_prob[self.batch_indices, actions, :]
        #
        #     target_q = (target_prob * self.atoms).sum(-1)
        #     current_q =  (current_prob * self.atoms).sum(-1)
        #     loss = target_q - current_q
        #     loss = loss.pow(2).mul(0.5)
        #     return loss
        #
        # elif self.config.method == 'C51DQN_ent0': # epsilon = 1 - p_E
        #     current_prob, _ = self.network(states) # [bs, 51]
        #     current_prob = current_prob['prob']
        #     actions = tensor(transitions.action).long()
        #     current_prob = current_prob[self.batch_indices, actions, :]
        #
        #     target_q = (target_prob * self.atoms).sum(-1) # [bs]
        #     current_q = (current_prob * self.atoms).sum(-1)
        #     loss = target_q - current_q
        #     loss = loss.pow(2).mul(0.5)
        #     # compute epsilon
        #     bs = current_prob.shape[0]
        #     atoms = self.atoms.unsqueeze(0).repeat(bs, 1) # [bs, 51]
        #     flag_mapping = (atoms - self.delta_atom < target_q.unsqueeze(1)).long() + (atoms > target_q.unsqueeze(1)).long() - 1 # [bs, 51]
        #     pe = (target_prob * flag_mapping).sum(-1) # [bs]
        #     epsilon = (1 - pe).unsqueeze(1)
        #     # compute new target_prob
        #     new_target_prob = (target_prob / epsilon) * (1-flag_mapping)
        #     # compute the entropy
        #     log_prob, _ = self.network(states)
        #     log_prob = log_prob['log_prob']
        #     actions = tensor(transitions.action).long()
        #     log_prob = log_prob[self.batch_indices, actions, :]
        #     entropy = (new_target_prob * new_target_prob.add(1e-5).log() - new_target_prob * log_prob).sum(-1)
        #     return loss + self.config.alpha * entropy
        # elif self.config.method == 'C51DQN_entsoftmax':
        #     current_prob, _ = self.network(states)  # [bs, 51]
        #     current_prob = current_prob['prob']
        #     actions = tensor(transitions.action).long()
        #     current_prob = current_prob[self.batch_indices, actions, :]
        #
        #     target_q = (target_prob * self.atoms).sum(-1)  # [bs]
        #     current_q = (current_prob * self.atoms).sum(-1)
        #     loss = target_q - current_q
        #     loss = loss.pow(2).mul(0.5)
        #     # compute epsilon
        #     bs = current_prob.shape[0]
        #     atoms = self.atoms.unsqueeze(0).repeat(bs, 1)  # [bs, 51]
        #     flag_mapping = (atoms - self.delta_atom < target_q.unsqueeze(1)).long() + (atoms > target_q.unsqueeze(1)).long() - 1  # [bs, 51]
        #     new_target_prob = (target_prob / self.config.varepsilon) * (1 - flag_mapping)
        #     new_target_prob -= flag_mapping * (1-self.config.varepsilon) / (self.config.varepsilon)
        #     new_target_prob = F.softmax(new_target_prob, dim=-1)
        #     # compute the entropy
        #     log_prob, _ = self.network(states)
        #     log_prob = log_prob['log_prob']
        #     actions = tensor(transitions.action).long()
        #     log_prob = log_prob[self.batch_indices, actions, :]
        #     entropy = (new_target_prob * new_target_prob.add(1e-5).log() - new_target_prob * log_prob).sum(-1)
        #     return loss + self.config.alpha * entropy
        #
        # else:
        #     pass

    def reduce_loss(self, loss):
        return loss.mean()
