from functools import partial
import itertools
import math

import jax
import jax.numpy as jnp
import numpy as np

from agents.control.dqn import DQN, fast_nstep_return


class ALR(DQN):
    """DQN with approximate lambda-return from averaged n-step returns"""

    def _define_update(self):
        discount = self.discount
        estimator = self.estimator

        assert estimator.startswith('ql-pilar-')
        _, _, effective_n = estimator.split('-')
        effective_n = float(effective_n)
        (n1, n2, w), error = best_approximation(effective_n, discount)

        self.n2 = n2
        assert 1 <= n1 <= n2
        print("n={} --> (n1, n2, w)={}, error={}".format(effective_n, (n1, n2, w), round(error, 4)))

        @jax.jit
        def trajectory_loss(params, target_params, obs, actions, rewards, terminateds, truncateds):
            # Just need to compute the first Q-value of the sequence with main parameters
            Q_main = self.q_values(params, obs[0, None])
            q_main_taken = Q_main[0, actions[0]]

            dones = jnp.logical_or(terminateds, truncateds)  # End of episode, regardless of reason

            Q_func = lambda s: self.q_values(target_params, s)
            G1 = fast_nstep_return(n1, Q_func, obs, rewards[:-1], terminateds[:-1], dones[:-1], discount)
            G2 = fast_nstep_return(n2, Q_func, obs, rewards[:-1], terminateds[:-1], dones[:-1], discount)
            G = (1-w) * G1 + w * G2
            G = jax.lax.stop_gradient(G)

            return 0.5 * jnp.square(G - q_main_taken)

        vmap_trajectory_loss = jax.vmap(trajectory_loss, in_axes=[None, None, 0, 0, 0, 0, 0])

        @jax.jit
        def update(opt_state, target_params, obs, actions, rewards, terminateds, truncateds, t):
            def loss(params):
                losses = vmap_trajectory_loss(params, target_params, obs, actions, rewards, terminateds, truncateds)
                return jnp.mean(losses)

            params = self.get_params(opt_state)
            step = jax.grad(loss)(params)
            opt_state = self.opt_update(t, step, opt_state)
            return opt_state

        self.update = update

    def reinforce(self, obs, action, next_obs, reward, terminated, truncated, b_prob):
        self.replay_memory.save(obs, action, reward, terminated, truncated, b_prob)
        self.update_target_network()

        if self.t <= self.prepop:
            return

        if self.train_period == 1 or (self.t % self.train_period) == 1:
            minibatch = self.replay_memory.sample_trajectories(self.batch_size, length=self.n2+1)
            minibatch = minibatch[:-1]  # Slice off behavior probabilities
            self.opt_state = self.update(self.opt_state, self.target_params, *minibatch, self.t)


def best_approximation(effective_n, discount):
    assert effective_n >= 1
    assert 0.0 < discount < 1.0
    lambd = (1 - pow(discount, effective_n - 1)) / (1 - pow(discount, effective_n))

    def error_func(n1, n2):
        N = 10_000  # Number of terms in approximation
        pilar_weight, w = get_pilar_weight_func(effective_n, discount, n1, n2)
        error = max([abs(pilar_weight(i) - pow(discount * lambd, i)) for i in range(N + 1)])
        return error, w

    best_values = None
    best_error = float('inf')

    for n1 in range(1, math.floor(effective_n) + 1):
        prev_error = float('inf')

        for n2 in itertools.count(start=math.floor(effective_n) + 1):
            error, w = error_func(n1, n2)

            if error < best_error:
                best_values = (n1, n2, w)
                best_error = error

            if error >= prev_error:
                break
            prev_error = error

    # Sanity check: make sure contraction rates match
    cr = (1-w) * pow(discount, n1) + w * pow(discount, n2)
    expected_cr = pow(discount, effective_n)
    assert np.allclose(cr, expected_cr), f"contraction rate sanity check failed: {cr} != {expected_cr}"

    return best_values, best_error


def get_pilar_weight_func(effective_n, discount, n1, n2):
    assert n1 <= effective_n < n2
    assert 0.0 < discount < 1.0
    w = (discount**n1 - discount**effective_n) / (discount**n1 - discount**n2)

    def pilar_weight(i):
        if i < n1:
            return pow(discount, i)
        if i < n2:
            return w * pow(discount, i)
        return 0.0

    return pilar_weight, w
