"""
Code to first generate a graph and then sample from it. 
TODOs:
"""
import sys, os
import glob
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
import igraph as ig
from einops import repeat
import torch_geometric
import pickle

if os.getcwd().split('/')[-1] in ['belief_prop']:
    from factor_graphs import *
    from pgm import *
    from belief_prop import *
# sys.path.append(os.path.join(os.getcwd(), '..'))
else:
    from belief_prop.factor_graphs import *
    from belief_prop.pgm import *
    from belief_prop.belief_prop import *
# else:


def load_from_pickle(filename):
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    return data

def get_combined_graph(g1, g2):
    el1 = g1.get_edgelist()
    el2 = g2.get_edgelist()
    v_count1 = g1.vcount()
    v_count2 = g2.vcount()
    new_graph = ig.Graph(v_count1 + v_count2)
    
    def get_combined_edgelist():
        for t in el2:
            el1.append((t[0] + v_count1, t[1] + v_count1))
        el1.append((np.random.randint(v_count1), v_count1))
        return el1
    final_edgelist = get_combined_edgelist()
    new_graph.add_edges(final_edgelist)
    return new_graph


def get_graph(num_nodes, num_children, 
              get_loopy_graph=False, 
              num_loops=1, 
              max_tree=4,
              max_children_in_trees=2):
    """The only loopy graph that we are going to work with is going to have one
    big ring and then it will have a tree attached to it. 
    """
    if get_loopy_graph:
        if num_loops == 1:
            num_nodes_in_ring = np.random.randint(
                int(0.5 * num_nodes), int(0.8 * num_nodes))
            ring_graph = ig.Graph.Ring(n=num_nodes_in_ring, circular=True)
            children = min(num_children, num_nodes - num_nodes_in_ring)
            tree = ig.Graph.Tree(n=num_nodes - num_nodes_in_ring, 
                                children=children)
            final_graph = get_combined_graph(ring_graph, tree)
        elif num_loops > 1:
            num_nodes_in_ring = np.random.randint(
                int(0.5 * num_nodes), int(0.8 * num_nodes))
            num_nodes_in_ring = num_nodes
            ring_graph = ig.Graph.Ring(n=num_nodes_in_ring, circular=True)
            tree = ig.Graph.Tree(n=np.random.randint(1, max_tree), 
                                 children=max_children_in_trees)
            final_graph = get_combined_graph(ring_graph, tree)
            for _ in range(num_loops-1):
                num_nodes_in_ring = num_nodes
                num_nodes_in_ring = np.random.randint(
                    int(0.5 * num_nodes), int(0.8 * num_nodes))
                ring_graph = ig.Graph.Ring(n=num_nodes_in_ring, circular=True)
                tree = ig.Graph.Tree(n=np.random.randint(1, max_tree), 
                                     children=max_children_in_trees)
                intermediate_graph = get_combined_graph(ring_graph, tree)
                final_graph = get_combined_graph(intermediate_graph, final_graph)
        else:
            ValueError("Invalid num rings")
    else:
        final_graph = ig.Graph.Tree(n=num_nodes, children=num_children)
    return final_graph

