#######################################################################
# Copyright (C) 2017 Shangtong Zhang(zhangshangtong.cpp@gmail.com)    #
# Permission given to modify the code as long as you keep this        #
# declaration at the top                                              #
#######################################################################

from ..network import *
from ..component import *
from ..utils import *
from .BaseAgent import *
from .DQN_agent import *
from geomloss import SamplesLoss
from main import args

class SinkhornRegressionDQNActor(DQNActor):
    def __init__(self, config):
        super().__init__(config)

    def compute_q(self, prediction):
        q_values = prediction['sample'].mean(-1)
        return to_np(q_values)

def sinkhorn_loss(x, y, epsilon, n, niter):
    """
    Given two emprical measures with n points each with locations x and y
    outputs an approximation of the OT cost with regularization parameter epsilon
    niter is the max. number of steps in sinkhorn loop
    """
    # The Sinkhorn algorithm takes as input three variables :
    C = cost_matrix(x, y).cuda(args.gpu)  # Wasserstein cost function [bs, N, N]
    bs = C.shape[0] # 32
    # both marginals are fixed with equal weights
    mu = 1. / n * torch.ones(bs, n).cuda(args.gpu)
    nu = 1. / n * torch.ones(bs, n).cuda(args.gpu)
    mu.requires_grad = False
    nu.requires_grad = False

    # Parameters of the Sinkhorn algorithm.
    rho = 1  # (.5) **2          # unbalanced transport
    tau = -.8  # nesterov-like acceleration
    lam = rho / (rho + epsilon)  # Update exponent
    thresh = 10**(-1)  # stopping criterion

    # Elementary operations .....................................................................
    def ave(u, u1):
        "Barycenter subroutine, used by kinetic acceleration through extrapolation."
        return tau * u + (1 - tau) * u1

    def M(u, v):
        "Modified cost for logarithmic updates"
        "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
        # return (-C + u.unsqueeze(1) + v.unsqueeze(0)) / epsilon
        return (-C + u.unsqueeze(2) + v.unsqueeze(1)) / epsilon

    def lse(A):
        "log-sum-exp"
        # return torch.log(torch.exp(A).sum(1, keepdim=True) + 1e-6)  # add 10^-6 to prevent NaN
        return torch.log(torch.exp(A).sum(2, keepdim=True) + 1e-6)  # add 10^-6 to prevent NaN

    # Actual Sinkhorn loop ......................................................................
    u, v, err = 0. * mu, 0. * nu, 0.
    actual_nits = 0  # to check if algorithm terminates because of threshold or max iterations reached

    for i in range(niter):
        u1 = u  # useful to check the update
        u = epsilon * (torch.log(mu) - lse(M(u, v)).squeeze()) + u
        # v = epsilon * (torch.log(nu) - lse(M(u, v).t()).squeeze()) + v
        v = epsilon * (torch.log(nu) - lse(M(u, v).permute(0, 2, 1)).squeeze()) + v
        err = (u - u1).abs().sum()
        actual_nits += 1
        if (err < thresh).data.cpu().numpy():
            break
    U, V = u, v
    Gamma  = torch.exp(M(U, V))  # Transport plan pi = diag(a)*K*diag(b)
    cost = torch.sum(Gamma * C)  # Sinkhorn cost
    return cost # singe element


def cost_matrix(x, y, p=2): # [bs, N, 1] -> [bs, N, N]
    "Returns the matrix of $|x_i-y_j|^p$."
    x_col = x.unsqueeze(2) # [bs, N, p=1] -> [bs, N, 1, p=1]
    y_lin = y.unsqueeze(1) # [bs, N] -> [bs, 1, N, p=1]
    c = torch.sum((torch.abs(x_col - y_lin)) ** p, 3) # sum over p
    return c

