import os
import random
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
import igraph as ig
import sys
sys.path.append('..')
from dataloaders.belief_dataloader import get_graph
# from belief_prop.brute_force_marginal import *
# from belief_prop.factor_graphs import *
# from belief_prop.pgm import *
# from belief_prop.belief_prop import *
if os.getcwd().split('/')[-1] in ['graphs', 'notebooks']:
    sys.path.append(os.path.join(os.getcwd(), '..'))
    from belief_prop.factor_graphs import *
    from belief_prop.pgm import *
    from belief_prop.belief_prop import *
    from belief_prop.brute_force_marginal import marginal
else:
    from factor_graphs import *
    from pgm import *
    from belief_prop import *

# from brute_force_marginal import marginal
# from factor_graphs import *
# from pgm import *
# from belief_prop import *

seed = 0
np.random.seed(seed)
random.seed(seed)

max_num_nodes = 15
num_extra_edges = 10
max_num_children = 2
num_samples = 5000
singleton_mean = 0
singleton_var = 1
Jst_list = [-1]

dataset_path = '/home/user/data/graph_datasets/ising_dataset'
dataset_foldername = (
    f'brute_loopy_'
    f'only_J_'
    f'nodes_{max_num_nodes}_'
    f'num_samples_{num_samples}_'
    f'Jst_{Jst_list[0]}'
)
folder_path = os.path.join(dataset_path, dataset_foldername)
os.makedirs(folder_path, exist_ok=True)


def ising_edge_potential(Jst=1):
    sign = 1
    arr = np.array(
        [[sign * Jst, -1 * sign * Jst],
        [-1 * sign * Jst, sign * Jst]]
    )
    return np.exp(arr)

def ising_node_potential(i, singleton_mean, singleton_var):
    J = np.random.normal(singleton_mean, singleton_var)
    return J

def get_edge_index(edge_list):
    """Returns an undirected graph edge index."""
    edge_index = torch.zeros([2, len(edge_list)]) # store the edge_info.
    num_edges = len(edge_list)
    for i in range(num_edges):
        edge_index[0][i] = edge_list[i][0]
        edge_index[1][i] = edge_list[i][1]
    return edge_index

def exists(value, l):
    if value in l:
        return True
    else:
        return False

def get_extra_treelist(edge_list, n, num_extra_edges):
    i = 0
    extra_edge_list = []
    while len(extra_edge_list) < num_extra_edges:
        a = np.random.choice(np.arange(n))
        b = np.random.choice(np.arange(n))
        if a == b: 
            continue
        elif exists((a,b), edge_list) or exists((b,a), edge_list):
            continue
        elif exists((a,b), extra_edge_list) or exists((b,a), extra_edge_list):
            continue
        else:
            extra_edge_list.append((a,b))
    return extra_edge_list

def get_random_graph(num_nodes, num_extra_edges=5, num_children=2):
    assert num_extra_edges < num_nodes * (num_nodes-1) / 2, "exceeding max edge amount"
    tree = ig.Graph.Tree(n=num_nodes, children=num_children)
    initial_edge_list = tree.get_edgelist()
    extra_edge_list = get_extra_treelist(
        initial_edge_list, n=num_nodes, num_extra_edges=num_extra_edges)
    edge_list = initial_edge_list + extra_edge_list
    edge_index = get_edge_index(edge_list)
    return edge_index


def get_num_within_range(n, low, high):
    num_nodes = np.random.choice(np.arange(low, high))
    return num_nodes

def get_edge_potential(edge_index, Jst):
    edge_potential_list = []
    for i in range(edge_index.shape[-1]):
        edge_potential = ising_edge_potential(Jst=Jst)
        edge_potential_list.append(edge_potential)
    ep = np.stack(edge_potential_list, axis=0)
    return ep

def get_node_potential(num_nodes, singleton_mean=0, singleton_var=1):
    x = torch.zeros([num_nodes, 1])
    for i in range(num_nodes):
        x[i] = ising_node_potential(
            i, singleton_mean=singleton_mean, singleton_var=singleton_var)
    return x

if __name__ == '__main__':

    for i in range(num_samples):
        print(i)
        Jst = np.random.choice(Jst_list)
        num_nodes = get_num_within_range(
            n=max_num_nodes, low=int(0.7 * max_num_nodes), high=max_num_nodes)
        extra_edge_num = get_num_within_range(
            n=num_extra_edges, low=int(0.7 * num_extra_edges), high=num_extra_edges)
        num_children = get_num_within_range(
            n=max_num_children, low=int(0.7 * max_num_children), high=max_num_children)
        edge_index = get_random_graph(
            num_nodes, num_extra_edges=extra_edge_num, num_children=num_children)
        edge_weight = get_edge_potential(edge_index)
        x = get_node_potential(num_nodes)
        graph = Data(
            x=x,
            edge_index=edge_index.to(torch.long),
            edge_attr=edge_weight
        )
        y = marginal(Jst=Jst, graph=graph)
        graph.y = y
        name = f'example_{i:04}.pt'
        filename = os.path.join(folder_path, name)
        with open(filename, 'wb') as f:
            pickle.dump(graph, f)