import random
import numpy as np
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
import igraph as ig

from factor_graphs import *
from pgm import *
from belief_prop import *
sys.path.append('..')
from dataloaders.belief_dataloader import IsingModel
from brute_force_marginal import *
import pickle

seed = 0
np.random.seed(seed)
random.seed(seed)

num_samples = 1
task_type = "single_Jst"
Jst_list = [1]
Jst_var = 1
singleton_mean = 0 
singleton_var = 1
max_num_nodes = 110
max_num_children = 4
num_loops = 2

dataset_path = '/home/user/data/graph_datasets/ising_dataset'
dataset_foldername = (
    f'bp_loopy_'
    f'only_J_'
    f'num_loops_{num_loops}_'
    f'nodes_{max_num_nodes}_'
    f'num_samples_{num_samples}_'
    f'Jst_{Jst_list[0]}_'
    f'max_child_{max_num_children}'
)
folder_path = os.path.join(dataset_path, dataset_foldername)
os.makedirs(folder_path, exist_ok=True)



if task_type == "single_Jst":
    for i in range(num_samples):
        print(f'Example {i}')
        Jst = np.random.choice(Jst_list)
        num_nodes = np.random.randint(
            int(max_num_nodes * 0.75),
            max_num_nodes,)
        graph = IsingModel(
            Jst=Jst, 
            singleton_mean=singleton_mean, 
            singleton_var=singleton_var,
            num_nodes=num_nodes,
            num_children=max_num_children, 
            make_loopy=True, 
            num_loops=num_loops,
        )
        # exact_marginal = marginal(Jst, graph)
        # graph.y = torch.Tensor(exact_marginal)
        name = f'example_{i:04}.pt'
        filename = os.path.join(folder_path, name)
        with open(filename, 'wb') as f:
            pickle.dump(graph, f)
else:
    assert len(Jst_list) == 1, "Just provide one value in Jst, which we will use as the mean."
    for i in range(num_samples):
        print(f'Example {i}')
        Jst = np.random.normal(Jst_list[0], Jst_var)
        num_nodes = np.random.randint(
            int(max_num_nodes * 0.75),
            max_num_nodes,)
        graph = IsingModel(
            Jst=Jst, 
            singleton_mean=singleton_mean, 
            singleton_var=singleton_var,
            num_nodes=num_nodes,
            num_children=max_num_children)
        name = f'example_{i:04}.pt'
        filename = os.path.join(folder_path, name)
        with open(filename, 'wb') as f:
            pickle.dump(graph, f)

