import flax.struct as struct
import jax
import jax.numpy as jnp

from nais.gym.base import (
    EnvironmentConfig,
    EnvState,
    LogRewardBase,
    LogRewardConfig,
)

# We represent a phylogenetic tree as an array
# of the form (B, T, 2), in which the t-th tuple for the b-th batch
# indicates which two nodes were merged at the t-th step of the generative process
# This allows for both forward and backward tracking of the process

# The forward policy will be based on an GCN, for instance, that takes as input the current tree
# and outputs a probability distribution over the potential mergers

# To implement this policy, we should also keep track of the trees' adjacency matrices
# Clearly, the adjacency matrix will grow as we add more nodes, but since there is a
# limited number of nodes to add, we can simply pad the adjacency matrix


@struct.dataclass(frozen=True)
class PhylogeneticMetadata:
    adjacency_matrix: jax.Array
    node_count: int

    # The `state` (from EnvState) attribute will be storing the
    # sequence of added nodes


def add_node(env_state: EnvState, nodes_to_merge: jax.Array, idx: int):
    # nodes_to_merge: [B, 2]

    # When merging nodes (i, j), we mask out the actions corresponding to j,
    # and keep only those corresponding to i when i < j (arbitrary)

    state = env_state.state
    state = state.at[env_state.batch_ids, idx].set(nodes_to_merge)

    # For the right node, we move the actions; for the left node, we only remove the
    # action corresponding to the right node

    # We henceforth assume that nodes_to_merge[:, 1] > nodes_to_merge[:, 0] (as in triu_indices)

    # We update the masks accordingly (the forward mask is expanded based on triu_indices in the policy)
    forward_mask = env_state.forward_mask
    forward_mask = forward_mask.at[env_state.batch_ids, nodes_to_merge[:, 1], :].set(0)
    forward_mask = forward_mask.at[env_state.batch_ids, nodes_to_merge[:, 0], nodes_to_merge[:, 1]].set(0)

    # In the backward mask, we allow the node to be removed
    backward_mask = env_state.backward_mask
    backward_mask = backward_mask.at[env_state.batch_ids, nodes_to_merge[:, 0]].set(1)
    backward_mask = backward_mask.at[env_state.batch_ids, nodes_to_merge[:, 1]].set(0)

    node_count = env_state.metadata.node_count + 1

    # We include the edges from the left and right nodes to the node_count
    adjacency_matrix = env_state.metadata.adjacency_matrix
    adjacency_matrix = adjacency_matrix.at[env_state.batch_ids, nodes_to_merge[:, 0], node_count].set(1)
    adjacency_matrix = adjacency_matrix.at[env_state.batch_ids, nodes_to_merge[:, 1], node_count].set(1)
    adjacency_matrix = adjacency_matrix.at[env_state.batch_ids, node_count, nodes_to_merge[:, 0]].set(1)
    adjacency_matrix = adjacency_matrix.at[env_state.batch_ids, node_count, nodes_to_merge[:, 1]].set(1)

    return env_state.replace(
        state=state,
        forward_mask=forward_mask,
        backward_mask=backward_mask,
        metadata=PhylogeneticMetadata(
            adjacency_matrix=adjacency_matrix,
            node_count=node_count,
        ),
    )


def remove_node(env_state: EnvState, nodes_to_remove: jax.Array):
    # nodes_to_remove: [B,]
    return env_state


class LogRewardUniform(LogRewardBase):
    def __call__(self, env_state: EnvState):
        return jnp.ones_like(env_state.batch_size)
