
import torch
import torch.nn as nn
import numpy as np
from treelib import Node, Tree
import sys
import scipy
from torch_geometric.data import Data
import random
from itertools import combinations
import scipy.special
from torch_geometric.utils.convert import to_scipy_sparse_matrix, from_scipy_sparse_matrix
import sys


def sample_radius(epsilon, d) :
    return epsilon * ((np.random.uniform(0,1))**(1/(d-1)))


cos_sim = nn.CosineSimilarity()


def predict_from_model(model_local, data_local, adj_normalized, device):

    model_local.eval()
    data_local = data_local.to(device)

    with torch.no_grad():
        ground_truth = model_local(data_local.x, adj_normalized)
        ground_truth = torch.softmax(ground_truth, -1)

    return ground_truth

def Build_CRF_Tree_Structure(classifier, features, adj, labels, idx_train,idx_val,idx_test, radius, num_samples, num_iteration, device):
    # Idx of graphs
    idx_graph = 0
    n_nodes = labels.size(0)
    # Mapping Id to graphs
    map_idx_graphs = {}

    tree = Tree()
    # Make prediction
    # new_data = data.clone()*
    new_adj = adj.copy()
    new_data = Data(features = features, adj= new_adj)
    pred_ = classifier.predict(features, new_adj)
    pred_ = torch.softmax(pred_, -1)
    
    # Create the root node
    new_data.prediction = pred_.cpu()

    tree.create_node(str(idx_graph), str(idx_graph), data=new_data)
    # map_idx_graphs[idx_graph] = data
    idx_graph += 1

    # Generate neighboors_inputs
    for t_ in range(num_samples):
        # Sample in the Neighorhood of the graph

        distance_weights =  [scipy.special.comb(radius, k) for k in range(radius)] 
        Z_partion_function = np.sum(distance_weights)
        # Using the Hqmming distance, the probability weights increase and then decrease, which make sense (we do not see this behavior for continious input features)
        sampled_distance = random.choices(population = [k for k in range(radius)] ,weights =  distance_weights)[0]
        all_possible_pairs = list(combinations(np.arange(n_nodes).tolist(), 2))
        pairs_to_change = random.sample(all_possible_pairs,sampled_distance)
        new_adj = adj.copy()
        
        for k in range(len(pairs_to_change)) :
            new_adj[pairs_to_change[k][0],pairs_to_change[k][1]] = 1 - new_adj[pairs_to_change[k][0],pairs_to_change[k][1]]
        
        new_data = Data()
        new_data.features = features
        new_data.adj = new_adj
        # Make prediction
        pred_ = classifier.predict(features, new_adj)
        pred_ = torch.softmax(pred_, -1)
        new_data.prediction = pred_.cpu()
        g = scipy.special.comb(radius,sampled_distance)/Z_partion_function
        new_data.similarity = g

        # del new_data.x
        # Add the sample to the Tree
        tree.create_node(str(idx_graph), str(idx_graph), parent=str(0), data=new_data)
        idx_graph += 1
        del new_data
    while tree.depth() <= num_iteration :

        leaf_nodes = list(tree.filter_nodes(lambda x: tree.depth(x) == tree.depth()))
        for lf_node in leaf_nodes :
            # data = map_idx_graphs[int(lf_node.identifier)]
            adj_data = lf_node.data.adj
            # Generate neighboors_inputs
            for t_ in range(num_samples) :


                distance_weights =  [scipy.special.comb(radius, k) for k in range(radius)] 
                Z_partion_function = np.sum(distance_weights)
                
                # Using the Hqmming distance, the probability weights increase and then decrease, which make sense (we do not see this behavior for continious input features)
                sampled_distance = random.choices(population = [k for k in range(radius)] ,weights =  distance_weights)[0]
                all_possible_pairs = list(combinations(np.arange(n_nodes).tolist(), 2))
                pairs_to_change = random.sample(all_possible_pairs,sampled_distance)
                new_adj = adj_data.copy()
                for k in range(len(pairs_to_change)) :
                    new_adj[pairs_to_change[k][0],pairs_to_change[k][1]] = 1 - new_adj[pairs_to_change[k][0],pairs_to_change[k][1]]
        
                new_data = Data()
                new_data.features = features
                new_data.adj = new_adj
                # Make prediction
                pred_ = classifier.predict(features, new_adj)
                pred_ = torch.softmax(pred_, -1)
                new_data.prediction = pred_.cpu()
                g = scipy.special.comb(radius,sampled_distance)/Z_partion_function
                new_data.similarity = g

                if tree.depth() == num_iteration + 1:
                    del new_data.x

                # Add the sample to the Tree
                tree.create_node(str(idx_graph), str(idx_graph), parent=lf_node.identifier, data=new_data) # The parent is the leaf node
                # map_idx_graphs[idx_graph] = new_data
                idx_graph += 1
                del new_data

    del map_idx_graphs
    return tree #, map_idx_graphs

def CRF_inference_Structure(local_tree, local_idx, sigma=0.5):
    list_children = local_tree.children(local_idx)

    if len(list_children) == 0:
        initial_output = local_tree[local_idx].data.prediction
        # initial_output = torch.softmax(initial_output, -1)
        return initial_output
    else:
        data = local_tree[local_idx].data
        prediction = data.prediction
        new_y_nominator =  sigma * prediction
        new_y_denominator = sigma

        for element in list_children:
            index_element = element.tag
            # print(index_element)
            new_data = element.data
            noised_y = CRF_inference_Structure(local_tree, index_element)
            g = new_data.similarity
            new_y_nominator = new_y_nominator + (1-sigma) * g * noised_y
            new_y_denominator = new_y_denominator + (1-sigma) * g
            result = new_y_nominator / new_y_denominator

        return result
