import json
from tqdm import tqdm
from typing import List, Dict
from .utils.prompts import PARSE_REF_SYS, PARSE_REF_USER, EXTRACT_REF_SYS, EXTRACT_REF_USER, EXTRACT_FROM_ABSTRACT, EXTRACT_FROM_RELATED_WORK, REVIEWER, WRITER, MERGE_SIMILAR_ENTITY
from .utils.openai import openai_call
from .utils.constants import OUTPUT_PATH, INPUT_PATH
from .utils.websearch import title_search
from .utils.logger import get_logger
from transformers import pipeline

# TODO: modify the output path to version2
logger = get_logger(__name__)

def parse_ref_LLM(ref: List[str]) -> Dict[str, str]:
    # TODO: dataframe
    '''
    Parse the reference list to dict. The format of the parsed reference entry:
        {
            "authors": authors,
            "date": date,
            "title": title,
            "origin": origin
        }
    '''
    with tqdm(ref, desc='parse reference list', miniters=1) as pbar:
        parsed_ref = []
        for r in ref:
            try:
                res = openai_call(messages=[
                    {"role":"system", "content": PARSE_REF_SYS},
                    {"role": "user", "content": PARSE_REF_USER.format(ref=r)}
                ], json_format=True)
                parsed_ref_ = json.loads(res)
                parsed_ref_["origin"] = r
                parsed_ref.append(parsed_ref_)
            except Exception as e:
                print(f"failed to parse '{r}':\n{e}")
            pbar.update(1)
    
    with open(OUTPUT_PATH / "LLM_parsed_reference.json", 'w') as f:
        json.dump(parsed_ref, f, indent=4)
    return parsed_ref

def split_authors(author_list: str) -> str:
    model_name = "uer/roberta-base-chinese-extractive-qa"
    nlp = pipeline('question-answering', 
               model=model_name, 
               tokenizer=model_name, 
               device=1)
    QA_input = {
        'question': 'What is the first name in this list of names?',
        'context': author_list
    }
    first_author = nlp(QA_input)
    return first_author['answer']

def citation_hunter(content: str, reference: List[Dict]) -> List[str]:
    '''find out the citations of the content and match them with their entries in reference list'''
    # find out the citations
    messages = [
            {"role": "system", 
            "content": EXTRACT_REF_SYS},
            {"role": "user",
            "content": EXTRACT_REF_USER.format(content=content)}
        ]
    _res = openai_call(messages=messages)
    res = json.loads(_res) # TODO: make sure the result can be loaded
    # match the citations
    matched_ref = []
    with tqdm(res, desc='match paragraph citations', miniters=1) as pbar:
        for unit in res:
            aut = unit['authors'].split(' ')[0]
            yea = unit['year']
            is_match = False
            for index, r in enumerate(reference):
                if yea == r['date']: # filter
                    # import pdb; pdb.set_trace()
                    first_author = split_authors(r['authors'])
                    if aut in first_author:
                        matched_ref.append(r['origin'])
                        is_match = True
                        break
            if not is_match:
                print(f"failed to match the citation entry: {unit}")
            pbar.update(1)
    return matched_ref

def format_pack(relate_work: List[Dict], abstract, reference):
    content = "# abstract\n" + abstract + "# related work\n"
    for subsection in relate_work:
        sub = f"## {subsection['subtitle']}\n"
        for para in subsection['paragraphs']:
            sub = para + '\n'
        content += sub
    content += "# reference\n" + '\n'.join(reference)
    return content

def preprocess(input: Dict):
    '''
    the input should include paper title, abstract, related work paragraphs, and references
    prepare raw content: 
        for abstract: paper title, content
        for related work: excerpt theme, paragraph content, citations
    '''
    raw = input
    mentioned_ref = []
    
    # parse references
    # reference = next((item['content'] for item in raw if item.get('name') == 'reference'), None)
    # parsed_ref = parse_ref_LLM(reference) # MEMO: temp
    with open(OUTPUT_PATH / "LLM_parsed_reference.json", 'r') as f:
        parsed_ref = json.load(f)
    
    # pack abstract extraction prompt
    title = next((item['content'] for item in raw if item.get('name') == 'paper title'), None)
    abstract = next((item['content'] for item in raw if item.get('name') == 'abstract'), None)
    packed_abs = f"TEXT:\n- paper title: {title}\n- abstract: {abstract}"
    
    # extract citations of related work and match them
    related_work = next((item['content'] for item in raw if item.get('name') == 'related work'), None)
    packed_rw = []
    # MEMO: temp
    print("number of total paragraph matching task: 4")
    for unit in related_work:
        _theme = unit['subtitle']
        for para in unit['paragraphs']:
            _matched_ref = citation_hunter(para, parsed_ref)
            mentioned_ref += _matched_ref
            packed_rw.append(
                {
                    "theme": _theme,
                    "para_content": para,
                    "reference": '\n'.join(_matched_ref)
                }
            )
    
    packed_all = format_pack(title=title, relate_work=related_work, abstract=abstract, reference=mentioned_ref)
    
    with open(OUTPUT_PATH / "tempfile" / "packed_abs.txt", 'w') as f:
        f.write(packed_abs)
    with open(OUTPUT_PATH / "tempfile" / "packed_rw.json", 'w') as f:
        f.write(packed_rw)
    with open(OUTPUT_PATH / "tempfile" / "packed_all.txt", 'w') as f:
        f.write(packed_all)
    
    return packed_abs, packed_rw, packed_all
    
