"""Various flavors of Q-learning algorithms."""
import collections
import copy

import numpy as np
from scipy import signal
from .explorers import *

import gym_minigrid
import gym

Transition = collections.namedtuple(
    "Transition", "state action reward next_state done return_ episode_num episode_step"
)

EXPLORER_DICTIONARY = {
    "classic": NaiveExplorer,
    "EASEE": GraphExplorer,
}


def act_epsilon_greedy(state, q_values, epsilon, explorer):
    """Acts epsilon-greedily w.r.t. Q-values."""
    num_actions = q_values.shape[1]
    if np.random.random() < epsilon:
        action = explorer.get_action()[0]
        explorer.add_action(action, exploration=True)
    else:
        action_values = q_values[state]
        max_value = np.max(action_values)
        try:
            action = np.random.choice(np.flatnonzero(action_values == max_value))
            explorer.add_action(action, exploration=False)
        except ValueError:
            print("Action values: ", action_values)
            print("Max value: ", max_value)
            raise

    return action


def linear_update_epsilon(
    epsilon_base,
    episode_num,
    num_episodes_total,
    anneal_epsilon_until,
    epsilon_min=0.1,
):
    """Updates epsilon for linear annealing."""
    assert 0.0 <= anneal_epsilon_until <= 1.0
    a = epsilon_min - epsilon_base
    b = epsilon_base
    x = float(episode_num) / (num_episodes_total * anneal_epsilon_until)
    return max(epsilon_min, a * x + b)


def q_learning(
    env,
    learning_rate=0.01,
    epsilon=0.1,
    num_steps=1000,
    use_target_network=False,
    update_target_every=1000,
    anneal_epsilon_until=None,
    augment_reward_func=None,
    explore_method="classic",
    depth=4,
):

    if explore_method != "EASEE":
        explorer = EXPLORER_DICTIONARY[explore_method](
            list(np.arange(env.num_actions)), depth
        )
    else:
        equalities = env.equalities
        explorer = EXPLORER_DICTIONARY[explore_method](
            list(np.arange(env.num_actions)), depth=depth, equalities=equalities
        )

    """Tabular Q-learning algorithm."""
    num_states = env.num_states
    num_actions = env.num_actions
    gamma = env.gamma
    epsilon_base = epsilon
    # Optimistic initialization.
    q_values = np.zeros((num_states, num_actions), dtype=np.float32)
    prev_q_values = copy.deepcopy(q_values)

    state = env.reset()
    episodic_returns = []
    episode_steps = []
    episode_return = 0.0
    for step in range(num_steps):
        # Choose action.
        action = act_epsilon_greedy(state, q_values, epsilon, explorer)
        # Act and observe.
        next_state, reward, done, _ = env.step(action)
        episode_return += reward
        # Optionally modify the reward.
        if augment_reward_func:
            reward = augment_reward_func(
                reward, env, q_values, state, action, next_state, episodic_returns
            )
        # Compute TD target.
        if not done:
            if use_target_network:
                td_target = reward + gamma * np.max(prev_q_values[next_state])
            else:
                td_target = reward + gamma * np.max(q_values[next_state])
        else:
            td_target = reward
        # Update Q-values.
        q_update = td_target - q_values[state, action]
        q_values[state, action] = q_values[state, action] + learning_rate * q_update
        # Update state.
        if not done:
            state = next_state
        else:
            state = env.reset()
            episodic_returns.append(episode_return)
            episode_return = 0.0
            episode_steps.append(step)
            if anneal_epsilon_until:
                epsilon = linear_update_epsilon(
                    epsilon_base, step, num_steps, anneal_epsilon_until
                )
        # Overwrite previous Q-values.
        if (step + 1) % update_target_every == 0:
            prev_q_values = copy.deepcopy(q_values)

    return q_values, episodic_returns, episode_steps


# pylint:disable=line-too-long
# See https://stackoverflow.com/questions/47970683/vectorize-a-numpy-discount-calculation
def episode_returns(transitions, gamma):
    rewards = [transition.reward for transition in transitions][::-1]
    return signal.lfilter([1], [1, -gamma], x=rewards)[::-1]


# pylint:enable=line-too-long


def update_return(transition, return_):
    return Transition(
        transition.state,
        transition.action,
        transition.reward,
        transition.next_state,
        transition.done,
        return_,
        transition.episode_num,
        transition.episode_step,
    )


def td_lambda_target(
    transition,
    transitions_next,
    target_q_values,
    n_step,
    lambda_,
    gamma,
    reward=None,
    monte_carlo_updates=False,
):
    """Calculates the TD(lambda) target for Q-value updates."""
    if not reward:
        reward = transition.reward
    if monte_carlo_updates:
        td_target = transition.return_
    elif not transition.done:
        assert n_step >= 1.0
        td_target = reward
        discount_factor = 1.0
        if lambda_ > 0.0:
            assert n_step > 1.0
            lambda_target = (1.0 - lambda_) * (
                reward
                + discount_factor * np.max(target_q_values[transitions_next[0].state])
            )
            lambda_factor = 1.0
        done = False
        k = 0
        for k in range(n_step - 1):
            discount_factor *= gamma
            transition_next = transitions_next[k]
            td_target += discount_factor * transition_next.reward
            if lambda_ > 0.0:
                kstep_td_target = td_target
                if not transition_next.done:
                    kstep_td_target += discount_factor * np.max(
                        target_q_values[transitions_next[k + 1].state]
                    )
                    if k == n_step - 1:
                        lambda_target += lambda_factor * kstep_td_target
                    else:
                        lambda_target += (
                            (1.0 - lambda_) * lambda_factor * kstep_td_target
                        )
                    lambda_factor *= lambda_
                else:
                    lambda_target += lambda_factor * kstep_td_target
            if transition_next.done:
                done = True
                break
        if not done:
            discount_factor *= gamma
            td_target += discount_factor * np.max(
                target_q_values[transitions_next[k].state]
            )
    else:
        td_target = transition.reward
    return td_target
