from typing import Dict, List
from pathlib import Path
import traceback
import json, os
import pandas as pd
import time
from .utils.process_pdf import process_pdf
from .build_graph.paper2graph import Paper2Graph, Graph
from .build_graph.extend import extend, step_into_rw_mentioned
from .build_graph.merge_graphs import connect
from .utils.logger import get_logger
from concurrent.futures import ProcessPoolExecutor
import string

def convert_to_lowercase(stri):
    # 将列表中的所有元素转换为字符串小写
    translator = str.maketrans('', '', string.punctuation)
    
    # 遍历列表中的每个元素，将其转换为字符串，去除空格和标点
    return str(stri).lower().translate(translator).replace(" ", "")


def resume_graph(resume_path):
    with open(resume_path, 'r') as f:
        resume_info = json.load(f)
    title = resume_info.get('title')
    year = resume_info.get('year')
    topic = resume_info.get('topic')
    entity = pd.read_csv(resume_info.get('entity'))
    relation = pd.read_csv(resume_info.get('relation'))

    init_graph = Graph(
        entity=entity,
        relation=relation,
        title=title,
        year=year,
        topic=topic
    )
    return init_graph
   
def graph4extend(resume_path: str, dir_path: str):
    old_path = Path(resume_path).parent

    init_graph = resume_graph(resume_path=resume_path)
    new_papers, for_step_into_details = init_graph.extend(pdf_path=dir_path / "extension", isretry=False,temp_save=True)
    assert new_papers, "Failed to download any new paper, please retry extend..."
    # updated entity and relation if new papers added
    init_graph.entity.to_csv(old_path / "entity.csv", index=False)
    init_graph.relation.to_csv(old_path / "relation.csv", index=False)
    return for_step_into_details, init_graph


#region Deprecated main
# def main(data=None, dir_path=None, isresume=False):
#     if isresume:
#         papers4stepinto, basegraph = graph4extend(
#             resume_path="/root/mypaperTKG/PaperTKG/test1/source_paper/output/Chain-of-Thought Prompting Elicits Reasoning in Large Language Models/final_result/resume.json",
#             dir_path=dir_path
#         )
#     else:
#         papers4stepinto, basegraph = build_origin_graph(source_paper_info=data, dir_path=dir_path)
#     with open('/root/mypaperTKG/PaperTKG/test1/extension/papers4stepinto.json', 'w') as f:
#         json.dump(papers4stepinto, f, indent=4)
#     add_list = []
#     os.makedirs(dir_path/"subgraphs", exist_ok=True)
#     for ind, paper in enumerate(papers4stepinto):
#         title = paper['title']
#         if not paper.get('abstract'):
#             continue
#         try:
#             pdf_info = process_pdf(dir_path/"extension"/f"{title}.pdf")
#         except Exception as e:
#             print(e)
#             continue
#         temp_path = dir_path/"subgraphs"/title
#         os.makedirs(temp_path, exist_ok=True)
#         with open(temp_path/"pdf_info.json", 'w') as f:
#             json.dump(pdf_info, f, indent=4)
#         if not (pdf_info.get('related work') and pdf_info.get('reference')):
#             print(f"Exception occurred when extract necessary information from PDF, requires human intervention. Tempfiles saved at {temp_path}")
#         else:
#             print(f"Successfully extract necessary information from PDF. Build graph for {title}")
#             # keys = ['title', 'abstract', 'date', 'isAPA', 'reference', 'related work']
#             subgraph = Paper2Graph(
#                 basic_info={
#                     "title": title,
#                     "abstract": paper['abstract'],
#                     "date": paper['year'],
#                     "isAPA": pdf_info['is_APA'],
#                     "reference": pdf_info['reference'],
#                     "related work": pdf_info['related work'],
#                     "topic": basegraph.topic
#                 },
#                 temp_data_path=dir_path/"subgraphs"
#             )
#             try:
#                 subgraph.extract(temp_save=True, isretry=False)
#                 subgraph.download(vis=False)
#                 add_list.append(title)
#             except Exception as e:
#                 print(f"Exception occurred when extracting subgraph of paper titled *{title}*: {e}")
#                 traceback.print_exc()
#                 continue
#             try:
#                 basegraph.entity = pd.concat([basegraph.entity, subgraph.entity], ignore_index=True).drop_duplicates(subset='entity name')
#                 basegraph.relation = pd.concat([basegraph.relation, subgraph.relation], ignore_index=True).drop_duplicates()
#                 basegraph.entity.to_csv(dir_path/"final_graph"/"entity.csv", index=False)
#                 basegraph.relation.to_csv(dir_path/"final_graph"/"relation.csv", index=False)
#                 print(ind)
#             except Exception as e:
#                 print(f"Cannot merge the graphs due to: {e}, requires manually merge the graphs.")
#                 traceback.print_exc()
#     print(f"successfully added the following subgraphs: {json.dumps(add_list, indent=4)}")
#endregion


