import math
from functools import partial
from typing import Callable

import jax
import jax.numpy as jnp
import optax
from haiku import Params

from medium_rl.config import Config
from medium_rl.data.transition import Transition
from medium_rl.envs.sequence_env import SequenceEnv
from medium_rl.init import RunState


@partial(jax.jit, static_argnames=["forward", "policy_fn", "env"])
def gen_transition(run_state: RunState, forward: Callable, policy_fn: Callable, env: SequenceEnv):
    def _env_step(run_state: RunState, step):
        rng, policy_rng = jax.random.split(run_state.rng)

        # Select action
        env_state = run_state.env_state
        obs, legal_action_mask = env_state.obs, env_state.legal_action_mask
        logits = forward.apply(run_state.params, None, obs, is_training=False)[0]
        action = policy_fn(
            logits,
            legal_action_mask,
            step,
            policy_rng,
        )

        # Step env
        env_state = env.step_fn(env_state, action)

        transition = Transition(
            obs,
            env_state.obs,
            legal_action_mask,
            env_state.legal_action_mask,
            action,
            env_state.terminating.astype(int),  # Whether the action terminated the env
            env_state.terminated.astype(int),
            jnp.zeros(obs.shape[0]),  # Filler
            jnp.zeros(obs.shape[0]),  # Filler
            jnp.zeros(obs.shape[0]),  # Filler
            step * jnp.ones(obs.shape[0], dtype=jnp.int32),
        )

        run_state = run_state.replace(env_state=env_state, rng=rng)
        return run_state, transition

    # This accumulates transitions to create a subtrajectory of shape T x B x ...
    run_state, transition_batch = jax.lax.scan(_env_step, run_state, jnp.arange(env.max_len))
    transition_batch = jax.tree_util.tree_map(lambda x: x.swapaxes(0, 1), transition_batch)  # [B, T, ...]

    # Reset environments
    rng, reset_rng = jax.random.split(run_state.rng)
    keys = jax.random.split(reset_rng, run_state.env_state.obs.shape[0])
    run_state = run_state.replace(env_state=env.reset_fn(run_state.env_state, keys), rng=rng)

    return run_state, transition_batch


def make_sac_loss_fn(cfg: Config):
    def loss_fn(all_curr_logits, all_next_logits, target_next_logits, transitions):
        action = transitions.action
        reward = transitions.terminating * (jnp.log(transitions.reward) * cfg.reward_exp)
        step = transitions.step

        # Critic training
        q1_next = jax.lax.stop_gradient(target_next_logits[1][jnp.arange(cfg.minibatch_size), step + 1])
        q2_next = jax.lax.stop_gradient(target_next_logits[2][jnp.arange(cfg.minibatch_size), step + 1])
        pi_next = all_next_logits[0][jnp.arange(cfg.minibatch_size), step + 1]
        pi_next = jnp.where(transitions.next_legal_action_mask, pi_next, -jnp.inf)
        pi_next = jax.lax.stop_gradient(jax.nn.softmax(pi_next, axis=-1))
        min_q_target = pi_next * (jnp.minimum(q1_next, q2_next) - (1 / cfg.alg.omega) * jnp.log(pi_next))
        # Ensure the nans from illegal actions aren't used in sum
        min_q_target = jnp.where(transitions.next_legal_action_mask, min_q_target, 0)

        # If a its a terminating action, there is no next q value
        q_target_next = reward + (1 - transitions.terminating) * min_q_target.sum(axis=-1)

        q1_pred = all_curr_logits[1][jnp.arange(cfg.minibatch_size), step]
        q1_pred_selected = jnp.take_along_axis(q1_pred, action[..., None], axis=-1).squeeze(axis=-1)
        q2_pred = all_curr_logits[2][jnp.arange(cfg.minibatch_size), step]
        q2_pred_selected = jnp.take_along_axis(q2_pred, action[..., None], axis=-1).squeeze(axis=-1)

        # jax.debug.print("{x}", x=q_target_next)
        q_loss = (q1_pred_selected - q_target_next) ** 2 + (q2_pred_selected - q_target_next) ** 2
        q_loss = ((1 - transitions.done) * q_loss).sum() / (1 - transitions.done).sum()

        # Actor training
        pi_curr = all_curr_logits[0][jnp.arange(cfg.minibatch_size), step]
        pi_curr = jnp.where(transitions.legal_action_mask, pi_curr, -jnp.inf)
        pi_curr = jax.nn.softmax(pi_curr, axis=-1)
        log_pi_curr = jax.nn.log_softmax(pi_curr, axis=-1)

        # Necessary to avoid nans
        pi_curr = jnp.where(transitions.legal_action_mask, pi_curr, 0)
        log_pi_curr = jnp.where(transitions.legal_action_mask, log_pi_curr, 0)

        q1_pred = jax.lax.stop_gradient(all_curr_logits[1][jnp.arange(cfg.minibatch_size), step])
        q2_pred = jax.lax.stop_gradient(all_curr_logits[2][jnp.arange(cfg.minibatch_size), step])
        min_pred = jnp.minimum(q1_pred, q2_pred)

        # CleanRL just does mean at the end?
        actor_loss = pi_curr * ((1 / cfg.alg.omega) * log_pi_curr - min_pred)
        actor_loss = jnp.where(transitions.legal_action_mask, actor_loss, 0)
        actor_loss = actor_loss.sum(axis=-1) / transitions.legal_action_mask.sum(axis=-1)
        actor_loss = ((1 - transitions.done) * actor_loss).mean()

        return q_loss + actor_loss

    return loss_fn


