import sys, os
sys.path.append('/home/user/projects/graph_repositories/graphs')
sys.path.append('/home/user/projects/graph_repositories/graphs/junction-tree')
import glob
import random
import numpy as np
import torch
import torch_geometric
from torch.utils.data import Dataset
from torch_geometric.data import Data
import igraph as ig
from einops import repeat
import torch_geometric
import igraph as ig
import networkx as nx
from train_utils import *
from utils import save_as_pickle
import click

# from  belief_prop.generate_random_inductive import get_random_graph
from  generate_random_inductive import * 

import junctiontree as jt
from tests.util import assert_potentials_equal
from junctiontree import computation as comp
import copy


def ising_node_potential(i, singleton_mean, singleton_var, num_dim=1):
    J = np.random.normal(singleton_mean, singleton_var, num_dim)
    return torch.FloatTensor(J)

def get_node_potential(num_nodes, singleton_mean=0, singleton_var=1, num_dim=1):
    x = torch.zeros([num_nodes, num_dim])
    for i in range(num_nodes):
        x[i, :] = ising_node_potential(
            i, 
            singleton_mean=singleton_mean, 
            singleton_var=singleton_var, 
            num_dim=num_dim)
    return x

def generate_star(num_nodes):
    """
    Generate a d-regular tree with a given depth and return the edges.

    :param d: Degree of each node (number of children).
    :param depth: Depth of the tree.
    :return: List of edges representing the d-regular tree.
    """
    center_node = torch.zeros(num_nodes-1)
    leaves = torch.arange(num_nodes-1) + 1
    edge_index = torch.stack([center_node, leaves], dim=0)
    return edge_index

def get_final_label(node_potential, weight_matrix, num_nodes, num_dim):
    y = torch.zeros(num_nodes, num_dim).type(torch.FloatTensor)
    for i in range(num_nodes-1):
        idx = i + 1
        s = torch.sum(node_potential, dim=0) - node_potential[idx, :]
        s = s.unsqueeze(-1)
        o = torch.einsum('i j, j k -> i k', weight_matrix, s)
        y[idx, :] = torch.squeeze(o)
    return torch.tanh(y)

def get_graph(num_nodes, 
              weight_matrix, 
              singleton_mean=0, 
              singleton_var=1, 
              num_dim=1):
    graph_edge_index = generate_star(num_nodes)
    graph_edge_index = torch.tensor(graph_edge_index)
    x = get_node_potential(
        num_nodes, 
        singleton_mean=singleton_mean, 
        singleton_var=singleton_var, 
        num_dim=num_dim)
    y = get_final_label(x, weight_matrix, num_nodes, num_dim)
    graph = Data(
        x=x,
        y=y,
        edge_index=graph_edge_index.to(torch.long),
    )
    return graph


@click.command()
@click.option('--num_samples', default=1)
@click.option('--num_nodes', default=20)
@click.option('--num_dim', default=5)
@click.option('--singleton_mean', default=0)
@click.option('--singleton_var', default=1)
@click.option('--seed', required=True, type=int)
@click.option('--create_train_test_split', default=False)
@click.option('--create-data/--no-create-data', default=True)
@click.option('--folder-name', default='star_graph')
def generate(
    num_samples, num_nodes, num_dim,
    singleton_mean, singleton_var, seed,
    create_train_test_split, create_data, folder_name):

    np.random.seed(seed)
    torch.manual_seed(seed)

    def create_folder():
        dataset_path = f'/home/user/data/graph_datasets/{folder_name}'
        dataset_foldername = (
            f'ns_{num_samples}_'
            f'num_nodes_{num_nodes}_'
            f'num_dim_{num_dim}_'
            f'sm_{singleton_mean}_'
            f'sv_{singleton_var}_'
            f'seed_{seed}'
        )
        print(dataset_foldername)
        folder_path = os.path.join(dataset_path, dataset_foldername)
        os.makedirs(folder_path, exist_ok=True)
        print(folder_path)
        return folder_path

    weight_matrix = torch.randn(num_dim, num_dim)
    if create_data:
        folder_path = create_folder()

    for sample in range(num_samples):
        print(sample)
        graph = get_graph(
            num_nodes, 
            weight_matrix=weight_matrix, 
            singleton_mean=singleton_mean, 
            singleton_var=singleton_var,
            num_dim=num_dim)

        undirected_edge = torch_geometric.utils.to_undirected(graph.edge_index)
        graph2 = get_edgewise_graph(graph, to_undirected=True)
        graph3 = get_undirected_edgewise_graph(graph, to_undirected=True)
        graph2.undirected_edge_index = graph3.edge_index
        graph.undirected_edge_index = graph.edge_index
        graph.edge_index = undirected_edge
        if create_train_test_split:
            indices = np.arange(graph.x.shape[0])
            indices = np.random.permutation(indices)
            splits = np.array([0.7, 0.90])
            train_split = indices[0: int(splits[0] * len(indices))]
            test_split = indices[int(splits[0] * len(indices)) : int(splits[1] * len(indices))]
            val_split = indices[int(splits[1] * len(indices)) :]

            train_mask = np.zeros_like(indices)
            train_mask[train_split] = True

            test_mask = np.zeros_like(indices)
            test_mask[test_split] = True

            val_mask = np.zeros_like(indices)
            val_mask[val_split] = True

            graph.train_mask = torch.tensor(train_mask).type(torch.bool)
            graph.test_mask = torch.tensor(test_mask).type(torch.bool)
            graph.val_mask = torch.tensor(val_mask).type(torch.bool)


        if create_data:
            name = f'example_{sample:04}.pt'
            graph_name = f'graph_example_{sample:04}.pt'
            save_as_pickle(graph, name, folder_path)
            save_as_pickle(graph2, graph_name, folder_path)

if __name__ == '__main__':
    generate()