import torch
import numpy as np
import numba as nb
import random
from entropy.partitionTree import PartitionTree, PartitionTreeNode
from models.utils import cosine_sim
def get_community(code_tree: PartitionTree):
    node_dict = code_tree.tree_node
    root_id = code_tree.root_id
    tree_node_num = max(node_dict.keys()) + 1
    isleaf = torch.zeros(tree_node_num, dtype=torch.bool)
    stack = [root_id]
    community = []
    while stack:
        node_id = stack.pop()
        child = node_dict[node_id].children
        if child is None:
            isleaf[node_id] = True
            parent = node_dict[node_id].parent
            if node_dict[parent].child_h > 1:
                community.append(node_id)
            continue
        # if node_dict[node_id].child_h == 1:
        community.append(node_id)
        for e in child[::-1]:
            stack.append(e)   
    return community, isleaf

def get_layer_community(code_tree: PartitionTree):
    node_dict = code_tree.tree_node
    root_id = code_tree.root_id
    stack = [root_id]
    layer_community = []
    while stack:
        node_id = stack.pop()
        child = node_dict[node_id].children
        if child is None:
            parent = node_dict[node_id].parent
            if node_dict[parent].child_h > 1:
                layer_community.append(node_id)
            continue
        if node_dict[node_id].child_h == 1:
            layer_community.append(node_id)
        for e in child[::-1]:
            stack.append(e)   
    return layer_community

def get_sedict(community: list, code_tree: PartitionTree):
    node_dict = code_tree.tree_node
    se_dict = {}
    for community_id in community:
        if node_dict[community_id].children is None:
            continue
        node_list = node_dict[community_id].children
        se = torch.zeros(len(node_list))
        for i, e in enumerate(node_list):
            e = node_dict[e]
            e: PartitionTreeNode
            se[i] = -(e.g / code_tree.VOL) * torch.log2(torch.tensor((e.vol + 1) / (node_dict[e.parent].vol + 1)))
                # code_tree.deduct_se(community_id, None)
        se = torch.softmax(se.float(), dim=0)
        se_dict[community_id] = se
    # print(np.mean(se_list))
    return se_dict

# node sampling
def select_node1(node_id: int, code_tree: PartitionTree, isleaf: torch.Tensor, se_dict):
    node_dict = code_tree.tree_node
    while not isleaf[node_id]:
        node_list = list(node_dict[node_id].children)
        if len(node_list) > 1:
            se = se_dict[node_id]
            id = torch.multinomial(se, num_samples=1, replacement=False)
            node_id = node_list[id]
        else:
            node_id = node_list[0]
    return node_id

# edge sampling
def select_node2(node1: int, code_tree: PartitionTree, llm_node_x, org_node_x, com_x, extend_t):
    node_dict = code_tree.tree_node
    parent_id = node_dict[node1].parent
    if node_dict[node1].sim_list is None:
        node_list = node_dict[parent_id].partition.copy()
        node_list.remove(node1)
        node_list_len = len(node_dict[parent_id].partition)
        # node_list = []
        alpha = 0.85 
        if node_list_len < extend_t:
            gradparent = node_dict[parent_id].parent
            community_list = node_dict[gradparent].children.copy()
            # community_list = community.copy()
            if parent_id in community_list:
                community_list.remove(parent_id)
            # for node in node_dict[parent_id].partition:
            #     community_list.remove(node)
                     
            com_x_list = []
            for com in community_list:
                sim1 = cosine_sim(com_x[com][0], llm_node_x[node1])
                sim2 = cosine_sim(com_x[com][1], org_node_x[node1])
                com_x_list.append((sim1+alpha*sim2)/(1+alpha))

            com_sim_indices = np.argsort(com_x_list)[::-1]
            i = 0
            while i < len(com_sim_indices) and len(node_list) < extend_t:
                com = com_sim_indices[i]
                node_list.extend(node_dict[com].partition)
                i += 1
                
            new_node_list = []
            sim_list = []
            for node in node_list:
                sim1 = cosine_sim(llm_node_x[node], llm_node_x[node1])
                sim2 = cosine_sim(org_node_x[node], org_node_x[node1])
                sim = (sim1+alpha*sim2)/(1+alpha)
                if sim > 0.75:
                    new_node_list.append(node)
                    sim_list.append(sim)
            node_list = new_node_list
        else:
            sim_list = []
            new_node_list = []
            for node in node_list:
                sim1 = cosine_sim(llm_node_x[node], llm_node_x[node1])
                sim2 = cosine_sim(org_node_x[node], org_node_x[node1])
                sim = (sim1+alpha*sim2)/(1+alpha)
                if sim > 0.75:
                    new_node_list.append(node) 
                    sim_list.append(sim)
            node_list = new_node_list

        sim_list = torch.tensor(sim_list)
        sim_list = torch.softmax(sim_list.float(), dim=0)
        sim_list[torch.isinf(sim_list)] = 0.
        sim_list[torch.isnan(sim_list)] = 0.
        node_dict[node1].node_list = node_list
        node_dict[node1].sim_list = sim_list
    else:
        node_list = node_dict[node1].node_list
        sim_list = node_dict[node1].sim_list
    if len(sim_list) == 0:
        return None
    id = torch.multinomial(sim_list, num_samples=1, replacement=False)
    node2 = node_list[id]
    return node2

# reconstruct the graph via sampling
def llm_reshape(community, code_tree, isleaf, theta, llm_node_x, org_node_x, com_x, extend_t):
    node_dict = code_tree.tree_node
    se_dict = get_sedict(community, code_tree)

    edge_index = []
    k = 0
    n = 0
    for community_id in community:
        if node_dict[community_id].child_h > 1:
            continue
        if node_dict[community_id].children is not None:
            node_list = node_dict[community_id].children
        else:
            node_list = node_dict[community_id].partition
        for i in range(len(node_list)):
            if len(node_list) == 1:
                node1 = node_list[0]
            else:
                node1 = select_node1(community_id, code_tree, isleaf, se_dict)
            for i in range(theta):
                node2 = select_node2(node1, code_tree, llm_node_x, org_node_x, com_x, extend_t)
                if node2 is not None:
                    edge_index.append([node1, node2])
    edge_index = torch.tensor(edge_index)
    return edge_index