import haiku as hk
import jax
import jax.numpy as jnp

from src.baselines.qrdqn.util import distort_value


class FeatureExtractor(hk.Module):
    def __init__(
        self, num_hidden: int, is_state_vector: bool, name: str = "feature_extractor"
    ):
        super().__init__(name=name)
        self.num_hidden = num_hidden
        self.is_state_vector = is_state_vector

    def __call__(self, obs: jnp.ndarray) -> jnp.ndarray:
        x = obs.astype(jnp.float32)
        # If the state is an image, we use a CNN
        if not self.is_state_vector:
            x = hk.Conv2D(32, kernel_shape=3)(x)
            x = jax.nn.relu(x)
            x = hk.Conv2D(16, kernel_shape=3)(x)
            # Average pooling
            window_shape = (1, 2, 2, 1)
            strides = (1, 2, 2, 1)
            padding = "VALID"
            pooled = jax.lax.reduce_window(
                x,
                init_value=0.0,
                computation=jax.lax.add,
                window_dimensions=window_shape,
                window_strides=strides,
                padding=padding,
            )
            # Divide by window size to get average
            x = pooled / (2 * 2)  # (1, 5, 5, 32)

        x = x.reshape((x.shape[0], -1))
        x = hk.Linear(self.num_hidden)(x)
        x = jax.nn.tanh(x)
        return hk.Linear(
            self.num_hidden,
            w_init=hk.initializers.UniformScaling(0.01),
            b_init=hk.initializers.UniformScaling(0.25),
        )(x)


class ObsSeqEncoder(hk.Module):
    def __init__(
        self,
        num_hidden: int,
        is_state_vector: bool,
        name: str = "obs_seq_encoder",
    ):
        super().__init__(name=name)
        self.num_hidden = num_hidden
        self.is_state_vector = is_state_vector

        self.feature_extractor = FeatureExtractor(
            num_hidden=num_hidden,
            is_state_vector=is_state_vector,
        )

    def __call__(self, obs: jnp.ndarray) -> jnp.ndarray:
        # obs: (1, history_length, *obs_shape)
        x = obs.astype(jnp.float32)

        # Join the sequence and batch dimensions to get latents
        batch, seq_len = x.shape[0], x.shape[1]
        x = x.reshape((-1, *x.shape[2:]))  # (batch * seq_len, *obs_shape)
        x = self.feature_extractor(x)  # (batch * seq_len, hidden_size)

        # Return the sequence dimension
        x = x.reshape((batch, seq_len, -1))  # (batch, seq_len, hidden_size)

        core = hk.GRU(self.num_hidden)
        rnn_initial_states = core.initial_state(batch)
        # Dynamic unroll the RNN core
        outs, rnn_state = hk.dynamic_unroll(
            core,
            x,
            rnn_initial_states,
            time_major=False,
        )
        outs = outs[:, -1, :]  # type: ignore # Take the last time step (the current state) (batch, hidden_size)
        value = hk.Linear(self.num_hidden)(outs)  # (1, 128)
        value = jax.nn.tanh(value)

        return hk.Linear(
            self.num_hidden,
            w_init=hk.initializers.UniformScaling(0.01),
            b_init=hk.initializers.UniformScaling(0.25),
        )(value)


class PredictionNetwork(hk.Module):
    def __init__(
        self,
        num_hidden: int,
        num_actions: int,
        num_quantiles: int,
        is_state_vector: bool,
        activation: str = "tanh",
    ):
        super().__init__()
        self.num_hidden = num_hidden
        self.num_actions = num_actions
        self.is_state_vector = is_state_vector
        self.feature_extractor = FeatureExtractor(
            num_hidden=num_hidden, is_state_vector=is_state_vector
        )
        self.num_quantiles = num_quantiles
        self.activation = jax.nn.relu if activation == "relu" else jax.nn.tanh
        assert activation in ["relu", "tanh"]

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = x.astype(jnp.float32)
        x = self.feature_extractor(x)  # (1, hidden_size)

        # Value
        value = hk.Linear(self.num_hidden)(x)  # (1, 128)
        value = self.activation(value)
        value = hk.Linear(self.num_quantiles)(value)  # (b, num_quantiles)

        # Policy
        logits = hk.Linear(self.num_hidden)(x)  # (1, 128)
        logits = self.activation(logits)
        logits = hk.Linear(self.num_actions)(logits)  # (1, num_actions)
        return logits, value, x  # type: ignore


