from typing import NamedTuple

import chex
import haiku as hk
import jax
import jax.numpy as jnp
from haiku_geometric.nn import SAGEConv

from src.baselines.qrdqn.util import distort_value


class FeatureExtractor(hk.Module):
    def __init__(self, hidden_channels: int):
        super().__init__()
        self.hidden_channels = hidden_channels

        self.conv1 = SAGEConv(out_channels=hidden_channels)
        self.conv2 = SAGEConv(out_channels=hidden_channels)
        self.conv3 = SAGEConv(out_channels=hidden_channels)
        self.conv4 = SAGEConv(out_channels=hidden_channels)

        self.node_features_proj = hk.Linear(hidden_channels)

        self.linear_node_features = hk.Linear(hidden_channels)
        self.linear_pool = hk.Linear(hidden_channels)
        self.linear_node = hk.Linear(hidden_channels)
        self.linear_aux = hk.Linear(hidden_channels)

        self.num_registers = 1
        self.register_node_init = hk.get_parameter(
            "register_node_init",
            shape=(self.num_registers, hidden_channels),
            init=hk.initializers.RandomNormal(),
        )

    def __call__(self, nodes, senders, receivers, aux):
        nodes = self.linear_node_features(nodes)  # (nodes, hidden_channels)
        nodes = jax.nn.relu(nodes)  # (nodes, hidden_channels)
        nodes = jnp.concatenate(
            [nodes, jnp.ones((nodes.shape[0], self.hidden_channels))], axis=-1
        )  # (nodes, hidden_channels + 1)
        nodes = self.node_features_proj(nodes)  # (nodes, hidden_channels)
        nodes = jax.nn.relu(nodes)  # (nodes, hidden_channels)

        register_nodes = self.register_node_init  # (1, hidden)
        num_nodes = nodes.shape[0]
        nodes = jnp.concatenate(
            [nodes, register_nodes], axis=0
        )  # (n_nodes + 1, hidden)

        # There are n_nodes + 1 nodes in the graph. We want to connect (register) nodes [n_nodes:] to all the prior nodes.
        receivers_ = jnp.arange(num_nodes)
        senders_ = jnp.full(
            num_nodes, num_nodes, dtype=jnp.int32
        )  # [num_node, num_node, ...]
        senders_ = jnp.concatenate([senders, senders_], axis=0)
        receivers_ = jnp.concatenate([receivers, receivers_], axis=0)

        # Make the graph undirected by adding the reverse edges
        senders = jnp.concatenate([senders_, receivers_], axis=0)
        receivers = jnp.concatenate([receivers_, senders_], axis=0)

        # Apply convolutions
        x = self.conv1(nodes, senders, receivers)  # (n_nodes + 1, hidden_channels)
        x = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(
            jax.nn.relu(x)
        )  # (n_nodes + 1, hidden_channels)
        x = x + hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(
            jax.nn.relu(self.conv2(x, senders, receivers))
        )  # (n_nodes + 1, hidden_channels)
        x = x + hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(
            jax.nn.relu(self.conv3(x, senders, receivers))
        )  # (n_nodes + 1, hidden_channels)
        x = x + hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(
            jax.nn.relu(self.conv4(x, senders, receivers))
        )  # (n_nodes + 1, hidden_channels)

        # Pool and concat
        pool_rep = jnp.sum(x, axis=0)  # (hidden_channels,)
        pool_rep = self.linear_pool(pool_rep)  # (out_channels,)
        aux_rep = self.linear_aux(aux)  # (out_channels)

        pool_aux_concat_ = jnp.concat(
            [pool_rep[jnp.newaxis, :], aux_rep[jnp.newaxis, :]], axis=-1
        ).reshape((-1))  # (out_channels * 2,)
        pool_aux_concat = jax.nn.relu(hk.Linear(self.hidden_channels)(pool_aux_concat_))

        node_rep = self.linear_node(x)  # (n_nodes + K, out_channels)
        # Repeat aux_rep to match the number of nodes
        aux_rep = jnp.repeat(
            aux_rep[jnp.newaxis, :], x.shape[0], axis=0
        )  # (n_nodes + K, out_channels)

        pool_rep = jnp.repeat(pool_rep[jnp.newaxis, :], x.shape[0], axis=0)
        node_pool_concat = jnp.concat(
            [pool_rep, node_rep], axis=-1
        )  # (n_nodes + K, out_channels * 2)
        node_pool_concat = jax.nn.relu(
            node_pool_concat
        )  # (n_nodes + K, out_channels * 2)
        node_pool_aux_concat = jnp.concat(
            [node_pool_concat, aux_rep], axis=-1
        )  # (n_nodes + K, out_channels * 3)

        # Remove the register nodes from the representation
        node_reps = node_pool_aux_concat[
            : -self.num_registers
        ]  # (n_nodes, out_channels * 3)
        return node_reps, pool_aux_concat