class SinkhornDQNAgent(DQNAgent):
    def __init__(self, config):
        BaseAgent.__init__(self, config)
        self.config = config
        config.lock = mp.Lock()

        self.replay = config.replay_fn()
        self.actor = SinkhornRegressionDQNActor(config)

        self.network = config.network_fn()
        self.network.share_memory()
        self.target_network = config.network_fn()
        self.target_network.load_state_dict(self.network.state_dict())
        self.optimizer = config.optimizer_fn(self.network.parameters())

        # ew
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.network_phi = config.network_fn_phi().cuda(args.gpu)
        # self.network_phi.share_memory()
        self.optimizer_phi = config.optimizer_fn_phi(self.network_phi.parameters())

        self.actor.set_network(self.network)

        self.total_steps = 0
        self.batch_indices = range_tensor(config.batch_size)

        self.epsilon = config.epsilon
        self.niter_sink = config.niter_sink

    def eval_step(self, state):
        self.config.state_normalizer.set_read_only()
        state = self.config.state_normalizer(state)
        # q= self.network(state)['quantile'].mean(-1)
        # NEW
        q_, feature = self.network(state)
        q = q_['sample'].mean(-1)

        action = np.argmax(to_np(q).flatten())
        self.config.state_normalizer.unset_read_only()
        return [action], feature

    def GaussianKernal(self, v1, v2, sigma):  # [bs, N, N]
        # v: [bs, N]
        v1 = v1.unsqueeze(2)  # [bs, N] -> [bs, N, 1]
        v2 = v2.unsqueeze(1)
        d = (v1 - v2) ** 2
        sigma = 1.0 / torch.tensor(sigma).float().view(-1, 1).cuda(args.gpu)
        temp = torch.matmul(sigma, d.view(1, -1))  # [k, 1] [1, bs*N*N]
        return torch.sum(torch.exp(-temp), dim=0).reshape(v1.shape[0], v1.shape[1], v1.shape[1])

    def compute_loss(self, transitions):
        states = self.config.state_normalizer(transitions.state)
        next_states = self.config.state_normalizer(transitions.next_state)

        # new
        samples_, _ = self.target_network(next_states)
        samples_next = samples_['sample'].detach()
        a_next = torch.argmax(samples_next.sum(-1), dim=-1)
        samples_next = samples_next[self.batch_indices, a_next, :] # Z(s',a*) = [bs, N]

        rewards = tensor(transitions.reward).unsqueeze(-1)
        masks = tensor(transitions.mask).unsqueeze(-1)
        samples_next = rewards + self.config.discount ** self.config.n_step * masks * samples_next

        samples_, _ = self.network(states)
        samples = samples_['sample']

        actions = tensor(transitions.action).long()
        samples = samples[self.batch_indices, actions, :]
        x, y = samples, samples_next # [batch, N], e.g., N=200

        ############# embedding network
        # x, y = x.view(-1, 1), y.view(-1, 1)
        # print('view', x.shape, 32*200*4)
        # x, y = self.network_phi(x), self.network_phi(y)
        #############

        ############# Sinkhorn loss
        # x, y: [bs, N] rather than [batch, A, N]
        x, y = x.unsqueeze(2), y.unsqueeze(2) # [bs, N, p=1]
        Wxy = sinkhorn_loss(x, y, self.epsilon, self.config.num_samples, self.niter_sink)
        Wxx = sinkhorn_loss(x, x, self.epsilon, self.config.num_samples, self.niter_sink)
        Wyy = sinkhorn_loss(y, y, self.epsilon, self.config.num_samples, self.niter_sink)
        sink_loss = 2 * Wxy - Wxx - Wyy

        ############  official Sinkhorn loss, when epsilon is large, it is time-consuming
        # sink = SamplesLoss(loss='sinkhorn', p=2, blur=self.epsilon)
        # sink_loss = sink(x, y)

        return sink_loss

    def reduce_loss(self, loss):
        return loss.mean()


    def step(self):
        config = self.config
        transitions = self.actor.step()
        for states, actions, rewards, next_states, dones, info in transitions:
            self.record_online_return(info)
            self.total_steps += 1
            self.replay.feed(dict(
                state=np.array([s[-1] if isinstance(s, LazyFrames) else s for s in states]),
                action=actions,
                reward=[config.reward_normalizer(r) for r in rewards],
                mask=1 - np.asarray(dones, dtype=np.int32),
            ))

        if self.total_steps > self.config.exploration_steps:
            transitions = self.replay.sample()
            if config.noisy_linear:
                self.target_network.reset_noise()
                self.network.reset_noise()
            loss = self.compute_loss(transitions)
            if isinstance(transitions, PrioritizedTransition):
                priorities = loss.abs().add(config.replay_eps).pow(config.replay_alpha)
                idxs = tensor(transitions.idx).long()
                self.replay.update_priorities(zip(to_np(idxs), to_np(priorities)))
                sampling_probs = tensor(transitions.sampling_prob)
                weights = sampling_probs.mul(sampling_probs.size(0)).add(1e-6).pow(-config.replay_beta())
                weights = weights / weights.max()
                loss = loss.mul(weights)

            loss = self.reduce_loss(loss)
            self.optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(self.network.parameters(), self.config.gradient_clip)
            with config.lock:
                self.optimizer.step() # quantile network

            ########### new
            # loss2 = self.reduce_loss(-loss) #gradient ascend
            # self.optimizer_phi.zero_grad()
            # loss2.backward()
            # nn.utils.clip_grad_norm_(self.phinetwork.parameters(), self.config.gradient_clip)
            # with config.lock:
            #     self.optimizer_phi.step() # update phi network


        if self.total_steps / self.config.sgd_update_frequency % \
                self.config.target_network_update_freq == 0:
            self.target_network.load_state_dict(self.network.state_dict())