from collections import namedtuple
from functools import partial

import gym
import jax
import jax.numpy as jnp
import numpy as np

from agents import ControlAgent
import extractors
from extractors import vector_ops as vect
import optimizers
import schedules


class WQL(ControlAgent):
    """Watkins' Q(lambda)"""
    Parameters = namedtuple('Parameters', ['theta', 'w'])

    def __init__(self, observation_space, action_space, seed, discount, extractor='none', opt='adam', lr=3e-4,
                 epsilon='toy-control', target_period=1, dueling=True, lambd=0.0, cut_traces=True, dual_traces=True):
        assert isinstance(observation_space, gym.spaces.Box)
        assert isinstance(action_space, gym.spaces.Discrete)
        assert lr > 0.0
        assert 0.0 <= lambd <= 1.0

        super().__init__(observation_space, action_space, seed, discount)
        self.extractor = extractor = extractors.make(extractor)
        self.opt_cls = getattr(optimizers, opt)
        self.lr = lr
        self.epsilon_schedule = schedules.make(epsilon)
        self.target_period = target_period
        assert isinstance(dueling, bool)
        self.dueling = dueling
        self.lambd = lambd
        assert isinstance(cut_traces, bool)
        self.cut_traces = cut_traces
        assert isinstance(dual_traces, bool)
        self.dual_traces = dual_traces

        input_shape = observation_space.shape
        prng_key = jax.random.PRNGKey(seed)

        theta, features, prng_key = extractor.generate_parameters(input_shape, prng_key)
        w = jnp.zeros([features + 1, action_space.n + 1])
        self.init_params = self.Parameters(theta, w)

        z = vect.zeros_like(self.init_params)
        z2 = vect.zeros_like(self.init_params)
        self.etraces = (z, z2)

        self._define_forward(extractor)
        self._define_update()
        self.t = 0

    def _define_forward(self, extractor):
        # Make optimizer
        self.opt_init, self.opt_update, self.get_params = self.opt_cls(self.lr)
        self.opt_state = self.opt_init(self.init_params)

        def features(params, obs):
            return extractor.forward(params.theta, obs)
        self.features = features

        def q_values(params, obs):
            x = features(params, obs)
            # Append bias term to features
            x = jnp.append(x, jnp.ones_like(x[:, 0, None]), axis=1)

            y = x.dot(params.w)
            if not self.dueling:
                Q = y[:, 1:]
            else:
                V = y[:, 0, None]
                A = y[:, 1:]
                A_max = jnp.max(A, axis=-1, keepdims=True)
                Q = V + A - A_max
            return Q
        self.q_values = q_values
        self.jit_q_values = jax.jit(q_values)

        def get_q_value(params, obs, action):
            q = q_values(params, obs)
            return q[0, action]
        self.get_q_value = get_q_value

    def _define_update(self):
        discount = self.discount
        lambd = self.lambd
        cut_traces = self.cut_traces
        dual_traces = self.dual_traces

        @partial(jax.jit, static_argnames=['terminated', 'truncated'])
        def update(opt_state, etraces, target_params, obs, action, next_obs, reward, terminated, truncated, t):
            params = self.get_params(opt_state)
            z, z2 = etraces

            q_main, grads = jax.value_and_grad(self.get_q_value)(params, obs, action)
            q_target = self.get_q_value(target_params, obs, action)

            # Exclusively use target parameters for TD error
            td_error = reward - q_target
            if not terminated:
                next_v_target = jnp.max(self.q_values(target_params, next_obs), axis=-1)
                td_error += discount * next_v_target

            z = vect.etrace(z, discount * lambd, grads)
            step = vect.scale(td_error, z)

            if dual_traces:
                # Secondary trace z2 corrects the lambda-error by making it relative to the
                # main network's prediction instead of the target network's prediction
                z2 = vect.etrace(z2, discount * lambd, vect.scale(q_target - q_main, grads))
                step = vect.add(step, z2)  # Add correction

            step = vect.scale(-1, step)  # Flip sign
            opt_state = self.opt_update(t, step, opt_state)

            if terminated or truncated:
                z = vect.zeros_like(z)
                z2 = vect.zeros_like(z2)

            elif cut_traces:
                # TODO: Inefficient to do forward pass twice here
                v_main = jnp.max(self.q_values(params, obs), axis=-1)
                not_greedy = (q_main != v_main)

                z = vect.conditional_zeros_like(z, not_greedy)
                z2 = vect.conditional_zeros_like(z2, not_greedy)

            return opt_state, (z, z2)

        self.update = update

    @property
    def params(self):
        return self.get_params(self.opt_state)

    def reinforce(self, obs, action, next_obs, reward, terminated, truncated, b_prob):
        self.update_target_network()
        minibatch = (obs[None], action, next_obs[None], reward, terminated, truncated)
        self.opt_state, self.etraces = self.update(self.opt_state, self.etraces, self.target_params, *minibatch, self.t)

    def update_target_network(self):
        if self.target_period == 1:
            self.target_params = self.params
            return

        if (self.t % self.target_period) == 1:
            self.target_params = vect.copy(self.params)

    def act(self, obs):
        self.t += 1

        if self.t < self.prepop:
            epsilon = 1.0
        else:
            epsilon = self.epsilon_schedule(self.t - self.prepop)
        assert 0.0 <= epsilon <= 1.0

        if self.np_random.random() <= epsilon:
            prob = epsilon / self.action_space.n
            return self.action_space.sample(), prob

        obs = obs[None]  # Add batch dimension
        q = self.jit_q_values(self.params, obs)
        q = q[0]  # Remove batch dimension

        prob = 1 - epsilon + (epsilon / self.action_space.n)
        return self._argmax(q), prob

    def _argmax(self, q):
        assert not np.isnan(q).all(), "cannot have NaN inputs"
        return np.argmax(q).item()