def deep_search(source_paper_info: Dict, dir_path,topic,search_list,all_paper,query_list):

    keys = ['title', 'abstract', 'date', 'isAPA', 'reference', 'related work']
    _paper_info = source_paper_info
    paper_info = {k: _paper_info[k] for k in keys}
    paper_info['topic'] = topic
    for _ in range(2):
        try:
            if len(search_list)>200:
                return search_list   
            _init_graph = Paper2Graph(paper_info, dir_path/ "source_paper" / "output")
            _init_graph.extract(temp_save=True, isretry=False)
            _init_graph.download(vis=False) # visualization is temporarily deprecated... remains update


            # try to extend the graph

            init_graph = Graph(_init_graph.entity, _init_graph.relation, _init_graph.title, _init_graph.timepoint, _init_graph.topic)
            topic = init_graph.topic
            problem_list = init_graph.problem
            year = init_graph.timepoint
            query_list+=[convert_to_lowercase(f"{topic} {problem}")  for problem in problem_list]

            if year<2023:
                _new_paper_details = extend(topic=topic, problem=problem_list, start_year=year+1,search_list=all_paper,query_list=query_list)
                search_list += _new_paper_details
                for i in range(len(_new_paper_details)):
                    if convert_to_lowercase(_new_paper_details[i]["title"]) in all_paper:
                        continue
                    all_paper.append(convert_to_lowercase(_new_paper_details[i]["title"]))
            elif year==2023:
                _new_paper_details = extend(topic=topic, problem=problem_list, start_year=year,search_list=all_paper,query_list=query_list)
                search_list += _new_paper_details
                for i in range(len(_new_paper_details)):
                    if convert_to_lowercase(_new_paper_details[i]["title"]) in all_paper:
                        continue
                    all_paper.append(convert_to_lowercase(_new_paper_details[i]["title"]))


            
            year+=1
            if year<=2023 and len(_new_paper_details)!=0:
                for i in range(len(_new_paper_details)):
                    if convert_to_lowercase(_new_paper_details[i]["title"]) in all_paper:
                        continue
                    all_paper.append(convert_to_lowercase(_new_paper_details[i]["title"]))
                    search_list+=deep_search(_new_paper_details[i], dir_path,topic,search_list,all_paper,query_list=query_list)
            else:
                return search_list
        except:
            print("fail to deep search")
            return search_list
        
