import networkx as nx
from collections import deque


def build_graph(graph: list, undirected = False) -> nx.DiGraph | nx.Graph:
    if undirected:
        G = nx.MultiGraph()
    else:
        G = nx.MultiDiGraph()
    for triplet in graph:
        h, r, t = triplet
        G.add_edge(h.strip(), t.strip(), relation=r.strip())
        G.add_edge(t.strip(), h.strip(), relation="inverse of: " + r.strip())
    return G


def dfs(graph, start_node_list, max_length):
    """
    Find all paths within max_length starting from start_node_list in graph using DFS.

    Args:
        graph (nx.DiGraph): Directed graph
        start_node (List[str]): A list of start nodes
        max_length (int): Maximum length of path

    Returns:
        List[List[tuple]]: Find paths
    """
    def dfs_visit(node, path):
        if len(path) > max_length:
            return
        try:
            for neighbor in graph.neighbors(node):
                for neigh_rel in graph[node][neighbor].values():
                    rel = neigh_rel['relation']
                    new_path = path + [(node, rel, neighbor)]
                    if len(new_path) <= max_length:
                        path_lists.add(tuple(new_path))
                    dfs_visit(neighbor, new_path)
        except Exception as e:
            print(e)
            pass

    path_lists = set()
    for start_node in start_node_list:
        dfs_visit(start_node, [])

    return list(path_lists)


# 定义一个函数来进行宽度优先搜索
def bfs_with_rule(graph, start_node, target_rule, max_p=10):
    result_paths = []
    queue = deque([(start_node, [])])  # 使用队列存储待探索节点和对应路径
    while queue:
        current_node, current_path = queue.popleft()

        # 如果当前路径符合规则，将其添加到结果列表中
        if len(current_path) == len(target_rule):
            result_paths.append(current_path)
            # if len(result_paths) >= max_p:
            #     break

        # 如果当前路径长度小于规则长度，继续探索
        if len(current_path) < len(target_rule):
            if current_node not in graph:
                continue
            for neighbor in graph.neighbors(current_node):
                # 剪枝：如果当前边类型与规则中的对应位置不匹配，不继续探索该路径
                for neigh_rel in graph[current_node][neighbor].values():
                    rel = neigh_rel["relation"]
                    if rel != target_rule[len(current_path)] or len(current_path) > len(
                        target_rule
                    ):
                        continue
                    queue.append((neighbor, current_path + [(current_node, rel, neighbor)]))

    return result_paths


def get_truth_paths(q_entity: list, a_entity: list, graph: nx.Graph) -> list:
    """
    Get shortest paths connecting question and answer entities.
    """
    # Select paths
    paths = []
    for h in q_entity:
        if h not in graph:
            continue
        for t in a_entity:
            if t not in graph:
                continue
            try:
                for p in nx.all_shortest_paths(graph, h, t):
                    paths.append(p)
            except:
                pass
    # Add relation to paths
    result_paths = []
    # for p in paths:
    #     tmp = []
    #     for i in range(len(p) - 1):
    #         u = p[i]
    #         v = p[i + 1]
    #         tmp.append((u, graph[u][v]["relation"], v))
    #     result_paths.append(tmp)
    for path in paths:
        num_triples_path = len(path) - 1
        tmp = []

        for i in range(num_triples_path):
            u = path[i]
            v = path[i + 1]
            new_tmp = []
            for connect in graph[u][v].values():
                # triple_id_i = connect["triple_id"]
                # r_id_i = relation_types[triple_id_i].item()
                # ri = connect["relation"]
                if len(tmp) == 0:
                    new_tmp.append([[u, connect["relation"], v]])
                else:
                    for p in tmp:
                        new_tmp.append(p + [[u, connect["relation"], v]])
            tmp = new_tmp
        
        result_paths.extend(tmp)
    return result_paths


def get_simple_paths(q_entity: list, a_entity: list, graph: nx.Graph, hop=2) -> list:
    """
    Get all simple paths connecting question and answer entities within given hop
    """
    # Select paths
    paths = []
    for h in q_entity:
        if h not in graph:
            continue
        for t in a_entity:
            if t not in graph:
                continue
            try:
                for p in nx.all_simple_edge_paths(graph, h, t, cutoff=hop):
                    paths.append(p)
            except:
                pass
    # Add relation to paths
    result_paths = []
    for p in paths:
        result_paths.append([(e[0], graph[e[0]][e[1]]["relation"], e[1]) for e in p])
    return result_paths


def get_negative_paths(
    q_entity: list, a_entity: list, graph: nx.Graph, n_neg: int, hop=2
) -> list:
    """
    Get negative paths for question witin hop
    """
    import walker

    # sample paths
    start_nodes = []
    end_nodes = []
    node_idx = list(graph.nodes())
    for h in q_entity:
        if h in graph:
            start_nodes.append(node_idx.index(h))
    for t in a_entity:
        if t in graph:
            end_nodes.append(node_idx.index(t))
    paths = walker.random_walks(
        graph, n_walks=n_neg, walk_len=hop, start_nodes=start_nodes, verbose=False
    )
    # Add relation to paths
    result_paths = []
    for p in paths:
        tmp = []
        # remove paths that end with answer entity
        if p[-1] in end_nodes:
            continue
        for i in range(len(p) - 1):
            u = node_idx[p[i]]
            v = node_idx[p[i + 1]]
            tmp.append((u, graph[u][v]["relation"], v))
        result_paths.append(tmp)
    return result_paths


def get_random_paths(q_entity: list, graph: nx.Graph, n=3, hop=2) -> tuple[list, list]:
    """
    Get negative paths for question witin hop
    """
    import walker

    # sample paths
    start_nodes = []
    node_idx = list(graph.nodes())
    for h in q_entity:
        if h in graph:
            start_nodes.append(node_idx.index(h))
    paths = walker.random_walks(
        graph, n_walks=n, walk_len=hop, start_nodes=start_nodes, verbose=False
    )
    # Add relation to paths
    result_paths = []
    rules = []
    for p in paths:
        tmp = []
        tmp_rule = []
        for i in range(len(p) - 1):
            u = node_idx[p[i]]
            v = node_idx[p[i + 1]]
            tmp.append((u, graph[u][v]["relation"], v))
            tmp_rule.append(graph[u][v]["relation"])
        result_paths.append(tmp)
        rules.append(tmp_rule)
    return result_paths, rules