def abstract_extractor(packed_abs: str) -> List[Dict] | None:
    '''extract entities and realtions from abstract.'''
    m = [
        {"role": "system", "content": EXTRACT_FROM_ABSTRACT},
        {"role": "user", "content": packed_abs}
    ]
    abs_extraction = openai_call(messages=m, json_format=True)
    try:
        json_abs_extraction = json.loads(abs_extraction)
        if isinstance(json_abs_extraction, list):
            if len(json_abs_extraction) == 2:
                with open(OUTPUT_PATH / "tempfile" / "abstract_extraction.json", 'w') as f:
                    json.dump(json_abs_extraction, f, indent=4)
                logger.info("successfully extract from abstract. Data can be found at /outputs/tempfile/abstract_extraction.json.")
                return json_abs_extraction[0]['content'], json_abs_extraction[1]['content']
        else:
            # missing relation list
            logger.info("abstract extraction failure: element missing.")
    except Exception as e:
        print(f"abstract_extractor: Output format error, manual intervention required. ERROR: {e}")
        logger.error("failed to extract from abstract.")
        with open(OUTPUT_PATH / "tempfile" / "abstract_extraction.txt", 'w') as f:
            f.write(abs_extraction)
    
# TODO: Encapsulate the writer agent
def related_work_extractor(packed_rw: List[Dict], title: str, content_rw: str):
    '''extract entities and realtions from abstract.'''
    entity = []
    relation = []
    with tqdm(packed_rw, desc='extract from paragraphs in related work') as pbar:
        for unit in packed_rw:
            temp = f"- paper title: {title}\n- excerpt theme: {unit['theme']}\n- content: {unit['para_content']}\n- citations: \n{unit['reference']}"
            m = [
                {"role": "system", "content": EXTRACT_FROM_RELATED_WORK},
                {"role": "user", "content": temp}
            ]
            _para_res = openai_call(messages=m, json_format=True)
            try:
                para_res = json.loads(_para_res)
                entity += para_res[0]['content']
                relation += para_res[1]['content']
                with open(OUTPUT_PATH / "tempfile" / "related_work" / f"{unit['theme']}" / f"{unit['para_content'][10]}.json", 'w') as f:
                    json.dump(para_res, f, indent=4)
                logger.info(f"successfully extract from related work paragraph: {unit['para_content']}")
            except Exception as e:
                print(f"related_work_extractor: Output format error, manual intervention required. ERROR: {e}")
                logger.error(f"failed to extract from related work paragraph: {unit['para_content']}")
                with open(OUTPUT_PATH / "tempfile" / "related_work" / f"{unit['theme']}" / f"{unit['para_content'][10]}.txt", 'w') as f:
                    f.write(_para_res)
            pbar.update(1)
    if entity and relation:
        logger.info("successfully extract from related work. merging the lists...")
        complete_entity = rule_entity_check(entity_list=entity, relation_list=relation)
        logger.info("all the entities mentioned in relation list are ensured to exist in the entity list. reviewer is going to check the quality of graph...")
        
        history = f'''[{{"list name": "entity", "content": {json.dumps(complete_entity)}}}, {{"list name": "relation", "content": {json.dumps(relation)}}}]'''
        
        review = reviewer(
            title, 
            content_rw, 
            history
        )
        if not 'PASS' in review:
            # modify on origin result
            logger.info(f"reviewer: {review}\nthe writer is going to edit the graph...")
            final_entity, final_relation = writer(review, history)
            
            if not final_entity:
                logger.info("the writer failed to update the graph.")
                final_entity, final_relation = complete_entity, relation
            else:
                logger.info("the graph is updated by the writer.")
        else:
            logger.info("the reviewer has approved the current graph.")
            final_entity, final_relation = complete_entity, relation
            
        with open(OUTPUT_PATH / "rw_extraction_entity.json", 'w') as f:
            json.dump(final_entity, f, indent=4)
        with open(OUTPUT_PATH / "rw_extraction_relation.json", 'w') as f:
            json.dump(final_relation, f, indent=4)
            
        return final_entity, final_relation
        