def make_sac_train_step_fn(
    env: SequenceEnv,
    buffer,
    forward: Callable,
    policy_fn: Callable,
    loss_fn: Callable,
    optimizer,
    cfg: Config,
):
    @jax.jit
    def update_network(run_state: RunState, transitions: Transition):
        def _grad_loss_fn(params: Params, target_params: Params, transitions: Transition, rng: jax.random.PRNGKey):
            all_curr_logits = forward.apply(params, rng, transitions.obs, is_training=True)
            all_next_logits = forward.apply(params, rng, transitions.next_obs, is_training=True)
            target_next_logits = forward.apply(target_params, rng, transitions.next_obs, is_training=True)
            loss = loss_fn(all_curr_logits, all_next_logits, target_next_logits, transitions)

            return loss.mean()

        params, opt_state, rng = run_state.params, run_state.opt_state, run_state.rng
        rng, network_rng = jax.random.split(rng)

        # Update params
        grad_fn = jax.value_and_grad(_grad_loss_fn, argnums=0, has_aux=False)
        total_loss, grads = grad_fn(params, run_state.target_params, transitions, network_rng)
        updates, opt_state = optimizer.update(grads, opt_state, params=params)
        params = optax.apply_updates(params, updates)

        run_state = run_state.replace(params=params, opt_state=opt_state, rng=rng)
        return run_state, total_loss

    def train_step(run_state: RunState, buffer_state):
        """MAIN LOOP"""
        # Collect data
        run_state, transitions = gen_transition(run_state, forward, policy_fn, env)
        gen_samples = transitions.obs[:, -1]
        rewards, extra_oracle_info = env.get_rewards(gen_samples)

        # Add rewards to all transition (in loss only included for terminal)
        transitions_rewards = jnp.repeat(rewards, gen_samples.shape[-1], axis=1)
        transitions = transitions.replace(reward=transitions_rewards)
        B, T = transitions.obs.shape[:2]
        flattened_transitions = jax.tree_util.tree_map(lambda x: x.reshape((B * T, *x.shape[2:])), transitions)
        buffer_state = buffer.add(buffer_state, flattened_transitions)

        # Sample and train
        num_batches = math.ceil((cfg.env.max_len * cfg.num_envs) / cfg.minibatch_size)
        for _ in range(num_batches):
            rng, sample_rng = jax.random.split(run_state.rng)
            run_state = run_state.replace(rng=rng)
            sampled_transitions = buffer.sample(buffer_state, sample_rng).experience
            run_state, total_loss = update_network(run_state, sampled_transitions)

        return run_state, buffer_state, total_loss, gen_samples, rewards, extra_oracle_info

    return train_step
