from functools import partial
from typing import Callable

import jax
import jax.numpy as jnp
import optax
from haiku import Params
from jax import Array

from medium_rl.config import Config
from medium_rl.data.trajectory import SubTrajectory
from medium_rl.envs.sequence_env import SequenceEnv
from medium_rl.init import RunState


@partial(jax.jit, static_argnames=["alpha", "omega", "q", "q_fn"])
def tgm_loss(
    policy_logits: Array,
    legal_action_mask: Array,
    done_mask: Array,
    action: Array,
    reward: Array,
    alpha: float = 1,
    omega: float = 1,
    q: float = 1,
    q_fn: Callable = lambda x: x,
):
    lsm_diff = jax.nn.log_softmax(
        jnp.where(
            legal_action_mask == 1,
            q * alpha * q_fn(policy_logits) + omega * policy_logits,
            -jnp.inf,
        ),
        axis=-1,
    )
    lsm_diff = lsm_diff - q * jax.nn.log_softmax(
        jnp.where(legal_action_mask == 1, alpha * q_fn(policy_logits), -jnp.inf), axis=-1
    )
    selected_lsm_diff = jnp.take_along_axis(lsm_diff, action[..., None], axis=-1).squeeze(axis=-1)
    selected_lsm_diff = selected_lsm_diff * done_mask

    return reward - (1 / omega) * selected_lsm_diff.sum(axis=-1)


def make_tgm_loss_fn(cfg: Config):
    def loss_fn(policy_logits, legal_action_mask, done_mask, action, reward):
        return tgm_loss(
            policy_logits,
            legal_action_mask,
            done_mask,
            action,
            reward,
            cfg.alg.alpha,
            cfg.alg.omega,
            cfg.alg.q,
            cfg.alg.q_fn,
        )

    return loss_fn


@partial(jax.jit, static_argnames=["forward", "policy_fn", "env"])
def gen_traj(run_state: RunState, forward: Callable, policy_fn: Callable, env: SequenceEnv):
    def _env_step(run_state: RunState, step):
        rng, policy_rng = jax.random.split(run_state.rng)

        # Select action
        env_state = run_state.env_state
        obs, legal_action_mask = env_state.obs, env_state.legal_action_mask
        logits = forward.apply(run_state.params, None, obs, is_training=False)[0]
        action = policy_fn(
            logits,
            legal_action_mask,
            step,
            policy_rng,
        )

        # Step env
        env_state = env.step_fn(env_state, action)

        transition = SubTrajectory(
            obs[:, step],
            legal_action_mask,
            action,
            env_state.terminated.astype(int),  # Whether the action terminated the env
            jnp.zeros(1),  # Filler
        )

        run_state = run_state.replace(env_state=env_state, rng=rng)
        return run_state, transition

    # This accumulates transitions to create a subtrajectory of shape T x B x ...
    run_state, sub_traj_batch = jax.lax.scan(_env_step, run_state, jnp.arange(env.max_len))
    sub_traj_batch = jax.tree_util.tree_map(
        lambda x: x.swapaxes(0, 1), sub_traj_batch
    )  # [B, T, ...]

    # Reset environments
    rng, reset_rng = jax.random.split(run_state.rng)
    keys = jax.random.split(reset_rng, run_state.env_state.obs.shape[0])
    run_state = run_state.replace(env_state=env.reset_fn(run_state.env_state, keys), rng=rng)

    return run_state, sub_traj_batch


def make_tgm_train_step_fn(
    env: SequenceEnv,
    buffer,
    forward: Callable,
    policy_fn: Callable,
    loss_fn: Callable,
    optimizer,
    cfg: Config,
):
    @jax.jit
    def update_network(run_state: RunState, sub_traj_batch: SubTrajectory):
        def _grad_loss_fn(params: Params, sub_traj_batch: SubTrajectory, rng: jax.random.PRNGKey):
            policy_logits = forward.apply(params, rng, sub_traj_batch.obs, is_training=True)[0]
            done_mask = 1 - sub_traj_batch.done[..., None].squeeze(-1)

            action = sub_traj_batch.action
            reward = jnp.log(sub_traj_batch.reward).squeeze(-1) * cfg.reward_exp
            loss = loss_fn(
                policy_logits,
                sub_traj_batch.legal_action_mask,
                done_mask,
                action,
                reward,
            )

            return loss.var()

        params, opt_state, rng = run_state.params, run_state.opt_state, run_state.rng
        rng, network_rng = jax.random.split(rng)

        # Update params
        grad_fn = jax.value_and_grad(_grad_loss_fn, argnums=0, has_aux=False)
        total_loss, grads = grad_fn(params, sub_traj_batch, network_rng)
        updates, opt_state = optimizer.update(grads, opt_state, params=params)
        params = optax.apply_updates(params, updates)

        run_state = run_state.replace(params=params, opt_state=opt_state, rng=rng)
        return run_state, total_loss

    def train_step(run_state: RunState, buffer_state):
        """MAIN LOOP"""
        # Collect data
        run_state, sub_traj_batch = gen_traj(run_state, forward, policy_fn, env)
        new_samples = sub_traj_batch.obs
        rewards, extra_oracle_info = env.get_rewards(new_samples)
        sub_traj_batch = sub_traj_batch.replace(reward=rewards)

        # Add to buffer/sample
        buffer_state = buffer.add(buffer_state, sub_traj_batch)
        rng, sample_rng = jax.random.split(run_state.rng)
        run_state = run_state.replace(rng=rng)
        sub_traj_batch = buffer.sample(buffer_state, sample_rng).experience

        # Train
        run_state, total_loss = update_network(run_state, sub_traj_batch)

        return run_state, buffer_state, total_loss, new_samples, rewards, extra_oracle_info

    return train_step
