from typing import Sequence

import jax
import jax.numpy as jnp
import distrax
import flax.linen as nn
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
import optax


# models (iql)
LOG_STD_MAX = 2
LOG_STD_MIN = -20


def normalize(x, stats):
    return (x - stats["mean"]) / (stats["std"] + 1e-3)


class SoftQNetworkIQL(nn.Module):
    activation: str = "tanh"
    obs_stats: dict = None

    @nn.compact
    def __call__(self, obs, action):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
        if self.obs_stats:
            obs = normalize(obs, self.obs_stats)
        x = jnp.concatenate([obs, action], axis=-1)
        x = nn.Dense(256)(x)
        x = activation(x)
        x = nn.Dense(256)(x)
        x = activation(x)
        q = nn.Dense(1)(x)
        return jnp.squeeze(q, axis=-1)


class VectorCritic(nn.Module):
    activation: str = "tanh"
    n_critics: int = 2
    obs_stats: dict = None

    @nn.compact
    def __call__(self, obs, action):
        # Idea taken from https://github.com/perrin-isir/xpag
        # Similar to https://github.com/tinkoff-ai/CORL for PyTorch
        vmap_critic = nn.vmap(
            SoftQNetworkIQL,
            variable_axes={"params": 0},  # parameters not shared between the critics
            split_rngs={"params": True, "dropout": True},  # different initializations
            in_axes=None,
            out_axes=-1,
            axis_size=self.n_critics,
        )
        q_values = vmap_critic(activation=self.activation, obs_stats=self.obs_stats)(
            obs, action
        )
        return q_values


class ValueFunction(nn.Module):
    activation: str = "tanh"
    obs_stats: dict = None
    NO_ACTION_INPUT: None = None

    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
        if self.obs_stats:
            x = normalize(x, self.obs_stats)
        x = nn.Dense(256)(x)
        x = activation(x)
        x = nn.Dense(256)(x)
        x = activation(x)
        x = nn.Dense(1)(x)
        return jnp.squeeze(x, axis=-1)


