import sys, os
import glob
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
import igraph as ig
from einops import repeat
import torch_geometric
import pickle

if os.getcwd().split('/')[-1] in ['utils']:
    sys.path.append('..')
    from factor_graphs import *
    from pgm import *
    from belief_prop import belief_propagation
elif os.getcwd().split('/')[-1] in ['mamba', 'notebooks']:
    from belief_prop.factor_graphs import *
    from belief_prop.pgm import *
    from belief_prop.belief_prop import belief_propagation
else:
    from factor_graphs import *
    from pgm import *
    from belief_prop import *

# from generate_flipflop import *

def get_edge_potential(J):
    arr = np.array(
        [[J, -J],
        [-J, J]]
    )
    return np.exp(arr)

def get_node_potential(J, label):
    arr = np.array([-J, J]) * label
    return np.exp(arr)


def get_J(p):
    return 0.5 * np.log(p / (1 - p))


def get_factor_graph(data, J):
    train_indices = np.where(data.train_mask == True)[0]
    y = (2 * data.y - 1).cpu().numpy()

    res_factor_graph = factor_graph()
    # we first add the edge potentials
    edge_names = []
    singleton = 0
    edge = 0
    for i in range(data.edge_index.shape[-1]):
        node_name = f"f{i}"
        index_a = data.edge_index[0, i].item()
        index_b = data.edge_index[1, i].item()
        a = str(index_a)
        b = str(index_b)
        if index_a in train_indices or index_b in train_indices:
            if index_a in train_indices and index_b not in train_indices:
                node_name = f"f{singleton}"
                edge_potential = get_node_potential(J, y[index_a]) 
                res_factor_graph.add_factor_node(
                    node_name, 
                    factor([b], edge_potential)
                )
                singleton += 1
            elif index_b in train_indices and index_a not in train_indices:
                node_name = f"f{singleton}"
                edge_potential = get_node_potential(J, y[index_b]) 
                res_factor_graph.add_factor_node(
                    node_name, 
                    factor([a], edge_potential)
                )
                singleton += 1
            else:
                continue
        else:
            node_name = f"h{edge}"
            edge_potential = get_edge_potential(J)
            res_factor_graph.add_factor_node(
                node_name, 
                factor([a, b], edge_potential)
            )
            edge += 1
            edge_names.append(node_name)
    return edge_names, res_factor_graph


def get_entropy_of_graph(data, p_flip):
    p = 1 - p_flip
    marginal = np.zeros((data.x.shape[0], 2))
    not_train_indices = np.where(data.train_mask == False)[0]
    J = get_J(p=p)

    edge_names, res_factor_graph = get_factor_graph(data, J)

    bp = belief_propagation(res_factor_graph)

    for idx in not_train_indices:
        node = f"{idx}"
        m = bp.belief(node).get_distribution()
        marginal[idx] = m
    
    test_entropy = np.log(marginal[not_train_indices]) * marginal[not_train_indices]
    entropy = -1 * np.sum(test_entropy) / len(not_train_indices)
    return marginal, entropy

def load_from_pickle(filename):
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    return data

if __name__=='__main__':
    p_flips = [0.1, 0.2, 0.3, 0.7, 0.8, 0.9]
    # total_edges = [10000, 55000]
    # p_flips = [0.3]
    total_edges = [10000]
    data_folder_path = "/home/user/data/graph_datasets/flipflop"
    for te in total_edges:
        for p_flip in p_flips:
            filepath = f"flipflop_128_ns_10pf_{p_flip}_TE_{te}"
            path = os.path.join(data_folder_path, filepath, "example_0000.pt")
            # print(path)
            data = load_from_pickle(path)
            _, entropy = get_entropy_of_graph(data, p_flip)
            print(f"P Flip: {p_flip}, Entropy: {entropy}")