import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
import json
import pandas as pd
from tqdm import tqdm
import re
from typing import List, Dict, Annotated
# from .base import GraphBuilder
# from .utils.openai import openai_call
# from .utils.websearch import title_search
# from .utils.logger import get_logger
# from .utils.constants import DIR_PATH
# from .utils.prompts import \
#     EXTRACT_ABSTRACT_ENTITY, EXTRACT_ABSTRACT_RELATION,\
#     PARSE_REF_SYS, PARSE_REF_USER,\
#     EXTRACT_REF_SYS, EXTRACT_REF_USER,\
#     EXTRACT_RELATED_WORK_ENTITY, EXTRACT_RELATED_WORK_RELATION,\
#     FIND_SIMILAR_ENTITY, MERGE_SIMILAR_ENTITY

from base import GraphBuilder
from utils.openai import openai_call
from utils.websearch import title_search
from utils.logger import get_logger
from utils.constants import DIR_PATH
from utils.prompts import \
    EXTRACT_ABSTRACT_ENTITY, EXTRACT_ABSTRACT_RELATION,\
    PARSE_REF_SYS, PARSE_REF_USER,\
    EXTRACT_REF_SYS, EXTRACT_REF_USER,\
    EXTRACT_RELATED_WORK_ENTITY, EXTRACT_RELATED_WORK_RELATION,\
    FIND_SIMILAR_ENTITY, MERGE_SIMILAR_ENTITY

BOLD = '\033[1m'
RESET = '\033[0m'
FILE_PATH = DIR_PATH / "demo"

class AbstractGraphBuilder(GraphBuilder):
    def __init__(self, title: str, abstract: str) -> None:
        self.title = title
        self.abstract = abstract
        
    def preprocess(self):
        self.logger.info(f"Extracting entities and relations from {BOLD}Abstract{RESET} part...")
        packed_abs = f"TEXT:\n- paper title: {self.title}\n- abstract: {self.abstract}"
        return packed_abs
    
    def extractor(self) -> bool:
        packed_abs = self.preprocess()
        mess = [
            {"role": "system", "content": EXTRACT_ABSTRACT_ENTITY},
            {"role": "user", "content": packed_abs}
        ]
        entity = self.get_json_chat_completion(messages=mess)
        if not entity:
            self.logger.warning("Potential Risk: The extraction of the abstract part has failed.")
            return
        else:
            # check format
            if len(entity) == 1:
                try:
                    entity = list(entity.values())[0]
                except:
                    self.logger.warning(f"wrong result format: {json.dumps(entity)}")
            # check the self paper
            self.entity = pd.DataFrame(entity)
            if self.entity[self.entity['entity name'] == self.title].empty:
                self.entity.loc[len(self.entity)] = {
                    'entity name': self.title,
                    'entity type': 'paper',
                    'timestamp': None,
                    'description': None
                }

            # save intermediate results
            # with open(OUTPUT_PATH / "tempfile" / "abstract_entity.json", 'w') as f:
            #     json.dump(entity, f, indent=4)
                
            mess2 = [
                {"role": "system", "content": EXTRACT_ABSTRACT_RELATION},
                {"role": "user", "content": packed_abs + f"\nENTITIES:\n{json.dumps(entity)}"}
            ]
            relation = self.get_json_chat_completion(messages=mess2)
            if not relation:
                self.logger.warning("Potential Risk: The extraction of the abstract part has failed.")
                return
            else:
                if len(relation) == 1:
                    try:
                        relation = list(relation.values())[0]
                    except:
                        self.logger.warning(f"wrong result format: {json.dumps(relation)}")
                # save intermediate results
                # with open(OUTPUT_PATH / "tempfile" / "abstract_relation.json", 'w') as f:
                #     json.dump(relation, f, indent=4)
                self.logger.info("==== sucessfully extract entities and relations from the abstract part. ====")
                
                
                self.relation = pd.DataFrame(relation)
                
                return True
            
    def build(self):
        issuccess = self.extractor()
        if issuccess:
            isimproved = self.imporve(content=self.abstract)
            self.checker(content=self.abstract)
            improve_status = "Success" if isimproved else "Failed"
            graph_info = f'''Original graph construction information:
                    ** {len(self.entity)} entities extracted **
                        - paper entity: {(self.entity['entity type'] == 'paper').sum()}
                        - method entity: {(self.entity['entity type'] == 'method').sum()}
                        - problem entity: {(self.entity['entity type'] == 'problem').sum()}
                        - domain entity: {(self.entity['entity type'] == 'domain').sum()}
                    ** {len(self.relation)} relations extracted **
                    ** Graph optimization status: {improve_status} **
                    ** Checked the graph **
                        - removed duplicates
                        - aligned the entity names in the entity list and the relation list
                        - removed discrete entities'''
            print(graph_info)
            with open(FILE_PATH / "abs_subgraph_info.txt", 'w') as f:
                f.write(graph_info)
            print("--------------------------------------------------------------")
            self.entity.to_csv(FILE_PATH / "abstract_entity_final.csv", index=False)
            self.relation.to_csv(FILE_PATH / "abstract_relation_final.csv", index=False)
            return True