def main(source_paper_info: Dict, dir_path) -> List[str]:
    # paths
    translator = str.maketrans('', '', string.punctuation)
    all_paper=[]
    search_list=[]
    query_list=[]
    save_path = dir_path/"final_graph"
    os.makedirs(save_path, exist_ok=True)
    ir_dir = save_path/"intermediate_result"
    os.makedirs(ir_dir, exist_ok=True)
    # utils
    start_time = time.time()
    logger = get_logger("run_test_incontextlearning", dir_path)

    topic = source_paper_info['topic']
    keys = ['title', 'abstract', 'date', 'isAPA', 'reference', 'related work']
    _paper_info = source_paper_info['source_paper']
    paper_info = {k: _paper_info[k] for k in keys}
    paper_info['topic'] = topic
    
    logger.info(f"Extract initial graph for the topic *{topic}*")
    _init_graph = Paper2Graph(paper_info, dir_path/ "source_paper" / "output")
    _init_graph.extract(temp_save=True, isretry=False)
    _init_graph.download(vis=False) # visualization is temporarily deprecated... remains update
    checkp1 = time.time()
    logger.info(f"Successfully built initial graph. Detailed information:\n{json.dumps(_init_graph.info, indent=4)}")
    logger.info(f"It takes {checkp1-start_time}s to build the initial graph.")

    # try to extend the graph
    logger.info(f"Extend the initial graph but adding more paper nodes.")
    init_graph = Graph(_init_graph.entity, _init_graph.relation, _init_graph.title, _init_graph.timepoint, _init_graph.topic)
    topic = init_graph.topic
    problem_list = init_graph.problem
    year = init_graph.timepoint
    all_paper.append(convert_to_lowercase(paper_info["title"]))
    query_list+=[convert_to_lowercase(f"{topic} {problem}")  for problem in problem_list]
    # MEMO: you can replace with your search strategy, please put the search method in build_graph/extend.py
    
    # the input can be topic and problem list, the output should be a list including units like this
    # unit = {
    #     "title": title,
    #     "date": the publication year of the paper,
    #     "isAPA": True if the citations of the paper is in APA format
    #     "abstract": abstract,
    #     "related work": related work content (string type),
    #     "reference": a list including all the 
    # references,
    #     "keywords": a list including keywords (i.e. topic or problems)
    # },
    # which means you should process the PDF or html file
    
    # Current search strategy using both topic and problem as keywords
    
    _new_paper_details = extend(topic=topic, problem=problem_list, start_year=year,search_list=all_paper,query_list=query_list)
    checkp2 = time.time()
    logger.info(f"get {len(_new_paper_details)} new papers for stepping into:\n{json.dumps([unit['title'] for unit in _new_paper_details])}")
    logger.info(f"It takes {checkp2-start_time}s to get more related papers.")

    # updated_entity, updated_relation = connect(init_graph, _new_paper_details)
    if _new_paper_details:
        for i in range(len(_new_paper_details)):
            if _new_paper_details[i]["title"] in all_paper:
                continue
            all_paper.append(convert_to_lowercase(_new_paper_details[i]["title"]))
            for  _ in range(2):
                try:
                    _new_paper_details+=deep_search(_new_paper_details[i], dir_path,topic,search_list,all_paper,query_list)
                    break
                except:
                    print("GPT error")



    updated_entity, updated_relation = connect(init_graph, _new_paper_details)
    # save intermediate result for debugging
    logger.info(f"Updated entities and relations are saved at {str(ir_dir)}")
    updated_entity.to_csv(ir_dir/"updated_entity.csv", index=False)
    updated_relation.to_csv(ir_dir/"updated_relation.csv", index=False)

    logger.info(f"Total number of entities after connection: {len(updated_entity)}\nTotal number of relations after connection: {len(updated_relation)}")
    logger.info(f"Get graph construction information for papers mentioned in related work.")
    rw_paper_details = step_into_rw_mentioned(init_graph)
    logger.info(f"{len(rw_paper_details)} papers' information got.")

    # build subgraphs for all the papers got
    new_paper_details = [{k: p.get(k) for k in keys} for p in _new_paper_details]
    paper_details = pd.DataFrame(new_paper_details+rw_paper_details).drop_duplicates(subset='title', keep="first").to_dict(orient='records')

    # save paper details of both new ones and related work mentioned ones
    logger.info(f"Paper details for stepping into are saved at {str(ir_dir)}")
    with open(ir_dir/"paper_details.json", 'w') as f:
        json.dump(paper_details, f, indent=4)

    subgraph_consturct_time = []
    for p_details in paper_details:
        p_details['topic'] = topic
        subgraph_start_time = time.time()
        try:
            subgraph = Paper2Graph(
                basic_info=p_details,
                temp_data_path=dir_path/ "subgraphs"
            )
            subgraph.extract(temp_save=True, isretry=False)
            subgraph.download()

            updated_entity = pd.concat([updated_entity, subgraph.entity], ignore_index=True)
            updated_relation = pd.concat([updated_relation, subgraph.relation], ignore_index=True)
            subgraph_checkp = time.time()
            subgraph_consturct_time.append(subgraph_checkp-subgraph_start_time)
            logger.info(f"Successfully add subgraph extracted from paper titled *{p_details['title']}*. It takes {subgraph_checkp-subgraph_start_time}s to build the graph.")
        except Exception as e:
            logger.info(f"Failed to build graph for paper titled *{p_details['title']}*: {e}")
            traceback.print_exc()
            continue
    logger.info(f"Finish steppping into papers.")
    if len(subgraph_consturct_time)!=0:
        average_time = sum(subgraph_consturct_time)/len(subgraph_consturct_time)
    logger.info(f"Successfully constructed {len(subgraph_consturct_time)} graphs. The average subgraph construction time is {round(average_time, 2)}s.")

    # save final results
    updated_entity.to_csv(save_path/"entity.csv", index=False)
    updated_relation.to_csv(save_path/"relation.csv", index=False)
    end_time = time.time()
    logger.info(f"Successfully built final graph. It took {round(end_time-start_time, 2)}s in total.")
    print(f"****************Successfully built final graph******************\n!!!!!!!all_paper:{len(all_paper)}!!!!!!\n$$$$$$$$search_list:{len(search_list)}$$$$$$$$")

def process_mul_data(i, data):
    dir_path = Path(__file__).parent / f"test{i}"
    if i!=0:
        os.makedirs(dir_path, exist_ok=True)
        print("satrt")
        main(source_paper_info=data, dir_path=dir_path)
        print("fail")

def test_openai_connection():
    from .utils.openai import openai_call
    test = openai_call([{"role": "user", "content": "Are you ready?"}])
    print("--------------------------test----------------------------")
    print("Are you ready?")
    print(test)
    print("----------------------------------------------------------")  
        

if __name__ == '__main__':
    test_openai_connection()
    with open('/home/xinwang/PaperTKG/new_source.json', 'r') as f:
        datas = json.load(f)
    for i in range(len(datas)):
        if i>1 and i<=2:
            data=datas[i]
            dir_path = Path(__file__).parent / f"test{i}"
            os.makedirs(dir_path, exist_ok=True)
            main(source_paper_info=data, dir_path=dir_path)


# if __name__ == '__main__':
#     # test_openai_connection()
#     with open('/home/xinwang/PaperTKG/new_source.json', 'r') as f:
#         datas = json.load(f)

#     #使用 ProcessPoolExecutor 进行并行计算
#     with ProcessPoolExecutor() as executor:
#         # 提交所有任务到执行器
#         executor.map(process_mul_data, range(len(datas)), datas)