# pyHEXgraph: Hierarchy and Exclusion Graphs *in Python* #
This code implements the HEX-graph described in the paper of [Large-Scale Object Classification using Label Relation Graphs](https://link.springer.com/chapter/10.1007/978-3-319-10590-1_4), in ECCV 2014.
This code is based on [kelemin's repository on hex-graph](https://github.com/ronghanghu/hex_graph) avaible under MIT License and is itself based on [Ronghang Hu's version](https://github.com/ronghanghu/hex_graph) which is under the Simplified BSD License. Both these implementations are in MATLAB.

We also used a partial Python implementation avaible on [yuweitu's repository](https://github.com/yuweitu/MultiLabel_Classification/tree/master/hexGraph).

Both mentionned implementations are in MATLAB, we developped here a PyTorch version optimized to run on GPU.

## Code structure ##
The `HEXgraph` object enables to perform the **symbolic operations** in order to build the junction graph, and **record the sum-product operations** of the computational graph executed during message passing.

The `fastHEXLayer` object performs the **message passing procedure as a neural network module**. It takes as input the HEXGraph object to define the computational graph and an activation score vector from a previous neural network on which computations will be performed (like described in the paper). Using `torch` operations in the HEX-layer enables autodifferentiation of the computational graph and thus backpropagation of the gradient afterwards.

### HEXgraph ###
We present under descriptions of the main HEXGraph class methods as in [kelemin's](https://github.com/ronghanghu/hex_graph).

* `HEXGraph.checkConsistency(Eh, Ee)` : checks Eh and Ee are both square with no self loops and no any two nodes direct each other. Also, it makes sure there is no directed loop for Eh and no no exclusion between its ancestors or between itself and its ancestors. Additionaly, it examines that the graph is connected and Ee is symmetric.
* `HEXGraph.sparsifyDensify(Eh, Ee)` : based on Lemma 1 of the paper, makes sparsified and densified graph of Eh and Ee. Gives us minimally sparse and maximally dense equivalent graph.
* `HEXGraph.buildJT(Ehs, Ees)` : creates a junctionTree object from the sparse hierarchical and exclusion matrices.
* `HEXGraph.listStateSpace(Ehd, Eed, cliques)` : lists state space by using densified graph based on Algorithm 1 of the paper. Records every states of each variable.
* `HEXGraph.recordSumProduct(cliques, stateSpace, cliqParents, childVariables)` : records how states in a clique is connected in states in neighbor cliques.

### junctionTree ###
Builds the junction tree of a HEXgraph given its sparse hierarchical and exclusion matrices.

* `junctionTree.,minNei(Ea)` : finds an elimination sequence given the adjacency matrix using the min nei heuristic (ie. eleminate the node with the smaller set of neighbors).
* `junctionTree.buildJunctionGraph(Ehs, Ees)` : builds the junction graph given both sparse hierarchical and exclusion matrices by proceding to elimination steps according to the elimination sequence (ie. finds the cliques that will be featured in the junction tree). During this whole procedure, records which cliques each variable appears in and how many times each variable appears.
* `junctionTree.buildJunctionTree()` : computes the weight of each pair of cliques (i.e. length of intersect). Uses Kruskal algorithm to generate maximum spanning tree, and records parents and children of each cliques. Then, makes a sequence from children to parents.


### fastHEXLayer ###
These are the main methods of the `fastHEXLayer` class :

* `fastHEXLayer.initHEXgraph()` : Initializes the HEX-layer with the HEXgraph. Launches symbolic computations in the HEXgraph class.

* `fastHEXLayer.initMessages(scores)` : Initializes the messages between cliques with the sum-product object recorded in HEXgraph (size of the messages) and the potentials (with the scores).

* `fastHEXLayer.collectMessages(i, mp=True)` : Collects messages send to clique i and aggregate them with the potential. If `mp=True` also perform this for the max-product messages. 

* `fastHEXLayer.distributeMessages(i, mp=True)` : Distributes to the parent node the message obtained in `collectMessages` after projection with the transition matrix to correctly link each state of clique i to each state of its parent.

* `fastHEXLayer.decodeViterbi()` : Uses Viterbi algorithm to decode the sequence of states recorded as argmax during the max-product message passing procedure. Outputs the max global instantiation.

* `fastHEXLayer.forward(scores)` : Init messages with the scores and runs a full pass of message passing. Outputs the log-partition of the distribution.

* `fastHEXLayer.loss(state, scores, mp=False, beta=-1, log_z=None, normalize=False)` : Computes the loss for a given set of labels (state) and model scores (scores). The regularization hyper-parameter is beta.

The `HEXLayer` was too slow due to a full forward and backward pass. In `fastHEXLayer` we only do a forward pass for computing the partition (the backward pass is only used for the Viterbi algorithm).

## Examples ##

Here is a simple training loop with the HEX-layer :

```python
import pickle
import torch

from HEXgraph import HEXGraph
from fastHEXLayer import fastHEXLayer
from AlexNet import AlexNet

# Load a networkx DiGraph representing WordNet taxonomy
with open('./ImageNet/imgnet_lattice.pkl', 'rb') as file:
    L = pickle.load(file)

# Generate the HEX-graph from the DiGraph
hexg = HEXGraph(L=L)

# Init the HEX-layer from the HEX-graph
hexL = fastHEXLayer(hexg)

# Load the network
network = AlexNet(num_classes = hexL.hexg.numV)

# Define the dataloader
dataset = CustomDataset(root='./ImageNet/train', s2n_path='./ImageNet/synset_to_node.json', tm=hexL.hexg.getTM())
loader = torch.utils.data.DataLoader(train_dataset, batch_size=2048, shuffle=True, drop_last=True)

# Set up the optimizer and scheduler
optimizer = torch.optim.SGD(network.parameters(), lr=10**(-3), momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)

# Start the training loop
for epoch in range(epochs):
    # switch to train mode
    network.train()

    for i, (inputs, labels) in enumerate(loader):
        # compute scores from images
        scores = network(inputs)

        # compute the loss with the HEX-layer
        loss, log_z = hexL.loss(state=labels, scores=scores, mp=False, beta=0.1)

        # backpropagate the loss to compute the gradient then optimize the network
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
```

This repository features several notebooks which illustrate how one can use HEXgraph and HEXLayer in a practical setting. 

**exammple1.ipynb** and **exammple2.ipynb** are toy examples to quickly pick up on the main functionalities.

Besides **exammple2.ipynb** shows how the init the HEXgraph from a directed graph assuming a concept lattice with exclusive concepts whenever possible.

Eventually, **exampleImageNetHEXg.ipynb** illustrates the symbolic computation of the HEX-graph on the classes of ImageNet image classification challenge and the hierarchy of synsets from WordNet. Then, **exampleImageNet.ipynb** demonstrates how to interface the HEX-layer with a neural network and train the system on the ImageNet image classification challenge.

## Next research questions ##

* Test performance on GPUs

* During inference, how to compute the k-mode instead of just the mode, in order to compute the acc@3 or acc@5

* Marginalization of certain variables to perform learning with incomplete data

* Generalization of HEX-graph and HEX-layer to any types of propositional formulae