import os, sys
sys.path.append('..')
import torch_geometric
import itertools

import numpy as np

def calculate_potential(x_arr, Jst=1, graph=None):
    assert graph is not None, "need to provide a graph dummy"
    pair_wise_sum = 0
    for i in range(graph.edge_index.shape[1]):
        edge = graph.edge_index[:, i]
        pair_wise_sum += Jst * x_arr[edge[0]] * x_arr[edge[1]]

    node_wise_sum = 0
    for i in range(graph.x.shape[0]):
        node_wise_sum += graph.x[i].item() * x_arr[i]
    return node_wise_sum + pair_wise_sum


def get_Z(Jst, graph=None):
    assert graph is not None, "need to provide a graph dummy"
    bits = [-1, 1]
    Z = 0
    index = 0
    for l in itertools.product(bits, repeat=graph.x.shape[0]):
        index += 1
        # print(index)
        arr = np.array(l)
        potential = calculate_potential(arr, Jst=Jst, graph=graph)
        Z += np.exp(potential)
    return Z

def get_node_marginal(Jst, x_val, insert_index, graph=None):
    assert graph is not None, "need to provide a graph dummy"
    assert x_val in [-1, 1], "invalid x_val"
    bits = [-1, 1]
    X = 0
    for l in itertools.product(bits, repeat=graph.x.shape[0]-1):
        l = list(l)
        l.insert(insert_index, x_val)
        potential = calculate_potential(l, Jst=Jst, graph=graph)
        X += np.exp(potential)
    return X


def marginal(Jst, graph):
    Z = get_Z(Jst, graph)
    marginals = np.zeros((graph.x.shape[0], 2))
    for i in range(graph.x.shape[0]):
        for j, x_val in zip([0, 1], [-1, 1]):
            # print(i,j)
            marginals[i, j] = get_node_marginal(Jst, x_val, i, graph)
    m = marginals / Z
    mm = m[:, 0] * -1 + m[:, 1]
    return mm, m