class PredictionNetwork(hk.Module):
    def __init__(
        self,
        hidden_channels: int,
        num_quantiles: int,
        name: str = "prediction_network",
    ):
        super().__init__(name=name)
        self.num_quantiles = num_quantiles
        self.hidden_channels = hidden_channels
        self.feature_extractor = FeatureExtractor(hidden_channels=self.hidden_channels)

    def __call__(self, nodes, senders, receivers, aux):
        node_reps, pool_aux_concat = self.feature_extractor(
            nodes, senders, receivers, aux
        )  # (n_nodes, hidden_channels * 3)

        # Q-dist
        v_dist = jax.nn.relu(hk.Linear(self.hidden_channels)(pool_aux_concat))
        v_dist = hk.Linear(
            self.num_quantiles,
            w_init=hk.initializers.UniformScaling(0.01),
            b_init=hk.initializers.UniformScaling(0.25),
        )(v_dist)  # (num_quantiles,)

        # Policy priors
        logits = hk.Linear(self.hidden_channels)(
            node_reps
        )  # (n_nodes, hidden_channels)
        logits = jax.nn.relu(logits)  # (n_nodes, hidden_channels)
        logits = hk.Linear(1)(logits).reshape((-1))  # (n_nodes,)

        return logits, v_dist, node_reps, pool_aux_concat


class RewardHead(hk.Module):
    def __init__(self, hidden_channels: int, num_quantiles: int):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.num_quantiles = num_quantiles

    def __call__(
        self,
        selected_node: jnp.ndarray,
    ):
        sa_embed = hk.Linear(self.hidden_channels)(selected_node)
        sa_embed_ = jax.nn.tanh(sa_embed)  # (hidden_channels,)
        sa_embed = hk.Linear(self.hidden_channels)(sa_embed_)
        sa_embed = jax.nn.tanh(sa_embed)  # (hidden_channels,)
        sa_embed = sa_embed + sa_embed_  # (hidden_channels,)
        r_dist = hk.Linear(
            self.num_quantiles,
            w_init=hk.initializers.UniformScaling(0.01),
            b_init=hk.initializers.UniformScaling(0.25),
        )(sa_embed)  # (num_quantiles,)
        return r_dist


class RewardHistoryHead(hk.Module):
    def __init__(self, hidden_channels: int, num_quantiles: int):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.num_quantiles = num_quantiles

    def __call__(
        self,
        obs_embedding: jnp.ndarray,
    ):
        obs_embedding_ = hk.Linear(self.hidden_channels)(obs_embedding)
        obs_embedding_ = jax.nn.tanh(obs_embedding)  # (hidden_channels,)
        obs_embedding_ = obs_embedding + obs_embedding_  # (hidden_channels,)
        r_dist = hk.Linear(
            self.num_quantiles,
            w_init=hk.initializers.UniformScaling(0.01),
            b_init=hk.initializers.UniformScaling(0.25),
        )(obs_embedding_)  # (num_quantiles,)
        return r_dist  # (num_quantiles,)


