import json
from .utils.constants import INPUT_PATH, DIR_PATH
from .build_origin_graph import AbstractGraphBuilder, RelatedworkGraphBuilder, MergeGraph
from .extend_graph import GraphExtender

# from utils.constants import INPUT_PATH, DIR_PATH
# from build_origin_graph import AbstractGraphBuilder, RelatedworkGraphBuilder, MergeGraph
# from extend_graph import GraphExtender


FILE_PATH = DIR_PATH / "demo"

def load_inputs(input_path: str) -> tuple:
    with open(input_path, 'r') as f:
        data = json.load(f)
    title = data['paper title']
    abstract = data['abstract']
    related_work = data['related work']
    reference = data['reference']
    return title, abstract, related_work, reference

def visualizer(entity, relation):
    def get_color(entity_type):
        color_map = {
            "problem": "#264653",
            "paper": "#2A9D8F",
            "method": "#E9C46A",
            "domain": "#F4A261"
        }
        return color_map.get(entity_type, "gray")
    
    import networkx as nx
    from pyvis.network import Network

    G = nx.Graph()

    for e in entity:
        node = e["entity name"]
        G.add_node(node, 
                entity_type=e.get("entity type"),
                timestamp=e.get("timestamp"),
                description=e.get("description"))

    for r in relation:
        entity1 = r["entity1"]
        entity2 = r["entity2"]
        rel = r["relation"]
        G.add_edge(entity1, entity2, relation=rel)

    net = Network(height="750px", width="100%", directed=True)

    for node, data in G.nodes(data=True):
        net.add_node(node, label=node, title=str(data), color=get_color(data.get("entity_type")))

    for source, target, data in G.edges(data=True):
        net.add_edge(source, target, title=data['relation'], arrowStrikethrough=False)

    net.toggle_physics(True) 
    net.show_buttons(filter_=['physics'])  

    net.write_html(str(FILE_PATH / "visualization" / "knowledge_graph.html"))
    # net.show(str(OUTPUT_PATH / "visualization" / "knowledge_graph.html"))

def test_openai_connection():
    from .utils.openai import openai_call
    test = openai_call([{"role": "user", "content": "Are you ready?"}])
    print("--------------------------test----------------------------")
    print("Are you ready?")
    print(test)
    print("----------------------------------------------------------")

def main():
    title, abstract, related_work, reference = load_inputs(INPUT_PATH / "raw_content.json")
    abs_graph = AbstractGraphBuilder(title=title, abstract=abstract)
    success_build_graph_abs = abs_graph.build()
    rw_graph = RelatedworkGraphBuilder(title=title, related_work=related_work, reference=reference)
    success_build_graph_rw = rw_graph.build()

    # TODO: when the graph can only be partly built... to be robust
    if success_build_graph_abs and success_build_graph_rw:
        graph = MergeGraph(abs_graph, rw_graph)
        graph.smooth()
        extender = GraphExtender(graph.entity, graph.relation, title)
        extender.extend()
        
        visualizer(entity=extender.entity.to_dict(orient='records'), relation=extender.relation.to_dict(orient='records'))
        
def temp():
    import pandas as pd
    title, abstract, related_work, reference = load_inputs(INPUT_PATH / "raw_content.json")
    abs_graph = AbstractGraphBuilder(title=title, abstract=abstract)
    abs_graph.entity = pd.read_csv('/root/mypaperTKG/PaperTKG/build_graph/outputs/tempfile/abstract_entity_final.csv')
    abs_graph.relation = pd.read_csv('/root/mypaperTKG/PaperTKG/build_graph/outputs/tempfile/abstract_relation_final.csv')
    rw_graph = RelatedworkGraphBuilder(title=title, related_work=related_work, reference=reference)
    rw_graph.entity = pd.read_csv('/root/mypaperTKG/PaperTKG/build_graph/outputs/tempfile/relatedwork_entity_final.csv')
    rw_graph.relation = pd.read_csv('/root/mypaperTKG/PaperTKG/build_graph/outputs/tempfile/relatedwork_relation_final.csv')
    
    graph = MergeGraph(abs_graph, rw_graph)
    graph.smooth()
    extender = GraphExtender(graph.entity, graph.relation, title)
    extender.extend()
    
    visualizer(entity=extender.entity.to_dict(orient='records'), relation=extender.relation.to_dict(orient='records'))
        
def temp2():
    import pandas as pd
    entity = pd.read_csv(FILE_PATH / "entity.csv").to_dict(orient='records')
    relation = pd.read_csv(FILE_PATH / "relation.csv").to_dict(orient='records')
    visualizer(entity=entity, relation=relation)
    
        
if __name__ == '__main__':
    test_openai_connection()
    main()
    # temp()
    # temp2()
    
