import sys, os
sys.path.append('/home/user/projects/graph_repositories/graphs')
sys.path.append('/home/user/projects/graph_repositories/graphs/junction-tree')
import glob
import random
import numpy as np
import torch
import torch_geometric
from torch.utils.data import Dataset
from torch_geometric.data import Data
import igraph as ig
from einops import repeat
import torch_geometric
import igraph as ig
import networkx as nx
from train_utils import *
from utils import save_as_pickle
import click

# from  belief_prop.generate_random_inductive import get_random_graph
from  generate_random_inductive import * 

import junctiontree as jt
from tests.util import assert_potentials_equal
from junctiontree import computation as comp
import copy


def ising_edge_potential(jst=1):
    sign = 1
    arr = np.array(
        [[sign * jst, -1 * sign * jst],
        [-1 * sign * jst, sign * jst]]
    )
    return arr


def get_graph(
        num_nodes, num_extra_edges, num_children, 
        jst, singleton_mean=0, singleton_var=1):
    graph_edge_index = get_random_graph(
        num_nodes, num_extra_edges=num_extra_edges, num_children=num_children)
    edge_weight = get_edge_potential(graph_edge_index, jst)
    x = get_node_potential(
        num_nodes, singleton_mean=singleton_mean, singleton_var=singleton_var)
    graph = Data(
        x=x,
        edge_index=graph_edge_index.to(torch.long),
        edge_attr=edge_weight
    )
    return graph


def run_junction_tree(graph, jst):
    x = graph.x.cpu().numpy()
    edge_index = graph.edge_index.cpu().numpy()

    factors = []
    values = []
    var_sizes = dict()
    for i in range(x.shape[0]):
        factors.append([str(i)])
        var_sizes[str(i)] = 2
        node_potential = np.squeeze(x[i])
        values.append(np.exp(np.array([-node_potential, node_potential])))

    for i in range(edge_index.shape[-1]):
        factors.append([str(edge_index[0][i]), str(edge_index[1][i])])
        edge_potential = ising_edge_potential(jst=jst)
        values.append(np.exp(edge_potential))


    _tree = jt.create_junction_tree(factors, var_sizes)
    prop_values = _tree.propagate(values)
    return prop_values

@click.command()
@click.option('--num_samples', default=1)
@click.option('--num_nodes', default=100)
@click.option('--num_extra_nodes_percent', default=0.25)
@click.option('--num_children', default=1)
@click.option('--singleton_mean', default=0)
@click.option('--singleton_var', default=1)
@click.option('--jst', required=True, type=float)
@click.option('--seed', required=True, type=int)
@click.option('--create_train_test_split', default=False)
@click.option('--create-data/--no-create-data', default=True)
@click.option('--folder-name', default='junction_tree')
def generate(
    num_samples, num_nodes, num_extra_nodes_percent, num_children,
    singleton_mean, singleton_var, jst, seed,
    create_train_test_split, create_data, folder_name):

    np.random.seed(seed)
    torch.manual_seed(seed)

    num_extra_edges = int(num_extra_nodes_percent * num_nodes)
    def create_folder():
        dataset_path = f'/home/user/data/graph_datasets/{folder_name}'
        str_jst = str(jst).replace('.', '_')
        dataset_foldername = (
            f'ns_{num_samples}_'
            f'Jst_{str_jst}_'
            f'nn_{num_nodes}_'
            f'nc_{num_children}_'
            f'nee_{num_extra_edges}_'
            f'sm_{singleton_mean}_'
            f'sv_{singleton_var}_'
            f'seed_{seed}'
        )
        print(dataset_foldername)
        folder_path = os.path.join(dataset_path, dataset_foldername)
        os.makedirs(folder_path, exist_ok=True)
        print(folder_path)
        return folder_path

    if create_data:
        folder_path = create_folder()

    for sample in range(num_samples):
        print(sample)
        graph = get_graph(
            num_nodes, num_extra_edges, 
            num_children=num_children, jst=jst, 
            singleton_mean=singleton_mean, 
            singleton_var=singleton_var)

        y = np.zeros_like(graph.x.cpu().numpy())

        prop_values = run_junction_tree(graph, jst=jst)

        for i in range(graph.x.shape[0]):
            # s = np.array(prop_values[i] / prop_values[i].sum())
            marginal = np.sum(prop_values[i], axis=0)
            s = np.array(prop_values[i] / marginal)
            # exp = s[0] * -1 + s[1]
            y[i] = s[0]

        graph.y = torch.tensor(y)

        undirected_edge = torch_geometric.utils.to_undirected(graph.edge_index)
        graph2 = get_edgewise_graph(graph, to_undirected=True)
        graph3 = get_undirected_edgewise_graph(graph, to_undirected=True)
        graph2.undirected_edge_index = graph3.edge_index
        graph.undirected_edge_index = graph.edge_index
        graph.edge_index = undirected_edge
        if create_train_test_split:
            indices = np.arange(graph.x.shape[0])
            indices = np.random.permutation(indices)
            splits = np.array([0.7, 0.90])
            train_split = indices[0: int(splits[0] * len(indices))]
            test_split = indices[int(splits[0] * len(indices)) : int(splits[1] * len(indices))]
            val_split = indices[int(splits[1] * len(indices)) :]

            train_mask = np.zeros_like(indices)
            train_mask[train_split] = True

            test_mask = np.zeros_like(indices)
            test_mask[test_split] = True

            val_mask = np.zeros_like(indices)
            val_mask[val_split] = True

            graph.train_mask = torch.tensor(train_mask).type(torch.bool)
            graph.test_mask = torch.tensor(test_mask).type(torch.bool)
            graph.val_mask = torch.tensor(val_mask).type(torch.bool)


        if create_data:
            name = f'example_{sample:04}.pt'
            graph_name = f'graph_example_{sample:04}.pt'
            save_as_pickle(graph, name, folder_path)
            save_as_pickle(graph2, graph_name, folder_path)

if __name__ == '__main__':
    generate()