import os
import json
from arxiv_crawler import *
from extract_rw import *

def update_json_with_arxiv_info(data, topics, source_arxiv_ids, target_arxiv_ids):
    for i in range(len(topics)):
        topic = topics[i]
        if str(i+1) not in data['topic_paper_info']:
            data['topic_paper_info'][str(i+1)] = {
                "topic": "",
                "source_paper":{ 
                    "arxiv_id": "",
                    "title": "",
                    "date": "",
                    "abstract": "",
                    "related_work": ""
                },
                "target_paper": {}
            }
        data['topic_paper_info'][str(i+1)]['topic'] = topic
        source_arxiv_id = source_arxiv_ids.get(str(i+1), '')
        if not source_arxiv_id:
            print(f"the {i}th topic doesn't have valid source paper.")
        else:
            target_arxiv_id = target_arxiv_ids[i]

        # 更新source_paper
        source_paper_url = f"https://arxiv.org/abs/{source_arxiv_id}"
        response = requests.get(source_paper_url)
        if response.status_code == 200:
            source_paper_info = extract_arxiv_paper_info(html_content=response.text)
            data['topic_paper_info'][str(i+1)]['source_paper'].update(source_paper_info)
        else:
            print(f"Failed to retrieve {source_paper_url}")

        # 更新target_papers
        for key in target_arxiv_id:
            if key not in data['topic_paper_info'][str(i+1)]['target_paper']:
                data['topic_paper_info'][str(i+1)]['target_paper'][key] = {
                    "arxiv_id": "",
                    "title": "",
                    "date": "",
                    "related_work": ""           
                }
            tid = target_arxiv_id.get(key, '')
            if tid:
                response = requests.get(f"https://arxiv.org/abs/{tid}")
                if response.status_code == 200:
                    target_paper_info = extract_arxiv_paper_info(html_content=response.text, source=False)
                    data['topic_paper_info'][str(i+1)]['target_paper'][key].update(target_paper_info)
                else:
                    print(f"Failed to retrieve {tid}")
    
    return data


def update_json_with_rw(data, topics, source_arxiv_ids, target_arxiv_ids):
    for i in range(len(topics)):
        source_arxiv_id = source_arxiv_ids.get(str(i+1), '')
        if not source_arxiv_id:
            print(f"the {i}th topic doesn't have valid source paper.")
        else:
            target_arxiv_id = target_arxiv_ids[i]

        if not os.path.exists(f"pdfs/{str(i+1)}/"):
            # 如果路径不存在，则创建路径
            os.makedirs(f"pdfs/{str(i+1)}/")

        # 更新source_paper
        source_save_path = f"pdfs/{str(i+1)}/source.pdf"
        download_arxiv_pdf(source_arxiv_id, source_save_path)
        source_rw = extract_rw(pdf_path=source_save_path)
        data['topic_paper_info'][str(i+1)]['source_paper'].update({"related_work": source_rw})

        # 更新target_papers
        for key in target_arxiv_id:
            tid = target_arxiv_id.get(key, '')
            if tid:
                target_save_path = f"pdfs/{str(i+1)}/target"+key+".pdf"
                download_arxiv_pdf(tid, target_save_path)
                target_rw = extract_rw(pdf_path=target_save_path)
                data['topic_paper_info'][str(i+1)]['target_paper'][key].update({"related_work": target_rw})
    
    return data





def main(json_file_path, topics, source_arxiv_ids, target_arxiv_ids):
    # 定义数据
    data = {
        "topic_paper_info": {
            "1": {
                "topic": "",
                "source_paper": {
                    "arxiv_id": "",
                    "title": "",
                    "date": "",
                    "abstract": "",
                    "related_work": ""
                },
                "target_paper": {
                    "1":{
                        "arxiv_id": "",
                        "title": "",
                        "date": "",
                        "related_work": ""                   
                    },
                    "2":{
                        "arxiv_id": "",
                        "title": "",
                        "date": "",
                        "related_work": ""                   
                    }
                }
            }
        }
    }

    if len(topics) == len(source_arxiv_ids) and len(topics) == len(target_arxiv_ids):
        data = update_json_with_arxiv_info(data, topics, source_arxiv_ids, target_arxiv_ids)
        data = update_json_with_rw(data, topics, source_arxiv_ids, target_arxiv_ids)
    else:
        print("Invalid input!")

    # 将数据写入 JSON 文件
    with open(json_file_path, 'w') as json_file:
        json.dump(data, json_file, indent=4)

    print(f"{len(topics)} topics with source and target papers have been collected.")



if __name__ == '__main__':
    path = 'data/topic_paper_db.json'
    topics = ["In-context Learning"]
    source_arxiv_ids = {"1": "2201.11903"}
    target_arxiv_ids = [{"1":"2403.06914", "2":"2310.10638", "3":"2311.06668", "4":"2402.10738", "5": "2401.03385", "6": "2308.06912"}]
    main(path, topics, source_arxiv_ids, target_arxiv_ids)