class RelatedworkGraphBuilder(GraphBuilder):
    def __init__(self, title: str, related_work: List[Dict], reference: List) -> None:
        self.title = title
        self.related_work = related_work
        self.rw_units = self.preprocess(reference=reference)
    
    def preprocess(self, reference):
        self.logger.info(f"Extracting entities and relations from {BOLD}Related Work{RESET} part...")
        ref_df = pd.DataFrame(columns=['author', 'date', 'title', 'origin'])
        # ref_df = pd.read_csv(OUTPUT_PATH/ "parsed_reference.csv")
        processed_rw = pd.DataFrame(columns=['theme', 'content', 'reference'])
        
        def parse_reference():
            for r in tqdm(reference, desc='Parse reference'):
                m = [
                    {"role":"system", "content": PARSE_REF_SYS},
                    {"role": "user", "content": PARSE_REF_USER.format(ref=r)}
                ]
                parsed_r = self.get_json_chat_completion(messages=m)
                if not parsed_r:
                    self.logger.warning(f"failed to parse the reference: {r}.")
                else:
                    parsed_r['origin'] = r
                    ref_df.loc[len(ref_df)] = parsed_r
            # save intermediate results
            ref_df.to_csv(FILE_PATH / "parsed_reference.csv", index=False)          
        
        def match_citation():
            self.mentioned_ref = []
            for unit in self.related_work:
                para_title: str = unit['subtitle']
                para_content: List = unit['paragraphs']
                # extract citations
                for para in para_content:
                    para_matched_ref = []
                    m = [
                        {"role": "system", 
                        "content": EXTRACT_REF_SYS},
                        {"role": "user",
                        "content": EXTRACT_REF_USER.format(content=para)}
                    ]
                    para_citations = self.get_json_chat_completion(m)
                    # check format
                    if isinstance(para_citations, list):
                        if len(para_citations) == 1: # [{'result': [...]}]
                            try:
                                para_citations = para_citations[0].values()[0]
                            except:
                                self.logger.warning(f"Wrong paragraph citations extraction result format: {json.dumps(para_citations, indent=4)}")
                    else: 
                        if len(para_citations) == 1: # {"citations": [...]}
                            try:
                                para_citations = list(para_citations.values())[0]
                            except:
                                self.logger.warning(f"Wrong paragraph citations extraction result format: {json.dumps(para_citations, indent=4)}")
                        else: # only one citation: {'authors': ..., 'year': ...}
                            para_citations = [para_citations]
                    
                    # import pdb; pdb.set_trace()
                    for cit in para_citations:
                        author = cit['authors'].split(' ')[0]
                        year = cit['year']
                        target_entry = ref_df[
                            (ref_df['author'].str.contains(author, na=False)) &
                            (ref_df['date']==year)
                        ]
                        if target_entry.empty:
                            self.logger.warning(f"failed to match the citation: {json.dumps(cit)}")
                        else:
                            para_matched_ref.append(target_entry['origin'].values[0])
                    self.mentioned_ref += para_matched_ref
                    processed_rw.loc[len(processed_rw)] = {
                        "theme": para_title,
                        "content": para,
                        "reference": '\n'.join(para_matched_ref)
                    }
        
        parse_reference()
        match_citation()
        return processed_rw
    
    def extractor(self):
        entity = pd.DataFrame(columns=['entity name', 'entity type', 'timestamp', 'description'])
        relation = pd.DataFrame(columns=['entity1', 'relation', 'entity2'])
        
        for i in tqdm(range(len(self.rw_units)), desc='extract from paragraphs in related work'):
            unit = self.rw_units.iloc[i]
            temp = f"- paper title: {self.title}\n- excerpt theme: {unit['theme']}\n- content: {unit['content']}\n- citations: \n{unit['reference']}"
            m_e = [
                    {"role": "system", "content": EXTRACT_RELATED_WORK_ENTITY},
                    {"role": "user", "content": temp}
                ]
            para_entity = self.get_json_chat_completion(m_e)
            if not para_entity:
                self.logger.warning(f"Potential Risk: The extraction of the related work paragraph //{unit['content']}// has failed.")
                continue
            else:
                if len(para_entity) == 1:
                    try:
                        para_entity = list(para_entity.values())[0]
                    except:
                        self.logger.warning(f"wrong result format: {json.dumps(para_entity)}")
                m_r = [
                    {"role": "system", "content": EXTRACT_RELATED_WORK_RELATION},
                    {"role": "user", "content": temp + f"\nENTITIES:\n{json.dumps(para_entity)}"}
                ]
                para_relation = self.get_json_chat_completion(m_r)
                if not para_relation:
                    self.logger.warning(f"Potential Risk: The extraction of the related work paragraph //{unit['content']}// has failed.")
                    continue
                else:
                    if len(para_relation) == 1:
                        try:
                            para_relation = list(para_relation.values())[0]
                        except:
                            self.logger.warning(f"wrong result format: {json.dumps(para_relation)}")

                    para_entity_df = pd.DataFrame(para_entity)
                    para_relation_df = pd.DataFrame(para_relation)
                    entity = pd.concat([entity, para_entity_df], ignore_index=True)
                    relation = pd.concat([relation, para_relation_df], ignore_index=True)
        
        if not entity.empty and not relation.empty:
            self.entity = entity
            self.relation = relation
            # save intermediate results
            # entity.to_csv(OUTPUT_PATH / "tempfile" / "relatedwork_entity.csv", index=False)
            # relation.to_csv(OUTPUT_PATH / "tempfile" / "relatedwork_relation.csv", index=False)
            self.logger.info("==== sucessfully extract entities and relations from the related work part. ====")
            return True
     
    def complete_paper_info(self) -> int:
        paper = self.entity[self.entity['entity type'] == 'paper']
        success_cnt = 0
        for i in tqdm(range(len(paper)), desc='complete paper entity information'):
            p = paper.iloc[i]
            p_title = p['entity name']
            p_date = p['timestamp']
            search_res = title_search(p_title).run_pipeline()
            if search_res:
                p['timestamp'] = search_res['date'] if search_res['date'] else p_date
                p['description'] = search_res['abstract'] if search_res['abstract'] else p['description']
                success_cnt += 1
            else:
                self.logger.info(f"failed to update paper entity: {p_title}")
        return success_cnt
        
    def build(self):
        def merge_content():
            rw_content = "# Related Work\n"
            for unit in self.related_work:
                subtitle = unit['subtitle']
                section = f"## {subtitle}\n" + '\n'.join(unit['paragraphs'])
                rw_content += section
            rw_ref = '\n'.join(self.mentioned_ref)
            merged = rw_content + '\n' + '# Reference\n' + rw_ref
            return merged
        issuccess = self.extractor()
        issuccess = True
        if issuccess:
            merged_content = merge_content()
            isimproved = self.imporve(content=merged_content)
            self.checker(content=merged_content)
            improve_status = "Success" if isimproved else "Failed"
            improve_status = "Success"
            success_cnt = self.complete_paper_info()
            graph_info = f'''Original graph construction information:
                    ** {len(self.entity)} entities extracted **
                        - paper entity: {(self.entity['entity type'] == 'paper').sum()}
                        - method entity: {(self.entity['entity type'] == 'method').sum()}
                        - problem entity: {(self.entity['entity type'] == 'problem').sum()}
                        - domain entity: {(self.entity['entity type'] == 'domain').sum()}
                    ** {len(self.relation)} relations extracted **
                    ** Graph optimization status: {improve_status} **
                    ** Checked the graph **
                        - removed duplicates
                        - aligned the entity names in the entity list and the relation list
                        - removed discrete entities
                    ** Complete {success_cnt}/{(self.entity['entity type'] == 'paper').sum()} citation information **'''
            print(graph_info)
            with open(FILE_PATH / "rw_subgraph_info.txt", 'w') as f:
                f.write(graph_info)
            print("--------------------------------------------------------------")
            self.entity.to_csv(FILE_PATH / "relatedwork_entity_final.csv", index=False)
            self.relation.to_csv(FILE_PATH / "relatedwork_relation_final.csv", index=False)
            return True
        
        
