import sys, os
sys.path.append('/home/user/projects/graph_repositories/graphs')
sys.path.append('/home/user/projects/graph_repositories/graphs/junction-tree')
import glob
import random
import numpy as np
import torch
import torch_geometric
from torch.utils.data import Dataset
from torch_geometric.data import Data
import igraph as ig
from einops import repeat
import torch_geometric
import igraph as ig
import networkx as nx
from train_utils import *
from utils import save_as_pickle
import click

# from  belief_prop.generate_random_inductive import get_random_graph
# from  generate_random_inductive import * 

import junctiontree as jt
from tests.util import assert_potentials_equal
from junctiontree import computation as comp
import copy
np.random.seed(3)
torch.manual_seed(3)


def ising_edge_potential(jst=1):
    sign = 1
    arr = np.array(
        [[sign * jst, -1 * sign * jst],
        [-1 * sign * jst, sign * jst]]
    )
    return arr

def get_edge_potential(edge_index, Jst):
    edge_potential_list = []
    for i in range(edge_index.shape[-1]):
        edge_potential = ising_edge_potential(jst=Jst)
        edge_potential_list.append(edge_potential)
    ep = np.stack(edge_potential_list, axis=0)
    return ep

def node_potential(val):
    sign = np.random.choice([-1, 1])
    return sign * val

def get_node_potential(num_nodes, val):
    x = torch.zeros([num_nodes, 1])
    for i in range(num_nodes):
        x[i] = node_potential(val)
    return x


def get_graph(
        num_nodes, jst, node_potential_val=1000):
    hub_node = num_nodes
    num_paths = int(np.sqrt(num_nodes))
    num_nodes_per_path = int(num_nodes // num_paths)
    nodes = np.arange(num_nodes).reshape((num_paths, num_nodes_per_path))

    hub_edges_index = np.stack([np.ones_like(nodes.flatten())*hub_node, nodes.flatten()], axis=0)
    edge_index_source = nodes[:, :num_nodes_per_path-1].flatten()
    edge_index_target = nodes[:, 1:].flatten()
    edge_index_paths = np.stack([edge_index_source, edge_index_target], axis=0)
    graph_edge_index = np.concatenate([edge_index_paths, hub_edges_index], axis=1)


    # edge_weight = get_edge_potential(graph_edge_index, jst)
    edge_potential_nodes = torch.tensor(get_edge_potential(edge_index_paths, Jst=0))
    edge_potential_hub = torch.tensor(get_edge_potential(hub_edges_index, Jst=jst))
    edge_weight = torch.cat([edge_potential_nodes, edge_potential_hub], axis=0)


    x = get_node_potential(num_nodes+1, node_potential_val)
    x_rearranged = x[:-1].reshape(num_paths, num_nodes_per_path)
    x_rearranged[:, int(num_nodes_per_path // 2)+1:] = 0
    x[:-1, :] = x_rearranged.flatten().unsqueeze(-1)

    graph = Data(
        x=x,
        edge_index=torch.tensor(graph_edge_index).to(torch.long),
        edge_attr=edge_weight
    )
    return graph


def run_junction_tree(graph, jst):
    x = graph.x.cpu().numpy()
    edge_index = graph.edge_index.cpu().numpy()

    factors = []
    values = []
    var_sizes = dict()
    for i in range(x.shape[0]):
        factors.append([str(i)])
        var_sizes[str(i)] = 2
        node_potential = np.squeeze(x[i])
        values.append(np.exp(np.array([-node_potential, node_potential])))

    for i in range(edge_index.shape[-1]):
        factors.append([str(edge_index[0][i]), str(edge_index[1][i])])
        edge_potential = ising_edge_potential(jst=jst)
        values.append(np.exp(edge_potential))


    _tree = jt.create_junction_tree(factors, var_sizes)
    prop_values = _tree.propagate(values)
    return prop_values

@click.command()
@click.option('--num_samples', default=1)
@click.option('--num_nodes', default=100)
@click.option('--node_potential_val', default=1000)
@click.option('--jst', required=True, type=float)
@click.option('--create_train_test_split', default=False)
@click.option('--create-data/--no-create-data', default=True)
@click.option('--folder-name', default='path_graph')
@click.option('--undirected_edgewise', default=False)
def generate(
    num_samples, num_nodes, node_potential_val, jst,
    create_train_test_split, create_data, folder_name, undirected_edgewise):

    def create_folder():
        str_jst = str(jst)
        dataset_path = f'/home/user/data/graph_datasets/{folder_name}'
        dataset_foldername = (
            f'ns_{num_samples}_'
            f'Jst_{str_jst}_'
            f'nn_{num_nodes}_'
            f'npv_{node_potential_val}_'
            f'uew_{undirected_edgewise}'
        )
        print(dataset_foldername)
        folder_path = os.path.join(dataset_path, dataset_foldername)
        os.makedirs(folder_path, exist_ok=True)
        print(folder_path)
        return folder_path

    if create_data:
        folder_path = create_folder()

    for sample in range(num_samples):
        print(sample)
        graph = get_graph(num_nodes, jst=jst, node_potential_val=node_potential_val)

        y = np.zeros_like(graph.x.cpu().numpy())

        prop_values = run_junction_tree(graph, jst=jst)

        for i in range(graph.x.shape[0]):
            # s = np.array(prop_values[i] / prop_values[i].sum())
            marginal = np.sum(prop_values[i], axis=0)
            s = np.array(prop_values[i] / marginal)
            # exp = s[0] * -1 + s[1]
            y[i] = s[0]

        graph.y = torch.tensor(y)

        undirected_edge = torch_geometric.utils.to_undirected(graph.edge_index)

        if undirected_edgewise:
            graph2 = get_undirected_edgewise_graph(graph, to_undirected=True)
        else:
            graph2 = get_edgewise_graph(graph, to_undirected=True)

        graph.edge_index = undirected_edge
        if create_train_test_split:
            indices = np.arange(graph.x.shape[0])
            indices = np.random.permutation(indices)
            splits = np.array([0.7, 0.90])
            train_split = indices[0: int(splits[0] * len(indices))]
            test_split = indices[int(splits[0] * len(indices)) : int(splits[1] * len(indices))]
            val_split = indices[int(splits[1] * len(indices)) :]

            train_mask = np.zeros_like(indices)
            train_mask[train_split] = True

            test_mask = np.zeros_like(indices)
            test_mask[test_split] = True

            val_mask = np.zeros_like(indices)
            val_mask[val_split] = True

            graph.train_mask = torch.tensor(train_mask).type(torch.bool)
            graph.test_mask = torch.tensor(test_mask).type(torch.bool)
            graph.val_mask = torch.tensor(val_mask).type(torch.bool)


        if create_data:
            name = f'example_{sample:04}.pt'
            graph_name = f'graph_example_{sample:04}.pt'
            save_as_pickle(graph, name, folder_path)
            save_as_pickle(graph2, graph_name, folder_path)

if __name__ == '__main__':
    generate()