import sys, os
sys.path.append('/home/user/projects/graph_repositories/graphs')
sys.path.append('/home/user/projects/graph_repositories/graphs/junction-tree')
import glob
import random
import numpy as np
import torch
import torch_geometric
from torch.utils.data import Dataset
from torch_geometric.data import Data
import igraph as ig
from einops import repeat
import torch_geometric
import igraph as ig
import networkx as nx
from train_utils import *
from utils import save_as_pickle
import click

# from  belief_prop.generate_random_inductive import get_random_graph
from  generate_random_inductive import * 

import junctiontree as jt
from tests.util import assert_potentials_equal
from junctiontree import computation as comp
import copy
np.random.seed(3)
torch.manual_seed(3)

def ising_edge_potential(jst=1):
    sign = 1
    arr = np.array(
        [[sign * jst, -1 * sign * jst],
        [-1 * sign * jst, sign * jst]]
    )
    return arr

def generate_prufer_graph(num_nodes):
    """We generate a random tree graph"""
    def get_prufer_edges(prufer, m):
        vertices = m + 2
        vertex_set = [0] * vertices
        final_edge_set = []

        # Initialize the array of vertices
        for i in range(vertices):
            vertex_set[i] = 0

        # Number of occurrences of vertex in code
        for i in range(vertices - 2):
            vertex_set[prufer[i] - 1] += 1

        j = 0

        # Find the smallest label not present in prufer[].
        for i in range(vertices - 2):
            for j in range(vertices):
                # If j+1 is not present in prufer set
                if vertex_set[j] == 0:
                    # Remove from Prufer set and print pair.
                    vertex_set[j] = -1
                    edges = [j+1, prufer[i]]
                    final_edge_set.append(edges)

                    vertex_set[prufer[i] - 1] -= 1

                    break

        j = 0

        # For the last element
        for i in range(vertices):
            if vertex_set[i] == 0 and j == 0:
                edge = []
                edge.append(i+1)
                j += 1
            elif vertex_set[i] == 0 and j == 1:
                edge.append(i+1)

        final_edge_set.append(edge)
        return final_edge_set


    length = num_nodes - 2
    arr = [0] * length 

    for i in range(length):
        arr[i] = random.randint(1, length + 1)

    graph_edge_list = get_prufer_edges(arr, length)

    return torch.tensor(np.stack(graph_edge_list, axis=1)-1)

def get_graph(num_nodes, jst, 
              singleton_mean=0, 
              singleton_var=1, 
              graph_edge_index=None):
    if graph_edge_index is None:
        graph_edge_index = generate_prufer_graph(num_nodes)
    edge_weight = get_edge_potential(graph_edge_index, jst)
    x = get_node_potential(
        num_nodes, singleton_mean=singleton_mean, singleton_var=singleton_var)
    graph = Data(
        x=x,
        edge_index=graph_edge_index.to(torch.long),
        edge_attr=edge_weight
    )
    return graph
    

def run_junction_tree(graph, jst):
    x = graph.x.cpu().numpy()
    edge_index = graph.edge_index.cpu().numpy()

    factors = []
    values = []
    var_sizes = dict()
    for i in range(x.shape[0]):
        factors.append([str(i)])
        var_sizes[str(i)] = 2
        node_potential = np.squeeze(x[i])
        values.append(np.exp(np.array([-node_potential, node_potential])))

    for i in range(edge_index.shape[-1]):
        factors.append([str(edge_index[0][i]), str(edge_index[1][i])])
        edge_potential = ising_edge_potential(jst=jst)
        values.append(np.exp(edge_potential))


    _tree = jt.create_junction_tree(factors, var_sizes)
    prop_values = _tree.propagate(values)
    return prop_values


def generate_graph_codebank(num_nodes, graphbank_size):
    graphbank = []
    for i in range(graphbank_size):
        edge_index = generate_prufer_graph(num_nodes)
        graphbank.append(edge_index)
    return graphbank



@click.command()
@click.option('--num_samples', default=1)
@click.option('--num_nodes', default=20)
@click.option('--singleton_mean', default=0)
@click.option('--singleton_var', default=1)
@click.option('--jst', required=True, type=float)
@click.option('--create_train_test_split', default=False)
@click.option('--seed', required=True, type=int)
@click.option('--create-data/--no-create-data', default=True)
@click.option('--folder-name', default='prufer_tree_graphbank')
@click.option('--graphbank-size', default=100, type=int)
def generate(num_samples, num_nodes, singleton_mean, singleton_var, jst, 
             create_train_test_split, seed, 
             create_data, folder_name, graphbank_size):

    np.random.seed(seed)
    torch.manual_seed(seed)
    def create_folder():
        dataset_path = f'/home/user/data/graph_datasets/{folder_name}'
        str_jst = str(jst).replace('.', '_')
        dataset_foldername = (
            f'ns_{num_samples}_'
            f'cb_{graphbank_size}_'
            f'Jst_{str_jst}_'
            f'nn_{num_nodes}_'
            f'sm_{singleton_mean}_'
            f'sv_{singleton_var}_'
            f'seed_{seed}'
        )
        print(dataset_foldername)
        folder_path = os.path.join(dataset_path, dataset_foldername)
        os.makedirs(folder_path, exist_ok=True)
        print(folder_path)
        return folder_path

    if create_data:
        folder_path = create_folder()

    codebank = generate_graph_codebank(num_nodes, graphbank_size)

    for sample in range(num_samples):
        print(sample)
        graph_idx = np.random.choice(graphbank_size)
        graph_edge_index = codebank[graph_idx]
        graph = get_graph(
            num_nodes,
            jst=jst, 
            singleton_mean=singleton_mean, 
            singleton_var=singleton_var, 
            graph_edge_index=graph_edge_index)

        y = np.zeros_like(graph.x.cpu().numpy())

        prop_values = run_junction_tree(graph, jst=jst)

        for i in range(graph.x.shape[0]):
            marginal = np.sum(prop_values[i], axis=0)
            s = np.array(prop_values[i] / marginal)
            y[i] = s[0]

        graph.y = torch.tensor(y)

        undirected_edge = torch_geometric.utils.to_undirected(graph.edge_index)
        graph2 = get_edgewise_graph(graph, to_undirected=True)
        graph3 = get_undirected_edgewise_graph(graph, to_undirected=True)
        graph2.undirected_edge_index = graph3.edge_index
        graph.undirected_edge_index = graph.edge_index
        graph.edge_index = undirected_edge
        if create_train_test_split:
            indices = np.arange(graph.x.shape[0])
            indices = np.random.permutation(indices)
            splits = np.array([0.7, 0.90])
            train_split = indices[0: int(splits[0] * len(indices))]
            test_split = indices[int(splits[0] * len(indices)) : int(splits[1] * len(indices))]
            val_split = indices[int(splits[1] * len(indices)) :]

            train_mask = np.zeros_like(indices)
            train_mask[train_split] = True

            test_mask = np.zeros_like(indices)
            test_mask[test_split] = True

            val_mask = np.zeros_like(indices)
            val_mask[val_split] = True

            graph.train_mask = torch.tensor(train_mask).type(torch.bool)
            graph.test_mask = torch.tensor(test_mask).type(torch.bool)
            graph.val_mask = torch.tensor(val_mask).type(torch.bool)


        if create_data:
            name = f'example_{sample:04}.pt'
            graph_name = f'graph_example_{sample:04}.pt'
            save_as_pickle(graph, name, folder_path)
            save_as_pickle(graph2, graph_name, folder_path)

if __name__ == '__main__':
    generate()