# We make a function instead of a class since we will only like to sample one graph at a time.
def IsingModel(Jst, singleton_mean, singleton_var, 
               num_nodes=100, num_children=2, make_undirected=True, 
               make_loopy=False, num_loops=1, 
               return_tree=False, loopy_iters=None):
    """function to generate the Ising Model

    Jst: the J for the potential between two nodes. 
    singleton_mean: Mean of the Gaussian distribution we sample the singleton
        potential from.
    singleton_var: Variance of the Gaussian distribution we sample the singleton
        potential from.
    num_nodes: number of nodes in the tree. 
    num_children: maximum number of children in the tree.

    returns: Pytorch Data
    """
    if loopy_iters is None:
        loopy_iters = num_nodes

    def ising_edge_potential(Jst=1):
        sign = 1
        arr = np.array(
            [[sign * Jst, -1 * sign * Jst],
            [-1 * sign * Jst, sign * Jst]]
        )
        return np.exp(arr)

    def ising_node_potential(i):
        J = np.random.normal(singleton_mean, singleton_var)
        return J

    def get_edge_index(edge_list):
        """Returns an undirected graph edge index."""
        edge_index = torch.zeros([2, len(edge_list)]) # store the edge_info.
        num_edges = len(edge_list)
        for i in range(num_edges):
            edge_index[0][i] = edge_list[i][0]
            edge_index[1][i] = edge_list[i][1]
            # edge_index[1][i + num_edges] = edge_list[i][0]
            # edge_index[0][i + num_edges] = edge_list[i][1]
        return edge_index

    tree = get_graph(
        num_nodes, num_children, 
        get_loopy_graph=make_loopy, num_loops=num_loops)

    tree_edge_list = tree.get_edgelist()
    res_factor_graph = factor_graph()

    edge_index = get_edge_index(tree_edge_list)
    edge_potential = torch.zeros_like(edge_index)
    # edge feature for the final graph
    # x = torch.zeros([num_nodes, 2]) # store the edge potentials.
    x = torch.zeros([tree.vcount(), 1]) # only store the node potential J
    y = torch.zeros([tree.vcount(), 1]) # store the final marginals.

    potential_arr_list = []
    for i in range(len(tree_edge_list)):
        node_name = f'f{i}'
        a = str(tree_edge_list[i][0])
        b = str(tree_edge_list[i][1])
        potential_arr = ising_edge_potential(Jst=Jst)
        potential_arr_list.append(potential_arr)
        res_factor_graph.add_factor_node(
            node_name,
            factor([a, b], potential_arr)
        )
    edge_potential = np.stack(potential_arr_list, axis=0)

    for i in range(tree.vcount()):
        node_name = f'h{i}'
        a = str(i)
        node_potential = ising_node_potential(i)
        potential_arr = np.exp(np.array([-node_potential, node_potential]))
        x[i] = torch.Tensor([node_potential])
        res_factor_graph.add_factor_node(
            node_name,
            factor([a], potential_arr)
        )

    if make_loopy:
        bp = loopy_belief_propagation(res_factor_graph)
    else:
        bp = belief_propagation(res_factor_graph)

    for i in range(tree.vcount()):
        node = f'{i}'
        if make_loopy:
            marginal = bp.belief(node, num_iter=tree.vcount()).get_distribution()
        else:
            marginal = bp.belief(node).get_distribution()
        y[i] = marginal[0]

    graph = Data(x=x, 
                 edge_index=edge_index.to(torch.long), 
                 edge_attr=torch.Tensor(edge_potential),
                 y=y)
    if make_undirected:
        u_edge_index = torch_geometric.utils.to_undirected(graph.edge_index)
        graph.edge_index = u_edge_index
    if return_tree:
        return graph, tree
    else:
        return graph

def get_expectation(y):
    return y * -1 + (1 - y)

def get_transductive_graph(Jst, singleton_mean, singleton_var, 
                           num_nodes=100, num_children=2,
                           return_only_J=False,
                           train_val_test=np.array([0.7, 0.1, 0.2])):

    def get_train_val_test_mask(train_val_test):

        def get_mask(indices):
            mask = np.zeros(num_nodes)
            mask[indices] = 1
            return mask
        
        l = np.arange(num_nodes)
        random.shuffle(l)
        train_idx = int(train_val_test[0] * num_nodes)
        val_idx = int(train_val_test[1] * num_nodes)
        train_mask = get_mask(l[:train_idx]).astype(bool)
        val_mask = get_mask(l[train_idx: train_idx + val_idx]).astype(bool)
        test_mask = get_mask(l[train_idx + val_idx:]).astype(bool)
        return train_mask, val_mask, test_mask

    assert np.sum(train_val_test) == 1, "train val split percentages should sum to 1"
    graph = IsingModel(Jst, singleton_mean, 
                       singleton_var, num_nodes, num_children)
    if return_only_J:
        assert graph.x.shape[-1], "the dimensionality of x should be 1"
        graph.y = get_expectation(graph.y)
    else:
        if graph.x.shape[-1] == 1:
            xx = torch.concatenate([-graph.x, graph.x], axis=-1)
            graph.x = xx
            graph.y = get_expectation(graph.y)
        
    train_mask, val_mask, test_mask = get_train_val_test_mask(train_val_test)

    graph.train_mask = train_mask
    graph.val_mask = val_mask
    graph.test_mask = test_mask
    return graph



