import json
import re
import os
from typing import List, Dict, Tuple
from transformers import pipeline
from .utils.prompts import EXTRACT_FRAMEWORK_SYS, EXTRACT_FRAMEWORK_USER, PARA_CLASSIFER,\
PARSE_REF_SYS, PARSE_REF_USER
from .utils.constants import RW_PATH, REF_PATH, OUTPUT_PATH
from .utils.websearch import title_search
from .utils.openai import openai_call

def parse_ref_LLM(ref_text: str) -> Dict[str, str]:
    ref = []
    _refs = ref_text.split('\n')
    for r in _refs:
        try:
            res = openai_call(messages=[
                {"role":"system", "content": PARSE_REF_SYS},
                {"role": "user", "content": PARSE_REF_USER.format(ref=r)}
            ], json_format=True)
            parsed_ref = json.loads(res)
            ref.append(parsed_ref)
        except Exception as e:
            print(f"failed to parse '{r}':\n{e}")
    
    with open(OUTPUT_PATH / "LLM_parsed_reference.json", 'w') as f:
        json.dump(ref, f, indent=4)
    return ref

def preprocess(content: str):
    # break down the content
    paras_ = content.split("\n")

    # initialize input content
    temp_topic = ''
    temp_content = []
    temp_parsed_para = []
    parsed_paras: List[Dict] = []

    for index, para in paras_:

        mess = [
            {"role": "user", "content": PARA_CLASSIFER.format(content=para)}
        ]
        res_ = openai_call(messages=mess, json_format=True)
        for i in range(3):
            try:
                res = json.loads(res_)
                break
            except Exception as e:
                print(f"Attempt {i+1} - fail to load json paragraph information: {e}")
                if i==2: # MEMO: a temporary strategy
                    match_title = re.match(r"## \d+\.\d+ ([a-zA-Z\s]+)")
                    if match_title:
                        res['type'] = 'title'
                        res['content'] = match_title.group(1)
                    else:
                        res['type'] = 'content'
                    break

                mess.append({"role": "assistant", 
                             "content": res_})
                mess.append({"role": "user", 
                             "content": "Please return a valid JSON object without adding any extra tags or explanations."})
                res_ = openai_call(messages=mess)

        if res['type'] == 'title':
            if index != 0:
                temp_parsed_para['topic'] = temp_topic
                temp_parsed_para['content_para'] = temp_content
                parsed_paras.append(temp_parsed_para)
            temp_topic = res['content']
        else:
            temp_content.append(para)
            if index == len(paras_) - 1:
                temp_parsed_para['topic'] = temp_topic
                temp_parsed_para['content_para'] = temp_content
                parsed_paras.append(temp_parsed_para)
    
    return parsed_paras

def initial_graph(content: str):
    '''extract entities and relations from content'''
    
    parsed_paras = preprocess(content=content)
    entity = []
    relation = []
    failure = False

    for index, unit in parsed_paras:
        topic = unit['topic']

        temp_res_path = OUTPUT_PATH / "temp" / topic
        os.makedirs(temp_res_path, exist_ok=True)
        
        for para in unit['content_para']:
            mess = [
                {"role": "system", "content": EXTRACT_FRAMEWORK_SYS},
                {"role": "user", "content": EXTRACT_FRAMEWORK_USER.format(topic=topic, content=para)}
            ]
            res = openai_call(messages=mess, json_format=True)
            for i in range(3):
                try:
                    json_res = json.loads(res)
                    entity += json_res[0]['content']
                    relation += json_res[1]['content']
                    break
                except Exception as e:
                    print(f"Attempt {i+1} - failed to get JSON entity and relation: {e}")
                    mess.append({"role": "assistant", "content": res})
                    mess.append({"role": "user", "content": "Please return a valid JSON object without adding any extra tags or explanations."})
                    res = openai_call(mess)
                    if i==2:
                        print("Unable to save as JSON format, temporarily saved as a txt file for manual inspection")
                        with open(temp_res_path / f"entities_and_relations_{index}.txt", 'w') as f:
                            f.write(res)
                        failure = True
    if failure:
        # Not all entities and relationships were saved correctly; 
        # the successfully parsed parts are temporarily stored for further manual merging
        with open(OUTPUT_PATH / "temp" / "incomplete_entity.json", 'w') as f:
            json.dump(entity, f, indent=4)
        with open(OUTPUT_PATH / "temp" / "incomplete_relation.json", 'w') as f:
            json.dump(relation, f, indent=4)
        return False
    else:
        print("successfully extract entities and relations.")
        with open(OUTPUT_PATH / "relations.json", 'w') as f:
            json.dump(relation, f, indent=4)
        with open(OUTPUT_PATH / "entities_raw.json", 'w') as f: # TODO: may be deleted later
            json.dump(entity, f, indent=4)
        return entity

def split_authors(author_list: str) -> str:
    model_name = "uer/roberta-base-chinese-extractive-qa"
    nlp = pipeline('question-answering', 
               model=model_name, 
               tokenizer=model_name, 
               device=1)
    QA_input = {
        'question': 'What is the first name in this list of names?',
        'context': author_list
    }
    first_author = nlp(QA_input)
    return first_author

def update_entity(raw_entity: List[Dict], ref: List[Dict]) -> List[Dict]:
    '''complete the abstract and title information of citations'''
    entity_new = []
    missing = []
    for unit in raw_entity:
        if unit['entity type'] == "paper":
            temp = unit['entity name'].split(' ')
            aut = temp[0].strip(',')
            yea = temp[-1]
            for index, r in ref:
                first_author = split_authors(r['authors'])
                if aut in first_author and yea == r['date']:
                    # get title
                    title = r['title']
                    # get initial timestamp                    
                    try:
                        timestamp = re.match(r"(\d{4})[a-zA-Z]?", yea).group(1)
                    except Exception as e:
                        print(f"failed to parse publication time '{yea}': {e}, it will be directly saved.")
                        timestamp = yea
                    # get description
                    description = ''
                    updated_cit = title_search(title).run_pipeline()
                    if updated_cit:
                        if updated_cit['date']: # update timestamp
                            timestamp = updated_cit['date']
                            
                        if len(updated_cit['abstract']):
                            description = updated_cit['abstract']
                        elif len(updated_cit['link']):
                            description = updated_cit['link']
                        
                    _cit = {
                        "entity name": title,
                        "entity type": "paper",
                        "timestamp": timestamp,
                        "description": description
                    }
                    entity_new.append(_cit)
                else:
                    if index == len(ref) - 1:
                        missing.append(unit)
                        print(f"information missing: {unit}")
        else:
            temp = {
                "entity name": unit['entity name'],
                "entity type": unit['entity type'],
                "timestamp": None,
                "description": None
            }
            entity_new.append(temp)
    # print(f"failed to update the following entries:\n{'\n'.join(missing)}")
    return entity_new   

# TODO: update relations

def build_graph(related_work: str, reference: str) -> Tuple[List[Dict], List[Dict]] | None:
    # build the initial graph
    entity_ = initial_graph(related_work)
    if not entity_:
        print("Output format error, manual intervention required.")
        return
    # preprocess the reference list
    parse_ref_LLM(reference)

if __name__ == '__main__':
    '''MEMO: supposing the content is correctly extracted from the PDF file.'''
    # load files
    with open(RW_PATH, 'r') as f:
        related_work = f.read()
    with open(REF_PATH, 'r') as f:
        reference = f.read()

    
