"""Datasets with labels generated by a neural network with a graphical structure."""
import collections
import dataclasses
import itertools
import random
from typing import Dict, List, Optional, Sequence, Tuple, Union
import uuid

import numpy as np
import tensorflow as tf


@dataclasses.dataclass
class Node:
    children: List['Node'] = None
    n_units: Optional[int] = None

    def __post_init__(self):
        self.uuid = uuid.uuid4()
        if self.children is None:
            self.children = []


@dataclasses.dataclass
class Dag:
    nodes_by_level: Sequence[Sequence[Node]]

    def __post_init__(self):
        self.uuid_to_node = {
            node.uuid: node
            for level in self.nodes_by_level
            for node in level
        }

    def get_node_to_parents_uuid_map(self) -> Dict[uuid.UUID, List[uuid.UUID]]:
        ret = collections.defaultdict(list)
        for node in self.uuid_to_node.values():
            for child in node.children:
                ret[child.uuid].append(node.uuid)
        return ret

    @property
    def n_levels(self) -> int:
        return len(self.nodes_by_level)

    def get_flat_nodes(self) -> List[Node]:
        return list(self.uuid_to_node.values())

    # def to_network(self):
    #     raise NotImplementedError('TODO')


###############################################################################


class DagNeuralNetwork(tf.keras.Model):
    def __init__(self, dag: Dag, n_classes: int, activation, **kwargs):
        super().__init__(**kwargs)
        self.dag = dag
        self.n_classes = n_classes

        # TODO: Support stuff like 'relu' rather than a function
        self.activation = activation

        self.flat_nodes = self.dag.get_flat_nodes()
        self.node_to_parents_uuid_map = self.dag.get_node_to_parents_uuid_map()

        self.src_to_dst_uuid_to_linear_layer = self._make_layers()

        # Keras apparantly can't see the variables in src_to_dst_uuid_to_linear_layer.
        self._make_sure_keras_reads_me = [
            layer
            for dst_to_layer in self.src_to_dst_uuid_to_linear_layer.values()
            for layer in dst_to_layer.values()
        ]

        self.logits_layer = tf.keras.layers.Dense(n_classes, activation=None)
        self.logits_layer(tf.keras.Input([self.get_level_width(-1)], dtype=tf.float32))

    def get_input_size(self) -> int:
        return self.get_level_width(0)

    def get_level_width(self, layer_index: int) -> int:
        return sum(n.n_units for n in self.dag.nodes_by_level[layer_index])

    def _split_input(self, x):
        # x.shape = [..., input_size]
        first_level = self.dag.nodes_by_level[0]
        units = [n.n_units for n in first_level]
        return tf.split(x, units, axis=-1)

    def _make_layers(self):
        src_to_dst_uuid_to_linear_layer = collections.defaultdict(dict)
        for node in self.flat_nodes:
            for child in node.children:
                layer = tf.keras.layers.Dense(child.n_units, activation=None)
                layer(tf.keras.Input([node.n_units], dtype=tf.float32))
                src_to_dst_uuid_to_linear_layer[node.uuid][child.uuid] = layer
        return src_to_dst_uuid_to_linear_layer

    def _call_for_node(self, cache: Dict[uuid.UUID, tf.Tensor], node_uuid: uuid.UUID):
        if node_uuid not in cache:
            parent_uuids = self.node_to_parents_uuid_map[node_uuid]

            if not len(parent_uuids):
                raise ValueError('Source node not in level 0.')

            preactivations = []
            for parent_uuid in parent_uuids:
                parent_input = self._call_for_node(cache, parent_uuid)
                layer = self.src_to_dst_uuid_to_linear_layer[parent_uuid][node_uuid]
                preactivations.append(layer(parent_input))

            cache[node_uuid] = self.activation(tf.reduce_sum(preactivations, axis=0))
        return cache[node_uuid]

    def call(self, x):
        level_inputs = self._split_input(x)

        # Seed cache with level inputs, then run _call_for_node for all nodes in last level.
        cache = {}
        for inpt, node in zip(level_inputs, self.dag.nodes_by_level[0]):
            cache[node.uuid] = inpt

        last_hidden = tf.concat([
            self._call_for_node(cache, node.uuid)
            for node in self.dag.nodes_by_level[-1]
        ], axis=-1)

        return self.logits_layer(last_hidden)


###############################################################################


def make_uniformly_random_nodes_by_level(
    n_levels: int,
    min_nodes: Union[int, Sequence[int]],
    max_nodes: Union[int, Sequence[int], None] = None,
):
    """ """
    if isinstance(min_nodes, int):
        min_nodes = n_levels * [min_nodes]

    if isinstance(max_nodes, int):
        max_nodes = n_levels * [max_nodes]
    elif max_nodes is None:
        max_nodes = min_nodes

    assert len(min_nodes) == len(max_nodes) == n_levels

    ret = []
    for min_node, max_node in zip(min_nodes, max_nodes):
        n_nodes = random.randrange(min_node, max_node + 1)
        level = [
            Node()
            for _ in range(n_nodes)
        ]
        ret.append(level)
    return ret


def add_uniformly_random_dag_connections(
    nodes_by_level: Sequence[Sequence[Node]],
    p_connection: Union[float, Sequence[float]]
):
    """

    If p_connection is a float, the probability of any two nodes
    being connected will be p_connection regardless of their levels.
    If it is a sequence of floats, then p_connection[i] is the probability
    that two nodes separated by i levels in the hierarchy are connected will
    be p_connection[i].
    """
    n_levels = len(nodes_by_level)

    if isinstance(p_connection, float):
        p_connection = n_levels * [p_connection]

    if len(p_connection) < n_levels:
        p_connection = list(p_connection) + ((n_levels - len(p_connection)) * [0.0])

    for i in range(n_levels):
        lower_nodes = nodes_by_level[i]

        for j in range(i):
            n_separating_levels = i - j - 1
            prob = p_connection[n_separating_levels]

            higher_nodes = nodes_by_level[j]
            for low, high in itertools.product(lower_nodes, higher_nodes):
                if random.random() <= prob:
                    high.children.append(low)


def assign_uniformly_random_units(
    nodes_by_level: Sequence[Sequence[Node]],
    min_units: Union[int, Sequence[int]],
    max_units: Union[int, Sequence[int], None] = None,
):
    n_levels = len(nodes_by_level)

    if isinstance(min_units, int):
        min_units = n_levels * [min_units]

    if isinstance(max_units, int):
        max_units = n_levels * [max_units]
    elif max_units is None:
        max_units = min_units

    for i, nodes in enumerate(nodes_by_level):
        for node in nodes:
            node.n_units = random.randrange(min_units[i], max_units[i] + 1)


def remove_sources_not_at_first_level(nodes_by_level: Sequence[Sequence[Node]]):
    visited = set()
    stack = list(nodes_by_level[0])
    while stack:
        node = stack.pop()
        if node.uuid in visited:
            continue
        visited.add(node.uuid)
        stack.extend(node.children)

    return [
        [node for node in level if node.uuid in visited]
        for level in nodes_by_level
    ]