class MergeGraph:
    def __init__(self, 
                 graph1: Annotated[GraphBuilder, "abstract graph"], 
                 graph2: Annotated[GraphBuilder, "related work graph"]) -> None:
        self.entity = pd.concat([graph1.entity, graph2.entity], ignore_index=True)
        self.entity = self.entity.drop_duplicates(subset='entity name', keep='first')
        self.relation = pd.concat([graph1.relation, graph2.relation], ignore_index=True)
        title = graph1.title
        abstract = graph1.abstract
            
        rw_content = "## related work\n"
        for unit in graph2.related_work:
            _title = unit['subtitle']
            _paras = '\n'.join(unit['paragraphs'])
            _content = f"### {_title}" + '\n' + _paras
            rw_content += _content    
        
        self.raw_content = f'# paper title: {title}\n## abstract\n{abstract}\n' + rw_content
        self.logger = get_logger("merge subgraphs")
        
    def get_json_chat_completion(self, messages) -> Dict | None:
        def get_chat_completion(messages) -> str | None:
            for i in range(3):
                try:
                    reply = openai_call(messages=messages)
                    assert reply, "failed to get reply."
                    return reply
                except Exception as e:
                    self.logger.info(f"{i + 1}/3 - try to get reply for message: {messages} with error: {e}")
        def load_json_result(content):
            # tools
            def extract_json_res(content: str) -> Dict:
                json_marker_pattern = r"```json([\s\S]*?)```"
                match = re.search(json_marker_pattern, content, re.DOTALL)
                json_content = match.group(1)
                dict_content = json.loads(json_content)
                return dict_content
            def llm_json_formatter(content: str) -> str:
                prompt = f'''Please check if the text conforms to JSON format. If it does not, output the correct JSON format result or extract the part in JSON format; if it does, return the original text.
                    Please return a valid JSON result, without any extra explanations or symbols.
                    TEXT: {content}'''
                res = openai_call(
                    messages=[{"role": "user", "content": prompt.format(content)}],
                    llm_model_name='gpt-4o-mini'
                )
                return res
            
            # main workflow
            try:
                json_res = json.loads(content)
                return json_res
            except Exception as e:
                self.logger.info(f"failed to directly load the content: {content}, ERROR: {e}. try to use re to extract it.")
            try:
                json_res = extract_json_res(content=content)
                return json_res
            except Exception as e:
                self.logger.debug(f"failed to extract the content: {content}. ERROR: {e}. try to use LLM to format it.")
            try:
                new_content = llm_json_formatter(content=content)
                json_res = extract_json_res(new_content)
                return json_res
            except Exception as e:
                self.logger.debug(f"the content: {content}, improved by the LLM: {new_content}, still not conform to the JSON format. ERROR: {e}")
        for _ in range(3):
            try:
                _reply = get_chat_completion(messages=messages)
                reply = load_json_result(_reply)
                assert reply, "failed to get json reply."
                return reply
            except:
                pass
        self.logger.error(f"failed to get valid result for message: {messages}")
        
    def smooth(self):
        '''Determine whether the two can be merged based on entity names and descriptions.'''
        def add_node(dict_info):
            if isinstance(dict_info, dict):
                std_keys = ['entity name', 'entity type', 'timestamp', 'description']
                if list(dict_info.keys()) == std_keys:
                    if dict_info['entity type'] in ['paper', 'method', 'problem', 'domain'] and dict_info['entity name']:
                        self.entity.loc[len(self.entity)] = dict_info
                        self.logger.info(f"successfully merge redundant nodes into a new one: {json.dumps(dict_info)}")
                        return True
                else:
                    self.logger.warning(f"wrong entity node format: {json.dumps(dict_info)}, failed to merge redundant nodes into a new one.")
            else:
                self.logger.warning(f"wrong entity node format: {json.dumps(dict_info)}, failed to merge redundant nodes into a new one.")
        def replace_node(old_names, new_name):
            # delete redundant entity nodes
            self.entity = self.entity[~self.entity['entity name'].isin(old_names)].reset_index(drop=True)
            # replace all the names of the redundant entities with new one
            self.relation['entity1'] = self.relation['entity1'].replace(old_names, new_name)
            self.relation['entity2'] = self.relation['entity2'].replace(old_names, new_name)
        
        self.logger.info("merge the graph extracted from abstract and related work...")
        nopaper_entity_name = self.entity[self.entity['entity type'] != 'paper']['entity name'].values
        str_entity_name = ', '.join(nopaper_entity_name)
        similar_entities = self.get_json_chat_completion(
            [{"role": "user", "content": FIND_SIMILAR_ENTITY.format(entity=str_entity_name)}]
        )
        if not similar_entities is None:
            # import pdb; pdb.set_trace()
            if similar_entities:
                str_similar_entities = [json.dumps(g) for g in similar_entities]
                self.logger.info(f"Find {len(similar_entities)} group(s) of entities that can be merged:" + "\n" + '\n'.join(str_similar_entities))
                for group in tqdm(similar_entities, desc="merge redundant entities"):
                    target_entity = self.entity[self.entity['entity name'].isin(group)].to_dict(orient='records')
                    str_target_entity = json.dumps(target_entity)
                    new_entity  = self.get_json_chat_completion(
                        [{
                            "role": "user",
                            "content": MERGE_SIMILAR_ENTITY.format(entity=str_target_entity)
                        }]
                    )
                    if new_entity:
                        self.logger.info(f"created new entity node: {new_entity} for redundant entites {str_target_entity}")
                        successadd = add_node(new_entity)
                        if successadd:
                            replace_node(group, new_entity['entity name'])
                            self.entity.to_csv(FILE_PATH / "entity.csv", index=False)
                            self.relation.to_csv(FILE_PATH / "relation.csv", index=False)
                    else:
                        self.logger.info(f"failed to create new entity node for redundant entities: {str_target_entity}")
            else:
                self.logger.info("No redundant entity was found.")
        else:
            self.logger.warning("Failed to merge duplicate entities.")