class RewardHead(hk.Module):
    def __init__(
        self,
        num_hidden: int,
        num_actions: int,
        num_quantiles: int,
        name: str = "reward_head",
    ):
        super().__init__(name=name)
        self.num_hidden = num_hidden
        self.num_actions = num_actions
        self.num_quantiles = num_quantiles

    def __call__(self, obs_embedding: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
        obs_embedding = obs_embedding.astype(jnp.float32)
        action = action.astype(jnp.float32)

        one_hot_action = jax.nn.one_hot(action, num_classes=self.num_actions)
        action_embedding = hk.Linear(self.num_hidden)(one_hot_action)
        action_embedding = jax.nn.relu(action_embedding)

        combined = jnp.concatenate([obs_embedding, action_embedding], axis=-1)
        sa_embed = hk.Linear(self.num_hidden)(combined)
        sa_embed = jax.nn.tanh(sa_embed)
        sa_embed = sa_embed + obs_embedding

        reward = hk.Linear(
            self.num_quantiles,
            w_init=hk.initializers.UniformScaling(0.01),
            b_init=hk.initializers.UniformScaling(0.25),
        )(sa_embed)

        return reward  # (b, num_quantiles)


class RewardHistoryHead(hk.Module):
    def __init__(
        self,
        num_hidden: int,
        num_quantiles: int,
        name: str = "reward_history_head",
    ):
        super().__init__(name=name)
        self.num_hidden = num_hidden
        self.num_quantiles = num_quantiles
        self.obs_encoder = ObsSeqEncoder(num_hidden=num_hidden, is_state_vector=True)

    def __call__(self, obs: jnp.ndarray) -> jnp.ndarray:
        obs_embedding = self.obs_encoder(obs)
        obs_embedding = obs_embedding + jax.nn.tanh(
            hk.Linear(self.num_hidden)(obs_embedding)
        )
        reward = hk.Linear(
            self.num_quantiles,
            w_init=hk.initializers.UniformScaling(0.01),
            b_init=hk.initializers.UniformScaling(0.25),
        )(obs_embedding)
        return reward  # (b, num_quantiles)


def get_helpers(args, tau_hats: jnp.ndarray):
    # Return value for expansion
    def v_extract_value(v_dist: jnp.ndarray) -> jnp.ndarray:
        return v_dist

    def q_extract_value(q_dist: jnp.ndarray) -> jnp.ndarray:
        # Distort the q distributions to get greedy values
        # q_dist: (b, num_quantiles, num_actions)
        q_values = jax.vmap(distort_value, in_axes=(0, None, None))(
            q_dist,  # (b, num_quantiles, num_actions)
            tau_hats,  # (num_quantiles,)
            args.cvar_alpha,
        )
        greedy_actions = jnp.argmax(q_values, axis=1)  # (b,)
        q_dist = jnp.take_along_axis(
            q_dist, greedy_actions[:, None, None], axis=-1
        ).squeeze(axis=-1)
        return q_dist
        # (b, num_quantiles)

    def v_value_at_action(v_dist: jnp.ndarray, _action: jnp.ndarray) -> jnp.ndarray:
        """Extracts the value at the given action from the value distribution."""
        # value: (b, num_quantiles)
        # action: (b, )
        return v_dist

    def q_value_at_action(q_dist: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
        """Extracts the value at the given action from the Q distribution."""
        # q_dist: (b, num_quantiles, num_actions)
        # action: (b, )
        return jnp.take_along_axis(q_dist, action[:, None, None], axis=-1).squeeze(
            axis=-1
        )
        # (b, num_quantiles)

    if args.use_q_value_head:
        extract_value = q_extract_value
        value_at_action = q_value_at_action
    else:
        extract_value = v_extract_value
        value_at_action = v_value_at_action

    return extract_value, value_at_action


def make_network_apply_fns(args):
    def prediction_apply_fn(x: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]:
        """Apply the prediction network to the input."""
        net = PredictionNetwork(
            num_hidden=args.num_hidden,
            num_actions=args.num_actions,
            is_state_vector=args.is_state_vector,
            num_quantiles=args.num_quantiles,
            # use_q_value_head=args.use_q_value_head,
        )
        return net(x)

    def reward_apply_fn(obs_embedding: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
        """Apply the reward head to the observation embedding and action."""
        net = RewardHead(
            num_hidden=args.num_hidden,
            num_actions=args.num_actions,
            num_quantiles=args.num_quantiles,
        )
        return net(obs_embedding, action)

    def reward_history_apply_fn(obs_history: jnp.ndarray) -> jnp.ndarray:
        """Apply the reward history head to the observation embedding."""
        net = RewardHistoryHead(
            num_hidden=args.num_hidden,
            num_quantiles=args.num_quantiles,
        )
        return net(obs_history)

    def init_model_fn(obs: jnp.ndarray, obs_history: jnp.ndarray, action: jnp.ndarray):
        """Initialize the prediction network."""
        net = PredictionNetwork(
            num_hidden=args.num_hidden,
            num_actions=args.num_actions,
            is_state_vector=args.is_state_vector,
            num_quantiles=args.num_quantiles,
            # use_q_value_head=args.use_q_value_head,
        )
        _logits, _value, obs_embed = net(obs)
        reward_net = RewardHead(
            num_hidden=args.num_hidden,
            num_actions=args.num_actions,
            num_quantiles=args.num_quantiles,
        )
        reward = reward_net(obs_embed, action)
        reward_history_apply_fn = RewardHistoryHead(
            num_hidden=args.num_hidden,
            num_quantiles=args.num_quantiles,
        )
        reward_history = reward_history_apply_fn(obs_history)
        return obs_embed, _logits, _value, reward, reward_history

    prediction_apply = hk.without_apply_rng(hk.transform(prediction_apply_fn))
    reward_apply = hk.without_apply_rng(hk.transform(reward_apply_fn))
    reward_history_apply = hk.without_apply_rng(hk.transform(reward_history_apply_fn))
    init_model = hk.without_apply_rng(hk.transform(init_model_fn))
    return prediction_apply, reward_apply, reward_history_apply, init_model
