
import torch
import torch.nn as nn
import numpy as np
from treelib import Node, Tree
import sys
from torch_geometric.data import Data

def sample_radius(epsilon, d) :
    return epsilon * ((np.random.uniform(0,1))**(1/(d-1)))


cos_sim = nn.CosineSimilarity()


def predict_from_model(model_local, data_local, adj_normalized, device):

    model_local.eval()
    data_local = data_local.to(device)

    with torch.no_grad():
        ground_truth = model_local(data_local.x, adj_normalized)
        ground_truth = torch.softmax(ground_truth, -1)

    return ground_truth

def Build_CRF_Tree(model, data, adj_normalized, radius, num_samples, num_iteration, device):
    # Idx of graphs
    idx_graph = 0
    num_nodes = data.x.size(0)
    d = data.x.size(-1)

    # Mapping Id to graphs
    map_idx_graphs = {}

    tree = Tree()
    data = data.to(device)
    # Make prediction
    # new_data = data.clone()
    new_data = Data(x = data.x.clone())
    pred_ = predict_from_model(model, new_data, adj_normalized, device)

    new_data.prediction = pred_
    # Create the root node
    new_data.prediction = pred_.cpu()
    new_data.x = new_data.x.cpu()

    tree.create_node(str(idx_graph), str(idx_graph), data=new_data)
    # map_idx_graphs[idx_graph] = data
    idx_graph += 1

    # Generate neighboors_inputs
    for t_ in range(num_samples):
        # Sample in the Neighorhood of the graph
        sampled_distance = sample_radius(radius, d)
        i_0 = np.random.randint(num_nodes)
        r = [[sample_radius(sampled_distance, d)] if k!= i_0  else [sampled_distance] for k in range(num_nodes)  ]
        r = np.array(r)
        r = np.repeat(r, d, axis=1)
        u = 2*np.random.randint(0,2, num_nodes*d).reshape(num_nodes, d) - 1
        r = u * r
        mass = np.random.uniform(0,1,num_nodes *(d-1)).reshape(num_nodes, d -1)
        mass = np.sort(mass)
        mass = np.concatenate([np.zeros((num_nodes,1)), mass, np.ones((num_nodes,1))], axis = 1)
        mass = mass[:,1:] - mass[:,:-1]
        Z = np.sqrt(mass) * r

        # Create the new sample
        # new_data = Data(data.clone())
        new_data = Data()

        Z =  torch.tensor(Z).to(device)
        noised_x = data.x + Z
        new_data.x = noised_x.float()

        # Make prediction
        pred_ = predict_from_model(model, new_data, adj_normalized, device)

        new_data.prediction = pred_.cpu()
        new_data.x = new_data.x.cpu()

        g = cos_sim(new_data.x.cpu().flatten().unsqueeze(0), data.x.cpu().flatten().unsqueeze(0))
        new_data.similarity = g

        # del new_data.x
        # Add the sample to the Tree
        tree.create_node(str(idx_graph), str(idx_graph), parent=str(0), data=new_data)

        # parent=str([item[0] for item in map_idx_graphs.items() if item[1]==data][0])

        # map_idx_graphs.items()

        # map_idx_graphs[idx_graph] = new_data
        idx_graph += 1
        del new_data, mass, pred_, Z

    while tree.depth() <= num_iteration :

        leaf_nodes = list(tree.filter_nodes(lambda x: tree.depth(x) == tree.depth()))
        for lf_node in leaf_nodes :
            # data = map_idx_graphs[int(lf_node.identifier)]
            x_data = lf_node.data.x
            # Generate neighboors_inputs
            for t_ in range(num_samples) :
                # print(tree.depth(), num_iteration)
                # Sample in the Neighorhood of the graph
                sampled_distance = sample_radius(radius, d)
                i_0 = np.random.randint(num_nodes)
                r = [[sample_radius(sampled_distance, d)] if k!= i_0  else [sampled_distance] for k in range(num_nodes)  ]
                r = np.array(r)
                r = np.repeat(r, d, axis=1)
                u = 2*np.random.randint(0,2, num_nodes*d).reshape(num_nodes ,d) - 1
                r = u * r
                mass = np.random.uniform(0,1,num_nodes *(d-1) ).reshape(num_nodes ,d -1)
                mass = np.sort(mass)
                mass = np.concatenate([np.zeros((num_nodes,1)),mass, np.ones((num_nodes,1))],axis = 1)
                mass = mass[:,1:] - mass[:,:-1]
                Z = np.sqrt(mass)*r

                # Create the new sample
                # new_data = data.clone()
                new_data = Data()
                Z =  torch.tensor(Z).cpu() 
                noised_x = x_data + Z

                new_data.x = noised_x.float()


                # Make prediction
                pred_ = predict_from_model(model, new_data, adj_normalized, device)
                new_data.prediction = pred_.cpu()
                new_data.x = new_data.x.cpu()

                g = cos_sim(new_data.x.cpu().flatten().unsqueeze(0), data.x.cpu().flatten().unsqueeze(0))
                new_data.similarity = g

                if tree.depth() == num_iteration + 1:
                    del new_data.x

                # Add the sample to the Tree
                tree.create_node(str(idx_graph), str(idx_graph), parent=lf_node.identifier, data=new_data) # The parent is the leaf node
                # map_idx_graphs[idx_graph] = new_data
                idx_graph += 1
                del new_data, mass, pred_, Z

    del map_idx_graphs
    return tree #, map_idx_graphs

def CRF_inference(local_tree, local_idx, sigma=0.5):
    list_children = local_tree.children(local_idx)

    if len(list_children) == 0:
        initial_output = local_tree[local_idx].data.prediction
        # initial_output = torch.softmax(initial_output, -1)
        return initial_output
    else:
        data = local_tree[local_idx].data
        prediction = data.prediction
        new_y_nominator =  sigma * prediction
        new_y_denominator = sigma

        for element in list_children:
            index_element = element.tag
            # print(index_element)
            new_data = element.data
            noised_y = CRF_inference(local_tree, index_element)
            g = new_data.similarity
            new_y_nominator = new_y_nominator + (1-sigma) * g * noised_y
            new_y_denominator = new_y_denominator + (1-sigma) * g
            result = new_y_nominator / new_y_denominator

        return result
