#######################################################################
# Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com)    #
# Permission given to modify the code as long as you keep this        #
# declaration at the top                                              #
#######################################################################

from ..network import *
from ..component import *
from ..utils import *
from .BaseAgent import *
from .DQN_agent import *
from geomloss import SamplesLoss
from main import args

class MMDActor(DQNActor):
    def __init__(self, config):
        super().__init__(config)

    def compute_q(self, prediction):
        q_values = prediction['sample'].mean(-1)
        return to_np(q_values)


class MMDAgent(DQNAgent):
    def __init__(self, config):
        BaseAgent.__init__(self, config)
        self.config = config
        config.lock = mp.Lock()

        self.replay = config.replay_fn()
        self.actor = MMDActor(config)

        self.network = config.network_fn()
        self.network.share_memory()
        self.target_network = config.network_fn()
        self.target_network.load_state_dict(self.network.state_dict())
        self.optimizer = config.optimizer_fn(self.network.parameters())

        # ew
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.actor.set_network(self.network)

        self.total_steps = 0
        self.batch_indices = range_tensor(config.batch_size)

        self.epsilon = config.epsilon
        self.niter_sink = config.niter_sink

    def eval_step(self, state):
        self.config.state_normalizer.set_read_only()
        state = self.config.state_normalizer(state)
        # q= self.network(state)['quantile'].mean(-1)
        # NEW
        q_, feature = self.network(state)
        q = q_['sample'].mean(-1)

        action = np.argmax(to_np(q).flatten())
        self.config.state_normalizer.unset_read_only()
        return [action], feature

    def GaussianKernal(self, v1, v2, sigma):  # [bs, N, N]
        # v: [bs, N]
        v1 = v1.unsqueeze(2)  # [bs, N] -> [bs, N, 1]
        v2 = v2.unsqueeze(1)
        d = (v1 - v2) ** 2
        sigma = 1.0 / torch.tensor(sigma).float().view(-1, 1).cuda(args.gpu)
        temp = torch.matmul(sigma, d.view(1, -1))  # [k, 1] [1, bs*N*N]
        return torch.sum(torch.exp(-temp), dim=0).reshape(v1.shape[0], v1.shape[1], v1.shape[1])

    def compute_loss(self, transitions):
        states = self.config.state_normalizer(transitions.state)
        next_states = self.config.state_normalizer(transitions.next_state)

        # new
        samples_, _ = self.target_network(next_states)
        samples_next = samples_['sample'].detach()
        a_next = torch.argmax(samples_next.sum(-1), dim=-1)
        samples_next = samples_next[self.batch_indices, a_next, :] # Z(s',a*) = [bs, N]

        rewards = tensor(transitions.reward).unsqueeze(-1)
        masks = tensor(transitions.mask).unsqueeze(-1)
        samples_next = rewards + self.config.discount ** self.config.n_step * masks * samples_next

        samples_, _ = self.network(states)
        samples = samples_['sample']

        actions = tensor(transitions.action).long()
        samples = samples[self.batch_indices, actions, :]
        x, y = samples, samples_next # [batch, N], e.g., N=200

        ############# MMD (mixed kernel)
        Sigma = [1,2,3,4,5,6,7,8,9,10]
        xixj = self.GaussianKernal(x, x, Sigma)
        yiyj = self.GaussianKernal(y, y, Sigma)
        xiyj = self.GaussianKernal(x, y, Sigma)
        sink_loss = xixj + yiyj - 2 * xiyj

        return sink_loss

    def reduce_loss(self, loss):
        return loss.mean()