def get_helpers(args, tau_hats: jnp.ndarray):
    # Return value for expansion
    def v_extract_value(v_dist: jnp.ndarray) -> jnp.ndarray:
        return v_dist

    def q_extract_value(q_dist: jnp.ndarray) -> jnp.ndarray:
        # Distort the q distributions to get greedy values
        # q_dist: (b, num_quantiles, num_actions)
        q_values = jax.vmap(distort_value, in_axes=(0, None, None))(
            q_dist,  # (b, num_quantiles, num_actions)
            tau_hats,  # (num_quantiles,)
            args.cvar_alpha,
        )
        greedy_actions = jnp.argmax(q_values, axis=1)  # (b,)
        q_dist = jnp.take_along_axis(
            q_dist, greedy_actions[:, None, None], axis=-1
        ).squeeze(axis=-1)
        return q_dist
        # (b, num_quantiles)

    def v_value_at_action(v_dist: jnp.ndarray, _action: jnp.ndarray) -> jnp.ndarray:
        """Extracts the value at the given action from the value distribution."""
        # value: (b, num_quantiles)
        # action: (b, )
        return v_dist

    def q_value_at_action(q_dist: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
        """Extracts the value at the given action from the Q distribution."""
        # q_dist: (b, num_quantiles, num_actions)
        # action: (b, )
        return jnp.take_along_axis(q_dist, action[:, None, None], axis=-1).squeeze(
            axis=-1
        )
        # (b, num_quantiles)

    if args.use_q_value_head:
        extract_value = q_extract_value
        value_at_action = q_value_at_action
    else:
        extract_value = v_extract_value
        value_at_action = v_value_at_action

    return extract_value, value_at_action


def make_network_apply_fns(args):
    def prediction_apply_fn(nodes, senders, receivers, aux):
        prediction_network = PredictionNetwork(
            hidden_channels=args.num_hidden,
            num_quantiles=args.num_quantiles,
        )
        return jax.vmap(prediction_network)(nodes, senders, receivers, aux)

    def reward_apply_fn(selected_node):
        reward_head = RewardHead(
            hidden_channels=args.num_hidden,
            num_quantiles=args.num_quantiles,
        )
        return jax.vmap(reward_head)(selected_node)

    def reward_history_apply_fn(obs_embedding):
        reward_history_head = RewardHistoryHead(
            hidden_channels=args.num_hidden,
            num_quantiles=args.num_quantiles,
        )
        return jax.vmap(reward_history_head)(obs_embedding)

    def init_model_fn(nodes, senders, receivers, aux, action):
        prediction_network = PredictionNetwork(
            hidden_channels=args.num_hidden,
            num_quantiles=args.num_quantiles,
        )
        _logits, q_dist, node_reps, _pool_aux_concat = jax.vmap(prediction_network)(
            nodes, senders, receivers, aux
        )

        reward_head = RewardHead(
            hidden_channels=args.num_hidden,
            num_quantiles=args.num_quantiles,
        )
        selected_node = nodes[action]
        r_dist = jax.vmap(reward_head)(selected_node)

        reward_history_head = RewardHistoryHead(
            hidden_channels=args.num_hidden,
            num_quantiles=args.num_quantiles,
        )
        r_history_dist = jax.vmap(reward_history_head)(_pool_aux_concat)
        return _logits, q_dist, r_dist, r_history_dist

    prediction_apply = hk.without_apply_rng(hk.transform(prediction_apply_fn))
    reward_apply = hk.without_apply_rng(hk.transform(reward_apply_fn))
    reward_history_apply = hk.without_apply_rng(hk.transform(reward_history_apply_fn))
    init_model = hk.without_apply_rng(hk.transform(init_model_fn))
    return prediction_apply, reward_apply, reward_history_apply, init_model


if __name__ == "__main__":
    # Example usage
    nodes = jnp.array([[[1.0, 2.0], [3.0, 4.0]]])
    senders = jnp.array([[0, 1]])
    receivers = jnp.array([[1, 0]])
    aux = jnp.array([[0.1, 0.2]])

    prediction_apply, reward_apply, reward_history_apply, init_model = (
        make_network_apply_fns(
            args=NamedTuple("Args", [("num_hidden", int), ("num_quantiles", int)])(
                num_hidden=16, num_quantiles=10
            )
        )
    )

    params = init_model.init(
        jax.random.PRNGKey(42), nodes, senders, receivers, aux, jnp.array([0])
    )
    logits, v_dist, node_reps, _pool_aux_concat = prediction_apply.apply(
        params, nodes, senders, receivers, aux
    )
    print(logits.shape, v_dist.shape, node_reps.shape, _pool_aux_concat.shape)

    selected_node = nodes[jnp.array([0])]  # Example selected node
    r_dist = reward_apply.apply(params, selected_node)
    print(r_dist.shape)

    r_history_dist = reward_history_apply.apply(params, _pool_aux_concat)
    print(r_history_dist.shape)

    # def feature_extractor_fn(nodes, edges, senders, receivers, aux):
    #     extractor = FeatureExtractor(hidden_channels=16)
    #     return jax.vmap(extractor)(nodes, edges, senders, receivers, aux)

    # def prediction_network_fn(nodes, edges, senders, receivers, aux):
    #     network = PredictionNetwork(hidden_channels=16, num_quantiles=10)
    #     return jax.vmap(network)(nodes, edges, senders, receivers, aux)

    # def reward_head_fn(node_reps, action):
    #     reward_head = RewardHead(hidden_channels=16, num_quantiles=10)
    #     return jax.vmap(reward_head)(node_reps, action)

    # def reward_history_head_fn(obs_embedding):
    #     reward_history_head = RewardHistoryHead(hidden_channels=16, num_quantiles=10)
    #     return jax.vmap(reward_history_head)(obs_embedding)

    # # feature_extractor = hk.without_apply_rng(hk.transform(feature_extractor_fn))
    # rng = jax.random.PRNGKey(42)
    # # params = feature_extractor.init(rng, nodes, edges, senders, receivers, aux)
    # # output = feature_extractor.apply(params, nodes, edges, senders, receivers, aux)
    # # print(output.shape)

    # prediction_network = hk.without_apply_rng(hk.transform(prediction_network_fn))
    # params = prediction_network.init(rng, nodes, edges, senders, receivers, aux)
    # logits, q_dist, node_reps, _pool_aux_concat = prediction_network.apply(
    #     params, nodes, edges, senders, receivers, aux
    # )
    # print(logits.shape, q_dist.shape, _pool_aux_concat.shape)

    # reward_head = hk.without_apply_rng(hk.transform(reward_head_fn))
    # params = reward_head.init(rng, node_reps, jnp.array([0]))
    # r_dist = reward_head.apply(params, node_reps, jnp.array([0]))
    # print(r_dist.shape)

    # reward_history_head = hk.without_apply_rng(hk.transform(reward_history_head_fn))
    # params = reward_history_head.init(rng, _pool_aux_concat)
    # r_dist = reward_history_head.apply(params, _pool_aux_concat)
    # print(r_dist.shape)

    # q_extract_value, q_value_at_action = get_helpers(
    #     args=NamedTuple("Args", [("use_q_value_head", bool), ("cvar_alpha", float)])(
    #         use_q_value_head=True, cvar_alpha=0.1
    #     ),
    #     tau_hats=jnp.linspace(0.1, 0.9, 10),
    # )
    # q_values = q_extract_value(q_dist)
    # print("ex", q_values.shape)  # Should be (1, num_quantiles)

    # action = jnp.array([0])  # Example action
    # q_value = q_value_at_action(q_dist, action)
    # print("v", q_value.shape)  # Should be (1, num_quantiles)