class InductiveFolderDataLoader(Dataset): 
    def __init__(self,
                 dataset_folder,
                 num_samples,
                 mode='train',
                 file_prefix='example',
                 edge_file_prefix=None,
                 turn_to_expectation=True,
                 remove_all_edge_info=False):
        self.dataset_folder = dataset_folder
        self.num_samples = num_samples
        self.file_prefix = file_prefix
        self.edge_file_prefix = edge_file_prefix
        self.mode = mode
        self.remove_all_edge_info = remove_all_edge_info
        self.filenames, self.edge_filenames = self.get_filenames(
            self.dataset_folder)
        self.turn_to_expectation = turn_to_expectation
        if self.mode == 'test':
            assert self.num_samples < 0.2 * len(self.filenames), (
                'make sure that the total testing samples are at max one fourth of the dataset')
        else:
            assert self.num_samples < 0.8 * len(self.filenames), (
                'make sure that the total training samples are at max three fourths of the dataset')
    
    def get_filenames(self, dataset_folder: str):
        filenames = glob.glob(
            os.path.join(dataset_folder, f'{self.file_prefix}_*.pt'))
        if self.edge_file_prefix is not None:
            edge_filenames = glob.glob(
                os.path.join(dataset_folder, f'{self.edge_file_prefix}_*.pt'))
            assert len(filenames) == len(edge_filenames), (
                "Number of graphs, and edgewise graphs do not match"
            )
            return filenames, edge_filenames
        else:
            return filenames, None

    def remove_edge_info(self, data):
        arr = torch.LongTensor([i for i in range(data.x.shape[0])])
        arr = repeat(arr, 'a -> b a', b=2)
        data.edge_index = arr
        return data

    def get_expectation(self, y):
        return y * -1 + (1 - y)
    
    def __getitem__(self, index):
        edgewise_graph = None
        if self.mode == 'test':
            filename = self.filenames[-index]
            if self.edge_file_prefix is not None:
                graph_filename = self.edge_filenames[-index]
                edgewise_graph = load_from_pickle(graph_filename)
        else:
            filename = self.filenames[index]
            if self.edge_file_prefix is not None:
                graph_filename = self.edge_filenames[index]
                edgewise_graph = load_from_pickle(graph_filename)

        graph = load_from_pickle(filename)
        graph.y = self.get_expectation(graph.y)
        
        if edgewise_graph is not None:
            return graph, edgewise_graph
        else:
            return graph

    def __len__(self):
        return self.num_samples


class InductiveFolderDataLoaderStarGraph(Dataset): 
    def __init__(self,
                 dataset_folder,
                 num_samples,
                 mode='train',
                 file_prefix='example',
                 edge_file_prefix=None,
                 turn_to_expectation=True,
                 remove_all_edge_info=False):
        self.dataset_folder = dataset_folder
        self.num_samples = num_samples
        self.file_prefix = file_prefix
        self.edge_file_prefix = edge_file_prefix
        self.mode = mode
        self.remove_all_edge_info = remove_all_edge_info
        self.filenames, self.edge_filenames = self.get_filenames(
            self.dataset_folder)
        if self.mode == 'test':
            assert self.num_samples < 0.2 * len(self.filenames), (
                'make sure that the total testing samples are at max one fourth of the dataset')
        else:
            assert self.num_samples < 0.8 * len(self.filenames), (
                'make sure that the total training samples are at max three fourths of the dataset')
    
    def get_filenames(self, dataset_folder: str):
        filenames = glob.glob(
            os.path.join(dataset_folder, f'{self.file_prefix}_*.pt'))
        if self.edge_file_prefix is not None:
            edge_filenames = glob.glob(
                os.path.join(dataset_folder, f'{self.edge_file_prefix}_*.pt'))
            assert len(filenames) == len(edge_filenames), (
                "Number of graphs, and edgewise graphs do not match"
            )
            return filenames, edge_filenames
        else:
            return filenames, None

    
    def __getitem__(self, index):
        edgewise_graph = None
        if self.mode == 'test':
            filename = self.filenames[-index]
            if self.edge_file_prefix is not None:
                graph_filename = self.edge_filenames[-index]
                edgewise_graph = load_from_pickle(graph_filename)
        else:
            filename = self.filenames[index]
            if self.edge_file_prefix is not None:
                graph_filename = self.edge_filenames[index]
                edgewise_graph = load_from_pickle(graph_filename)

        graph = load_from_pickle(filename)
        
        if edgewise_graph is not None:
            return graph, edgewise_graph
        else:
            return graph

    def __len__(self):
        return self.num_samples

