import os, sys
import numpy as np
import torch
import torch_geometric
from torch_geometric.data import Data
import pickle
import networkx as nx
import pickle
# from belief_prop.utils.calculate_flipflop_conditional import get_entropy_of_graph
from utils.calculate_flipflop_conditional import get_entropy_of_graph
import torch_geometric


def get_sign(prev_sign, p_flip):
    sampled_sign = np.random.choice([-1, 1], size=1, p=[p_flip, 1-p_flip])
    sampled_sign = sampled_sign[0]
    return prev_sign * sampled_sign


def get_mask_splits(num_nodes, split=[0.7, 0.20, 0.10]):
    """Split is [train_num, test_num, val_num]"""
    length = num_nodes
    train_length = int(split[0]*length)
    test_length = int(split[1]*length)
    # randomly sample the array
    indices = np.arange(length)
    indices_shuffled = np.random.permutation(indices)
    train_mask_indices = indices_shuffled[:train_length]
    test_mask_indices = indices_shuffled[train_length:]
    train_mask = np.zeros(length).astype(bool)
    train_mask[train_mask_indices] = True
    test_mask = np.zeros(length).astype(bool)
    test_mask[test_mask_indices] = True
    val_mask = np.zeros(length).astype(bool)
    val_mask[test_mask_indices] = True

    return train_mask, test_mask, val_mask

def generate_d_regular_tree_with_edges(d, depth, signs, p_flip):
    """
    Generate a d-regular tree with a given depth and return the edges.

    :param d: Degree of each node (number of children).
    :param depth: Depth of the tree.
    :return: List of edges representing the d-regular tree.
    """
    edges = []
    next_node_id = 0  # Start with 1 as the root node ID

    def add_edges(node_id, current_depth):
        nonlocal next_node_id, edges
        if current_depth == depth:
            return

        for _ in range(d):
            next_node_id += 1
            child_id = next_node_id
            # signs.update({child_id: get_sign(signs[node_id], p_flip=p_flip)})
            signs.append(get_sign(signs[node_id], p_flip))
            edges.append((node_id, child_id))
            add_edges(child_id, current_depth + 1)

    add_edges(0, 0)  # Start from the root node (ID 1) and depth 0
    return edges, next_node_id, signs

def get_x(y, x_1=None, x_2=None, random_init=True):
    num_dims = x_1.shape[-1]
    x = np.zeros((y.shape[0], num_dims))
    label_1 = np.where(y == -1)[0]
    label_2 = np.where(y == 1)[0]
    if random_init:
        x[label_1] = x_1
        x[label_2] = x_2
    else:
        x[label_1] = torch.tensor([5])
        x[label_2] = torch.tensor([10])
    return torch.Tensor(x)

def get_tree(degree, depth, p_flip, x_1, x_2, 
             random_init=False, init_w_zero=False):
    # Example usage
    d = degree  # Degree of each node
    signs = [np.random.choice([-1, 1])]
    # signs = [1]
    edges, next_node_id, labels = generate_d_regular_tree_with_edges(
        d, depth, signs, p_flip)
    
    edge_index = torch.tensor(edges).T
    # edge_index = torch_geometric.utils.to_undirected(edge_index)
    signs = torch.tensor(labels)
    x = get_x(signs, x_1, x_2, random_init=random_init)
    
    updated_labels = (signs + 1) / 2
    graph = Data(
        x=x,
        y=updated_labels.type(torch.LongTensor),
        edge_index=torch.Tensor(edge_index).type(torch.LongTensor),
    )
    return graph

def create_masks(graph, masks, p_flip):
    # train_mask, val_mask, test_mask = get_mask_splits(labels)
    train_mask, val_mask, test_mask = masks
    graph.train_mask = torch.Tensor(train_mask).type(torch.bool)
    graph.test_mask = torch.Tensor(test_mask).type(torch.bool)
    graph.val_mask = torch.Tensor(val_mask).type(torch.bool)

    # # # === Why are we doing this??? ====
    # if init_w_zero:
    #     graph.x[val_mask] = torch.zeros_like(graph.x[val_mask])
    #     graph.x[test_mask] = torch.zeros_like(graph.x[test_mask])

    marginals, entropy = get_entropy_of_graph(graph, p_flip)

    graph.marginals = marginals
    return graph


if __name__ == '__main__':
    num_samples = 3
    degrees = [4, 5]
    depth = 6
    # NUM_DIMS_LIST = [5, 128]
    NUM_DIMS_LIST = [128]
    np.random.seed(0)
    # p_flips = [0.3, 0.4, 0.1, 0.9, 0.2, 0.8, 0.6, 0.001, 0.999, 0.7, 0.05, 0.95]
    # p_flips = [0.45, 0.55, 0.5]
    p_flips = [0.15, 0.85, 0.25, 0.75, 0.35, 0.65]
    # p_flips = [0.05]
    # p_flips = [0.3, 0.4, 0.7, 0.1, 0.9, 0.2, 0.8, 0.6]
    # p_flips = [0.7]
    rand_inits = [True]
    zero_tests = [False]


    for NUM_DIMS in NUM_DIMS_LIST:
        x_1 = np.random.randn(NUM_DIMS)
        x_2 = np.random.randn(NUM_DIMS)
        for degree in degrees:
            graph = get_tree(degree, depth, 0.1, x_1, x_2, 
                             random_init=True)
            num_nodes = graph.x.shape[0]
            masks = get_mask_splits(num_nodes)
            for rand_init in rand_inits:
                for zero_test in zero_tests:
                    for p_flip in p_flips:
                        foldername = f'd_regular_tree_same_mask'
                        if rand_init:
                            foldername = f'{foldername}_randn'
                        if zero_test:
                            foldername = f'{foldername}_zero'
                        print(foldername)

                        dataset_path = os.path.join(
                            '/home/user/data/graph_datasets/d_regular_tree', 
                            foldername)
                        dataset_foldername = (
                            f'd_regular_{NUM_DIMS}_'
                            f'ns_{num_samples}'
                            f'pf_{p_flip}_'
                            f'degree_{degree}_depth_{depth}'
                        )
                        folder_path = os.path.join(dataset_path, dataset_foldername)
                        os.makedirs(folder_path, exist_ok=True)

                        print(folder_path)

                        for sample in range(num_samples):
                            print(sample)
                            graph = get_tree(degree, depth, p_flip, x_1, x_2, 
                                                random_init=rand_init)
                            graph = create_masks(graph, masks, p_flip=p_flip)
                            print(f'max nodes: {graph.edge_index.max()}')
                            name = f'example_{sample:04}.pt'
                            filename = os.path.join(folder_path, name)
                            with open(filename, 'wb') as f:
                                pickle.dump(graph, f)
