from functools import partial

import gym
import jax
import jax.numpy as jnp

from agents.control.wql import WQL
from data_structures import ReplayMemory
import extractors
import optimizers
import schedules


class DQN(WQL):
    """Deep Q-Network with n-step returns"""
    def __init__(self, observation_space, action_space, seed, discount, extractor='none', opt='adam', lr=3e-4, train_period=4,
                 epsilon=0.05, prepop=50_000, target_period=1, dueling=True, rmem_size=500_000, estimator='ql-nstep-1', batch_size=32):
        assert isinstance(observation_space, gym.spaces.Box)
        assert isinstance(action_space, gym.spaces.Discrete)
        assert lr > 0.0

        super(WQL, self).__init__(observation_space, action_space, seed, discount)
        self.extractor = extractor = extractors.make(extractor)
        self.opt_cls = getattr(optimizers, opt)
        self.lr = lr
        assert train_period >= 1
        self.train_period = train_period
        self.epsilon_schedule = schedules.make(epsilon)
        self.prepop = prepop

        self.target_period = target_period
        assert isinstance(dueling, bool)
        self.dueling = dueling

        self.replay_memory = ReplayMemory(rmem_size, seed)
        self.estimator = estimator
        assert batch_size >= 1
        self.batch_size = batch_size

        input_shape = observation_space.shape
        prng_key = jax.random.PRNGKey(seed)
        theta, features, prng_key = extractor.generate_parameters(input_shape, prng_key)
        w = jnp.zeros([features + 1, action_space.n + 1])
        self.init_params = self.Parameters(theta, w)

        self._define_forward(extractor)
        self._define_update()
        self.t = 0

    def _define_update(self):
        discount = self.discount
        estimator = self.estimator

        @partial(jax.jit, static_argnames=['nstep'])
        def trajectory_loss(params, target_params, obs, actions, rewards, terminateds, truncateds, nstep):
            # Just need to compute the first Q-value of the sequence with main parameters
            Q_main = self.q_values(params, obs[0, None])
            q_main_taken = Q_main[0, actions[0]]

            dones = jnp.logical_or(terminateds, truncateds)  # End of episode, regardless of reason

            Q_func = lambda s: self.q_values(target_params, s)
            G = fast_nstep_return(nstep, Q_func, obs, rewards[:-1], terminateds[:-1], dones[:-1], discount)
            G = jax.lax.stop_gradient(G)

            return 0.5 * jnp.square(G - q_main_taken)

        vmap_trajectory_loss = jax.vmap(trajectory_loss, in_axes=[None, None, 0, 0, 0, 0, 0, None])

        @partial(jax.jit, static_argnames=['nstep'])
        def update(opt_state, target_params, obs, actions, rewards, terminateds, truncateds, nstep, t):
            def loss(params):
                losses = vmap_trajectory_loss(params, target_params, obs, actions, rewards, terminateds, truncateds, nstep)
                return jnp.mean(losses)

            params = self.get_params(opt_state)
            step = jax.grad(loss)(params)
            opt_state = self.opt_update(t, step, opt_state)
            return opt_state

        self.update = update

    def reinforce(self, obs, action, next_obs, reward, terminated, truncated, b_prob):
        self.replay_memory.save(obs, action, reward, terminated, truncated, b_prob)
        self.update_target_network()

        if self.t <= self.prepop:
            return

        if self.train_period == 1 or (self.t % self.train_period) == 1:
            n = self.get_n()
            minibatch = self.replay_memory.sample_trajectories(self.batch_size, length=n+1)
            minibatch = minibatch[:-1]  # Slice off behavior probabilities
            self.opt_state = self.update(self.opt_state, self.target_params, *minibatch, n, self.t)

    def get_n(self):
        if not hasattr(self, '_get_n'):
            est = self.estimator

            if est.startswith('ql-nstep-'):  # n-step returns
                _, _, n = est.split('-')
                n = int(n)
                assert n >= 1
                self._get_n = lambda: n
            else:
                raise ValueError("unsupported return estimator '{}'".format(est))

        return self._get_n()


def fast_nstep_return(n, Q_func, obs, rewards, terminateds, dones, discount):
    def bootstrap(i):
        Q_targ = Q_func(obs[i+1, None])
        v = jnp.max(Q_targ, axis=-1)
        return jnp.where(terminateds[i], 0, v)

    bs_index = n - 1
    G = 0.0
    for i in reversed(range(n)):
        bs_index = jnp.where(dones[i], i, bs_index)
        G = jnp.where(dones[i], 0.0, G)
        G = rewards[i] + (discount * G)

    G += jnp.power(discount, bs_index + 1) * bootstrap(bs_index)
    return G