def check_existence(target: str, name_list: List[str]):
    if target in name_list:
        return None
    else:
        type_classification = '''You can choose from the following entity type: problem, method, idea, paper, domain. What type does this entity {entity} belong to? Only reply with the type and do not add any other information.'''
        type_ = openai_call(
            messages=[{"role": "user", "content": type_classification.format(target)}]
        )
        temp_entity = {
            "entity name": target,
            "entity type": type_,
            "timestamp": "null",
            "description": "null"
        }
        return temp_entity
        
def rule_entity_check(entity_list: List[Dict], relation_list: List[Dict]):
    '''check if all the entities mentioned in the relation list exist in entity list. if not, create corresponding entity entry.'''
    # get entity name list
    entity_name = []
    for e in entity_list:
        name = e['entity name']
        entity_name.append(name)
        
    # check relation list
    add_entity = []
    with tqdm(relation_list, desc='rule entity check', miniters=1) as pbar:
        for r in relation_list:
            e1 = r['entity1']
            r1 = check_existence(e1, entity_name)
            e2 = r['entity2']
            r2 = check_existence(e2, entity_name)
            if r1:
                add_entity.append(r1)
            if r2:
                add_entity.append(r2)
            pbar.update(1)
        
    entity_list += add_entity
    return entity_list

def reviewer(title, content, result):
    m = [
        {"role": "user", "content": REVIEWER.format(title=title, content=content, result=result)}
    ]
    return openai_call(m)
    
def writer(review, history):
    m = [
        {"role": "user", "content": WRITER.format(review=review, history=history)}
    ]
    _res = openai_call(m, json_format=True)
    try:
        res = json.loads(_res)
        entity = res[0]['content']
        relation = res[1]['content']
        return entity, relation
    except Exception as e:
        print(f"failed to improve the result: {e}")
        return None, None
    
def rename_similar_entity(abs_entity_list, rw_entity_list):
    '''{
        "new name": the new name of the entity,
        "entity1": entity name in the first list,
        "entity2": entity name in the first list
    }'''
    m = [
        {"role": "user", "content": MERGE_SIMILAR_ENTITY.format(abs_list=abs_entity_list, rw_list=rw_entity_list)}
    ]
    _res = openai_call(m, json_format=True)
    try:
        res = json.loads(_res)
        return res
    except Exception as e:
        print(f"merge_similar_entity: Output format error, manual intervention required. ERROR: {e}")
        with open(OUTPUT_PATH / "tempfile" / "merge" / "similar_entity.txt", 'w') as f:
            f.write(_res)
           
def find_duplicates(string_list):
    seen = {}
    duplicates = {}

    for index, value in enumerate(string_list):
        if value in seen:
            duplicates[value].append(index)
        else:
            seen[value] = index
            duplicates[value] = [index]

    return {key: indices for key, indices in duplicates.items() if len(indices) > 1}