class TanhGaussianActor(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"
    action_lims: Sequence[float] = (-1.0, 1.0)
    obs_stats: dict = None
    eval: bool = False

    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
        if self.obs_stats:
            x = normalize(x, self.obs_stats)
        x = nn.Dense(256)(x)
        x = activation(x)
        x = nn.Dense(256)(x)
        x = activation(x)
        x = nn.Dense(self.action_dim)(x)
        x = nn.tanh(x)
        action_scale = (self.action_lims[1] - self.action_lims[0]) / 2
        action_bias = (self.action_lims[1] + self.action_lims[0]) / 2
        mean = action_bias + action_scale * x
        logstd = self.param(
            "logstd",
            init_fn=lambda key: jnp.zeros(self.action_dim, dtype=jnp.float32),
        )
        std = jnp.exp(jnp.clip(logstd, LOG_STD_MIN, LOG_STD_MAX))
        pi = distrax.Deterministic(mean) if self.eval else distrax.Normal(mean, std)
        return pi


# models (td3)
class SoftQNetwork(nn.Module):
    activation: str = "relu"
    obs_stats: dict = None

    @nn.compact
    def __call__(self, obs, action):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
        if self.obs_stats:
            obs = normalize(obs, self.obs_stats)
        x = jnp.concatenate([obs, action], axis=-1)
        x = nn.Dense(256)(x)
        x = activation(x)
        x = nn.Dense(256)(x)
        x = activation(x)
        q = nn.Dense(1)(x)
        return q


class TanhDeterministicActor(nn.Module):
    action_dim: Sequence[int]
    activation: str = "relu"
    action_lims: Sequence[float] = (-1.0, 1.0)
    obs_stats: dict = None
    NO_ACTION_INPUT: None = None

    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
        if self.obs_stats:
            x = normalize(x, self.obs_stats)
        x = nn.Dense(256)(x)
        x = activation(x)
        x = nn.Dense(256)(x)
        x = activation(x)
        action = nn.Dense(self.action_dim)(x)
        action_scale = (self.action_lims[1] - self.action_lims[0]) / 2
        action_bias = (self.action_lims[1] + self.action_lims[0]) / 2
        pi = distrax.Transformed(
            distrax.Deterministic(action),
            # Note: Chained bijectors applied in reverse order
            distrax.Chain(
                [
                    distrax.ScalarAffine(action_bias, action_scale),
                    distrax.Tanh(),
                ]
            ),
        )
        return pi


# agents (iql)
EXP_ADV_MAX = 100.0

def iql_train_step(args, network, aux_networks):
    q_network, _, value_network = aux_networks

    def _update_step(train_state, aux_train_states, traj_batch, rng):
        q_train_state, q_target_train_state, value_train_state = aux_train_states
        traj_batch = jax.tree_util.tree_map(
            lambda x: x.reshape((-1, x.shape[-1])), traj_batch
        )

        # --- Update target networks ---
        # Done first to ensure correct target initialization
        new_target_params = jax.tree_map(
            lambda x, y: jnp.where(q_target_train_state.step == 0, x, y),
            q_train_state.params,
            optax.incremental_update(
                q_train_state.params,
                q_target_train_state.params,
                args.polyak_step_size,
            ),
        )
        q_target_train_state = q_target_train_state.replace(
            step=q_target_train_state.step + 1,
            params=new_target_params,
        )

        # --- Compute value targets ---
        q_target_preds = q_network.apply(
            q_target_train_state.params, traj_batch.obs, traj_batch.action
        )
        value_targets = jnp.min(q_target_preds, axis=-1)
        next_v = value_network.apply(value_train_state.params, traj_batch.next_obs)

        # --- Update value function ---
        def _value_loss_fn(params):
            value_pred = value_network.apply(params, traj_batch.obs)
            adv = value_targets - value_pred
            # Asymmetric L2 loss
            value_loss = jnp.abs(
                args.iql_tau - jnp.where(adv < 0.0, 1.0, 0.0)
            ) * jnp.square(adv)
            return jnp.mean(value_loss), adv

        (value_loss, adv), value_grad = jax.value_and_grad(
            _value_loss_fn, has_aux=True
        )(value_train_state.params)
        value_train_state = value_train_state.apply_gradients(grads=value_grad)

        # --- Compute q targets ---
        def _compute_q_target(transition, next_v):
            return transition.reward + args.gamma * (1 - transition.done) * next_v

        q_targets = jax.vmap(_compute_q_target)(traj_batch, next_v)

        # --- Update q functions ---
        def _q_loss_fn(params):
            # Compute loss for both critics
            q_pred = q_network.apply(params, traj_batch.obs, traj_batch.action)
            q_loss = jnp.square(q_pred - q_targets).mean()
            return q_loss

        q_loss, q_grad = jax.value_and_grad(_q_loss_fn)(q_train_state.params)
        q_train_state = q_train_state.apply_gradients(grads=q_grad)

        # --- Update actor ---
        exp_adv = jnp.exp(adv * args.iql_beta).clip(max=EXP_ADV_MAX)

        def _actor_loss_function(params):
            def _compute_loss(transition, exp_adv):
                pi = network.apply(params, transition.obs)
                bc_losses = -pi.log_prob(transition.action)
                return exp_adv * bc_losses.sum()

            actor_loss = jax.vmap(_compute_loss)(traj_batch, exp_adv)
            return actor_loss.mean()

        actor_loss, actor_grad = jax.value_and_grad(_actor_loss_function)(
            train_state.params
        )
        train_state = train_state.apply_gradients(grads=actor_grad)

        loss = {
            "value_loss": value_loss,
            "q_loss": q_loss,
            "actor_loss": actor_loss,
        }
        metric = traj_batch.info
        loss = jax.tree_map(lambda x: x.mean(), loss)
        return (
            train_state,
            (q_train_state, q_target_train_state, value_train_state),
            loss,
            metric,
        )

    return _update_step


# agents (td3_bc)
def td3_bc_train_step(args, network, aux_networks):
    _, q_network, _, _, _ = aux_networks

    def _update_step(train_state, aux_train_states, traj_batch, rng):
        (
            actor_target_state,
            q1_state,
            q2_state,
            q1_target_state,
            q2_target_state,
        ) = aux_train_states
        traj_batch = jax.tree_util.tree_map(
            lambda x: x.reshape((-1, x.shape[-1])), traj_batch
        )

        # --- Update target networks ---
        def _update_target(state, target_state):
            # Done first to ensure correct target initialization
            new_target_params = jax.tree_map(
                lambda x, y: jnp.where(target_state.step == 0, x, y),
                state.params,
                optax.incremental_update(
                    state.params,
                    target_state.params,
                    args.polyak_step_size,
                ),
            )
            return target_state.replace(
                step=target_state.step + 1,
                params=new_target_params,
            )

        q1_target_state = _update_target(q1_state, q1_target_state)
        q2_target_state = _update_target(q2_state, q2_target_state)
        actor_target_state = _update_target(train_state, actor_target_state)

        # --- Update actor ---
        def _actor_loss_function(params, rng):
            def _transition_loss(rng, transition):
                pi = network.apply(params, transition.obs)
                sampled_action = pi.sample(seed=rng)
                q = q_network.apply(q1_state.params, transition.obs, sampled_action)
                bc_loss = jnp.square(sampled_action - transition.action).mean()
                return q, bc_loss

            rng, _rng = jax.random.split(rng)
            _rng = jax.random.split(_rng, len(traj_batch.reward))
            q, bc_loss = jax.vmap(_transition_loss)(_rng, traj_batch)
            lmbda = args.td3_alpha / (jnp.abs(q).mean() + 1e-7)
            lmbda = jax.lax.stop_gradient(lmbda)
            actor_loss = (-lmbda * q.mean()) + bc_loss.mean()
            return actor_loss.mean(), (q.mean(), lmbda.mean(), bc_loss.mean())

        rng, _rng = jax.random.split(rng)
        (actor_loss, (q_mean, lmbda, bc_loss)), actor_grad = jax.value_and_grad(
            _actor_loss_function, has_aux=True
        )(train_state.params, _rng)
        train_state = train_state.apply_gradients(grads=actor_grad)

        def _update_critics(runner_state, _):
            rng, q1_state, q2_state = runner_state

            # --- Compute targets ---
            def _compute_target(rng, transition):
                next_pi = network.apply(
                    actor_target_state.params, transition.next_obs
                )
                rng, _rng = jax.random.split(rng)
                next_action = next_pi.sample(seed=_rng)
                rng, _rng = jax.random.split(rng)
                rand_action = (
                    jax.random.normal(_rng, shape=next_action.shape) * args.policy_noise
                )
                rand_action = jnp.clip(
                    rand_action, a_min=-args.noise_clip, a_max=args.noise_clip
                )
                next_action = jnp.clip(
                    next_action + rand_action, a_min=-args.a_max, a_max=args.a_max
                )

                # Minimum of the target Q-values
                target_q1 = q_network.apply(
                    q1_target_state.params, transition.next_obs, next_action
                )
                target_q2 = q_network.apply(
                    q2_target_state.params, transition.next_obs, next_action
                )
                next_q_value = jnp.minimum(target_q1, target_q2)
                assert next_q_value.shape == transition.reward.shape
                return (
                    transition.reward
                    + args.gamma * (1 - transition.done) * next_q_value
                )

            rng, _rng = jax.random.split(rng)
            _rng = jax.random.split(_rng, len(traj_batch.reward))
            targets = jax.vmap(_compute_target)(_rng, traj_batch)

            # --- Update critics ---
            @jax.value_and_grad
            def _q_loss_fn(params):
                q_pred = q_network.apply(params, traj_batch.obs, traj_batch.action)
                assert q_pred.shape == targets.shape
                return jnp.square(q_pred - targets).mean()

            q1_loss, q1_grad = _q_loss_fn(q1_state.params)
            q1_state = q1_state.apply_gradients(grads=q1_grad)
            q2_loss, q2_grad = _q_loss_fn(q2_state.params)
            q2_state = q2_state.apply_gradients(grads=q2_grad)
            return (rng, q1_state, q2_state), (q1_loss, q2_loss)

        (rng, q1_state, q2_state), (q1_loss, q2_loss) = jax.lax.scan(
            _update_critics,
            (rng, q1_state, q2_state),
            None,
            length=args.num_critic_updates_per_step,
        )

        loss = {
            "q1_loss": q1_loss.mean(),
            "q2_loss": q2_loss.mean(),
            "actor_loss": actor_loss,
            "q_mean": q_mean,
            "lmbda": lmbda,
            "bc_loss": bc_loss,
        }
        metric = traj_batch.info
        return (
            train_state,
            (actor_target_state, q1_state, q2_state, q1_target_state, q2_target_state),
            loss,
            metric,
        )

    return _update_step


# agents
DETERMINISTIC_ACTORS = ["td3_bc"]


def get_agent(args, action_dim, action_lims, obs_stats=None):
    if args.agent == "td3_bc":

        auxilary_networks = (
            # Target actor
            TanhDeterministicActor(
                action_dim,
                activation=args.activation,
                action_lims=action_lims,
                obs_stats=obs_stats,
            ),
            # Q network 1
            SoftQNetwork(activation=args.activation, obs_stats=obs_stats),
            # Q network 2
            SoftQNetwork(activation=args.activation, obs_stats=obs_stats),
            # Target Q network 1
            SoftQNetwork(activation=args.activation, obs_stats=obs_stats),
            # Target Q network 2
            SoftQNetwork(activation=args.activation, obs_stats=obs_stats),
        )
        return (
            TanhDeterministicActor(
                action_dim,
                activation=args.activation,
                action_lims=action_lims,
                obs_stats=obs_stats,
            ),
            auxilary_networks,
        )
    elif args.agent == "iql":

        agent_networks = {
            "train": TanhGaussianActor(
                action_dim,
                activation=args.activation,
                action_lims=action_lims,
                obs_stats=obs_stats,
            ),
            "eval": TanhGaussianActor(
                action_dim,
                activation=args.activation,
                action_lims=action_lims,
                obs_stats=obs_stats,
                eval=True,
            ),
        }
        auxilary_networks = (
            # Q network
            VectorCritic(activation=args.activation, n_critics=2, obs_stats=obs_stats),
            # Target Q network
            VectorCritic(activation=args.activation, n_critics=2, obs_stats=obs_stats),
            # Value function
            ValueFunction(activation=args.activation, obs_stats=obs_stats),
        )
        return agent_networks, auxilary_networks
    raise ValueError(f"Unknown agent {args.agent}.")


def make_train_step(args, network, aux_networks):
    if args.agent == "td3_bc":
        return td3_bc_train_step(args, network, aux_networks)
    elif args.agent == "iql":
        return iql_train_step(args, network, aux_networks)
    raise ValueError(f"Unknown agent {args.agent}.")


def make_lr_schedule(args):
    init_lr = args.lr
    if args.lr_schedule == "constant":
        return init_lr
    total_steps = args.num_train_steps
    warmup_steps = total_steps // 10
    if args.lr_schedule == "cosine":
        return optax.cosine_decay_schedule(
            init_value=init_lr,
            decay_steps=total_steps,
        )
    elif args.lr_schedule == "exponential":
        return optax.warmup_exponential_decay_schedule(
            init_value=init_lr * 0.1,
            peak_value=init_lr,
            warmup_steps=warmup_steps,
            transition_steps=total_steps - warmup_steps,
            decay_rate=0.1,
        )
    raise ValueError(f"Unknown learning rate schedule {args.lr_schedule}.")


def _create_optimizer(args):
    lr_schedule = make_lr_schedule(args)
    return optax.adam(lr_schedule, eps=1e-5)


def create_agent_train_state(rng, network, args, obs_shape=None, action_dim=None):
    if obs_shape is None or hasattr(network, "NO_INPUT"):
        # Parameter
        network_params = network.init(rng)
    else:
        init_x = jnp.zeros(obs_shape)
        if action_dim is None or hasattr(network, "NO_ACTION_INPUT"):
            # Actor, critic, or actor-critic
            network_params = network.init(rng, init_x)
        else:
            # Q network
            init_action = jnp.zeros(action_dim)
            network_params = network.init(rng, init_x, init_action)
    tx = _create_optimizer(args)
    return TrainState.create(
        apply_fn=network.apply,
        params=network_params,
        tx=tx,
    )