import random
import numpy as np
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
import igraph as ig
import sys
sys.path.append('..')
from dataloaders.belief_dataloader import get_graph
# from belief_prop.brute_force_marginal import *
# from belief_prop.factor_graphs import *
# from belief_prop.pgm import *
# from belief_prop.belief_prop import *

from brute_force_marginal import marginal
from factor_graphs import *
from pgm import *
from belief_prop import *
import pickle

seed = 0
np.random.seed(seed)
random.seed(seed)

make_loopy = True
num_samples = 1
task_type = "single_Jst"
Jst_list = [1]
Jst_var = 1
singleton_mean = 0 
singleton_var = 1
max_num_nodes = 20
max_num_children = 4
num_loops = 5
max_tree = 10
max_children_in_trees = 2

dataset_path = '/home/user/data/graph_datasets/ising_dataset'
dataset_foldername = (
    f'brute_loopy_'
    f'only_J_'
    f'num_loops_{num_loops}_'
    f'nodes_{max_num_nodes}_'
    f'num_samples_{num_samples}_'
    f'Jst_{Jst_list[0]}_'
    f'max_child_{max_num_children}_'
    f'max_tree_{max_tree}_'
    f'max_tree_children_{max_children_in_trees}'
)
folder_path = os.path.join(dataset_path, dataset_foldername)
os.makedirs(folder_path, exist_ok=True)

def ising_edge_potential(Jst=1):
    sign = 1
    arr = np.array(
        [[sign * Jst, -1 * sign * Jst],
        [-1 * sign * Jst, sign * Jst]]
    )
    return np.exp(arr)

def ising_node_potential(i):
    J = np.random.normal(singleton_mean, singleton_var)
    return J



def get_edge_index(edge_list):
    """Returns an undirected graph edge index."""
    edge_index = torch.zeros([2, len(edge_list)]) # store the edge_info.
    num_edges = len(edge_list)
    for i in range(num_edges):
        edge_index[0][i] = edge_list[i][0]
        edge_index[1][i] = edge_list[i][1]
        # edge_index[1][i + num_edges] = edge_list[i][0]
        # edge_index[0][i + num_edges] = edge_list[i][1]
    return edge_index

num_nodes = max_num_nodes
num_children = max_num_children
Jst = Jst_list[0]
tree = get_graph(
    num_nodes, num_children, 
    get_loopy_graph=make_loopy, 
    num_loops=num_loops,
    max_tree=max_tree,
    max_children_in_trees=max_children_in_trees,
)

tree_edge_list = tree.get_edgelist()

edge_index = get_edge_index(tree_edge_list)
edge_potential = torch.zeros_like(edge_index)
# edge feature for the final graph
# x = torch.zeros([num_nodes, 2]) # store the edge potentials.
x = torch.zeros([tree.vcount(), 1]) # only store the node potential J
y = torch.zeros([tree.vcount(), 1]) # store the final marginals.

potential_arr_list = []
for i in range(len(tree_edge_list)):
    potential_arr = ising_edge_potential(Jst=Jst)
    potential_arr_list.append(potential_arr)

edge_potential = np.stack(potential_arr_list, axis=0)

for i in range(tree.vcount()):
    node_potential = ising_node_potential(i)
    potential_arr = np.exp(np.array([-node_potential, node_potential]))
    x[i] = torch.Tensor([node_potential])

graph = Data(
    x=x,
    edge_index=edge_index.to(torch.long),
    edge_attr=torch.Tensor(edge_potential),
    y=y
)

exact_marginal, _ = marginal(Jst, graph)
graph.y = torch.Tensor(exact_marginal)

name = f'example.pt'
filename = os.path.join(folder_path, name)
with open(filename, 'wb') as f:
    pickle.dump(graph, f)
