import os, sys
import numpy as np
import torch
import torch_geometric
from torch_geometric.data import Data
import pickle
from belief_prop.utils.calculate_flipflop_conditional import get_entropy_of_graph


def get_sign(prev_sign, p=0.7):
    sampled_sign = np.random.choice([-1, 1], size=1, p=[p, 1-p])
    return prev_sign * sampled_sign

def create_initial_edge_index(tree):
    edge_index = np.zeros([2, len(tree)-1])
    edge_index[0, :] = tree[:len(tree)-1]
    edge_index[1, :] = tree[1:len(tree)]
    return edge_index

def update_edge_index(edge_index, edge):
    return np.concatenate((edge_index, edge), axis=-1)


def get_mask_splits(labels, split=[0.7, 0.20, 0.10]):
    """Split is [train_num, test_num, val_num]"""
    length = len(labels)
    train_length = int(split[0]*length)
    test_length = int(split[1]*length)
    # randomly sample the array
    indices = np.arange(length)
    indices_shuffled = np.random.permutation(indices)
    train_mask_indices = indices_shuffled[:train_length]
    test_mask_indices = indices_shuffled[train_length:]
    train_mask = np.zeros(length).astype(bool)
    train_mask[train_mask_indices] = True
    test_mask = np.zeros(length).astype(bool)
    test_mask[test_mask_indices] = True
    val_mask = np.zeros(length).astype(bool)
    val_mask[test_mask_indices] = True

    return train_mask, test_mask, val_mask


def get_tree(chain_length, extra_edges, 
             prob_flip=0.7, create_masks=False, x_1=None, x_2=None, pm1=False):
    tree = np.arange(chain_length)
    edge_index = create_initial_edge_index(tree)

    initial_sign = np.random.choice([-1, 1], 1) # choose 1 of them randomly.
    labels = [initial_sign]

    for i in range(1, len(tree)):
        sign = get_sign(labels[i-1], p=prob_flip)
        labels.append(sign)

    # now off to adding edges
    for _ in range(extra_edges):
        node_picked = np.random.choice(tree, 1)[0]
        node_index = len(tree)
        edge = np.expand_dims(np.array([node_picked, node_index]), -1)

        # add this to the tree first
        tree = np.concatenate((tree, np.array([node_index])), axis=-1)
        edge_index = update_edge_index(edge_index, edge)

        prev_sign = labels[node_picked]
        new_sign = get_sign(prev_sign, p=prob_flip)
        labels.append(new_sign)

    labels = torch.Tensor(np.array(labels).squeeze())
    updated_labels = (labels + 1)/2

    graph = Data(
        x=get_x(labels, x_1, x_2, pm1=pm1),
        # x=torch.randn_like(labels).unsqueeze(-1) + labels,
        y=updated_labels.type(torch.int64),
        edge_index=torch.Tensor(edge_index).type(torch.int64),
    )
    if create_masks:
        train_mask, val_mask, test_mask = get_mask_splits(labels)
        graph.train_mask = torch.Tensor(train_mask).type(torch.bool)
        graph.test_mask = torch.Tensor(test_mask).type(torch.bool)
        graph.val_mask = torch.Tensor(val_mask).type(torch.bool)

        # graph.x[val_mask] = torch.randn_like(graph.x[val_mask])
        # graph.x[test_mask] = torch.randn_like(graph.x[test_mask])
        graph.x[val_mask] = torch.zeros_like(graph.x[val_mask])
        graph.x[test_mask] = torch.zeros_like(graph.x[test_mask])

        marginals, entropy = get_entropy_of_graph(graph, prob_flip)

        graph.marginals = marginals
    # graph.edge_index = torch_geometric.utils.to_undirected(graph.edge_index)

    return graph


def get_x(y, x_1=None, x_2=None, pm1=False):
    if pm1:
        x = np.zeros((y.shape[0], 1))
        label_1 = np.where(y == -1)[0]
        label_2 = np.where(y == 1)[0]
        x[label_1] = 1
        x[label_2] = -1
    else:
        num_dims = x_1.shape[-1]
        x = np.zeros((y.shape[0], num_dims))
        label_1 = np.where(y == -1)[0]
        label_2 = np.where(y == 1)[0]
        x[label_1] = x_1
        x[label_2] = x_2
    return torch.Tensor(x)

if __name__ == '__main__':
    chain_length = 5_000
    # extra_edges = 50_000
    extra_edges = 5_000
    num_samples = 10
    create_masks = True
    NUM_DIMS = 128
    np.random.seed(0)
    x_1 = np.random.randn(NUM_DIMS)
    x_2 = np.random.randn(NUM_DIMS)

    # prob_flip = 0.3
    prob_flips = [0.3, 0.7, 0.1, 0.9, 0.2, 0.8]

    for prob_flip in prob_flips:
        dataset_path = '/home/user/data/graph_datasets/flipflop'
        dataset_foldername = (
            f'flipflop_{NUM_DIMS}_'
            f'ns_{num_samples}'
            f'pf_{prob_flip}_'
            f'TE_{chain_length+extra_edges}'
        )
        folder_path = os.path.join(dataset_path, dataset_foldername)
        os.makedirs(folder_path, exist_ok=True)
        
        print(folder_path)
        for sample in range(num_samples):
            print(sample)
            # create initial length
            graph = get_tree(
                chain_length=chain_length, 
                extra_edges=extra_edges, 
                prob_flip=prob_flip, 
                create_masks=create_masks,
                x_1=x_1,
                x_2=x_2,
                pm1=False
            )
            name = f'example_{sample:04}.pt'
            filename = os.path.join(folder_path, name)
            with open(filename, 'wb') as f:
                pickle.dump(graph, f)

    