def merge(abs_entity, abs_relation, rw_entity, rw_relation):
    '''merge the entity and relation lists extracted from both abstract and related work'''
    # name list
    abs_entity_list = []
    rw_entity_list = []
    for e in abs_entity:
        abs_entity_list.append(e['entity name'])
    for e in rw_entity:
        rw_entity_list.append(e['entity name'])
        
    # merge duplicates
    logger.info("STAGE1: merge duplicates in each list")
    abs_dup = find_duplicates(abs_entity_list)
    if abs_dup:
        logger.info(f"find duplicates in abstract entities: {abs_dup}")
        total_iter = sum(len(value) - 1 for value in abs_dup.values())
        with tqdm(total=total_iter, desc="clear duplicates in abstract list") as pbar:
            for _, value in abs_dup.items():
                for i in range(1,len(value)):
                    abs_entity.pop(i)
                    pbar.update(1)
    else:
        logger.info("no duplicate found in abstract entity list")
    rw_dup = find_duplicates(rw_entity_list)
    if rw_dup:
        logger.info(f"find duplicates in related work entities: {rw_dup}")
        total_iter = sum(len(value) - 1 for value in rw_dup.values())
        with tqdm(total=total_iter, desc="clear duplicates in related work list") as pbar:
            for _, value in rw_dup.items():
                for i in range(1,len(value)):
                    rw_entity.pop(i)
                    pbar.update(1)
    else:
        logger.info("no duplicate found in related work entity list")
    
    # merge the similar entities in the two lists, which involves changes to relation list
    logger.info("STAGE2: merge the similar entities in the two lists")
    similar_entity = rename_similar_entity(abs_entity_list, rw_entity_list)
    if similar_entity:
        logger.info(f"find similar entities in two lists: {similar_entity}")
    detail_info = {}
    with tqdm(similar_entity, desc="update entity lists", miniters=1) as bar:
        for unit in similar_entity: # Align the abs_entity with the rw_entity
            new_name = unit['new_name']
            abs_name, rw_name = unit['entity1'], unit['entity2']
            for rw_ in rw_entity:
                if rw_['entity name'] == rw_name:
                    _, e_type, e_timestamp, e_description = rw_.values()
                    rw_['entity name'] = new_name
                    
                    detail_info[new_name] = {
                        "entity type": e_type,
                        "timestamp": e_timestamp,
                        "description": e_description
                    }
            for abs_ in abs_entity:
                if abs_['entity name'] == abs_name:
                    abs_ = {
                        k: v for k, v in detail_info[new_name]
                    }
                    abs_['entity name'] = new_name
            bar.update(1)
    merge_entity = rw_entity + abs_entity
    
    logger.info("Finish updating entity lists, update the relation list.")
    merge_relation = abs_relation + rw_relation
    with tqdm(similar_entity, desc='update relation list', miniters=1) as pbar:
        for unit in similar_entity: # update the entity name in relation list
            new_name = unit['new name']
            old_name = [unit['entity1'], unit['entity2']]
            for e in merge_relation:
                if e['entity1'] in old_name:
                    e['entity1'] = new_name
                if e['entity2'] in old_name:
                    e['entity2'] = new_name
        pbar.update(1)
    return merge_entity, merge_relation
    
def complete_paper_info(entity):
    with tqdm(entity, desc='get paper information from web', miniters=1) as pbar:
        for e in entity:
            if e['entity type'] == "paper":
                paper_info = title_search(e['entity name']).run_pipeline()
                if paper_info['date']:
                    e['timestamp'] = paper_info['date']
                if paper_info['abstract']:
                    e['description'] = paper_info['abstract']
            pbar.update(1)

def graph_builder(raw):
    '''
    extract entities and relations from the abstract and related work part of a paper to build graph.
    at first, the abstract and related work will be extracted seperately, and then they will be merged.
    '''
    # initial graph
    packed_abs, packed_rw, packed_all = preprocess(raw)
    abs_entity, abs_relation = abstract_extractor(packed_abs=packed_abs)
    rw_entity, rw_relation = related_work_extractor(packed_rw=packed_rw)
    logger.info("Subgraph extraction is complete, starting the merging process.")
    entity, relation = merge(abs_entity, abs_relation, rw_entity, rw_relation)
    logger.info("finish merging the graph")
    
    # improve
    logger.info("trying to improve the whole graph")
    title = raw[0]['content']
    result = json.dumps(
        [
            {
                "list name": "entity",
                "content": entity
            },
            {
                "list name": "relation",
                "content": relation
            }
        ])
    review = reviewer(title=title, content=packed_all, result=str(result))
    if not 'PASS' in review:
        logger.info(f"reviewer: {review}")
        logger.info("the writer is going to edit the graph")
        updated_entity, updated_relation = writer(review=review, history=str(result))
        if not updated_entity:
            logger.info("the writer failed to improve the quality of the whole graph.")
            final_entity, final_relation = entity, relation
        else:
            final_entity, final_relation = updated_entity, updated_relation
    else:
        logger.info("the reviewer approved the graph.")
        final_entity, final_relation = entity, relation
    
    # paper node compelteion
    logger.info("paper entity completion: abstract will be filled into description.")
    complete_paper_info(final_entity)
    
    # save the result
    logger.info("successfully build the graph.")
    with open(OUTPUT_PATH / "graph_v2"/ "entity.json", 'w') as f:
        json.dump(final_entity, f, indent=4)
    with open(OUTPUT_PATH / "graph_v2" / "relation.json", 'w') as f:
        json.dump(final_relation, f, indent=4)
    
if __name__ == "__main__":
    with open(INPUT_PATH / "raw_content.json", 'r') as f:
        raw = json.load(f)
    graph_builder(raw=raw)