import sys, os
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
import igraph as ig

from factor_graphs import *
from pgm import *
from belief_prop import *
sys.path.append('..')
from dataloaders.belief_dataloader import IsingModel
import pickle
from train_utils import get_edgewise_edge_index
from utils import save_as_pickle

seed = 0
np.random.seed(seed)
random.seed(seed)

num_samples = 5000
task_type = "single_Jst"
Jst_list = [-1]
Jst_var = 1
singleton_mean = 0 
singleton_var = 1
max_num_nodes = 15
max_num_children = 2

dataset_path = '/home/user/data/graph_datasets/ising_dataset'
dataset_foldername = (
    f'W_edge_only_J_'
    # f'only_J_'
    f'num_samples_{num_samples}_'
    f'num_nodes_{max_num_nodes}_'
    f'Jst_{Jst_list[0]}_'
    f'max_child_{max_num_children}'
)
folder_path = os.path.join(dataset_path, dataset_foldername)
os.makedirs(folder_path, exist_ok=True)
create_data = False

if task_type == "single_Jst":
    for sample in range(num_samples):
        print(f'Example {sample}')
        Jst = np.random.choice(Jst_list)
        num_nodes = np.random.randint(
            int(max_num_nodes * 0.75),
            max_num_nodes,)
        graph = IsingModel(
            Jst=Jst, 
            singleton_mean=singleton_mean, 
            singleton_var=singleton_var,
            num_nodes=num_nodes,
            num_children=max_num_children)
        edgewise_edge_index = get_edgewise_edge_index(graph.edge_index)
        graph2 = Data(
            x=graph.edge_index[1],
            edge_index=edgewise_edge_index.type(torch.long)
        )
        if create_data:
            name = f'example_{sample:04}.pt'
            graph_name = f'graph_example_{sample:04}.pt'
            save_as_pickle(graph, name, folder_path)
            save_as_pickle(graph2, graph_name, folder_path)
else:
    assert len(Jst_list) == 1, "Just provide one value in Jst, which we will use as the mean."
    for i in range(num_samples):
        print(f'Example {i}')
        Jst = np.random.normal(Jst_list[0], Jst_var)
        num_nodes = np.random.randint(
            int(max_num_nodes * 0.75),
            max_num_nodes,)
        graph = IsingModel(
            Jst=Jst, 
            singleton_mean=singleton_mean, 
            singleton_var=singleton_var,
            num_nodes=num_nodes,
            num_children=max_num_children)
        name = f'example_{i:04}.pt'
        filename = os.path.join(folder_path, name)
        with open(filename, 'wb') as f:
            pickle.dump(graph, f)

