import sys
import os
# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from typing import List, Dict
import requests
import time
import traceback
import re
import os
import pandas as pd
from bs4 import BeautifulSoup

# from paper2graph import Graph
# from utils.openai import get_chat_completion
# from utils.process_html import get_info_from_html
# from utils.process_pdf import process_pdf

from .paper2graph import Graph
from ..utils.openai import get_chat_completion
from ..utils.process_html import get_info_from_html
from ..utils.process_pdf import process_pdf

import string


translator = str.maketrans('', '', string.punctuation)

# new papers part
PREPS = [
        "in", 
        "on", 
        "at", 
        "by", 
        "with", 
        "from", 
        "to", 
        "of", 
        "about", 
        "after", 
        "before", 
        "during", 
        "through", 
        "under"
    ]

def convert_to_lowercase(stri):
    # 将列表中的所有元素转换为字符串小写
    translator = str.maketrans('', '', string.punctuation)
    
    # 遍历列表中的每个元素，将其转换为字符串，去除空格和标点
    return str(stri).lower().translate(translator).replace(" ", "")

def get_arxiv_id_from_s2(topic,
                      problem,
                      start_year,
                      end_year,
                      searchlist,
                      url='https://api.semanticscholar.org/graph/v1/paper/search'):
    searched = searchlist  # 用于存储搜索过的论文标题，防止重复添加
    got = []  # 用于存储成功找到的论文信息
    base_query = f"{topic} {problem}"
    
    for year in range(start_year, end_year + 1):
        print(f"Year: {year}")
        downloaded_papers = 0
        min_citation_count = 45  # 从较高的引用数开始
        min_citation_count_step = 10  # 每次递减的引用数
        papers_per_year = 5  # 每年需要下载的论文数
        max_trial = 10  # 设定最大重试次数
        trial_10 = 0  # 特殊的10次尝试限制
        
        while downloaded_papers < papers_per_year and min_citation_count >= 10 and trial_10 < 2:
            query_param = {
                "query": base_query,
                "minCitationCount": min_citation_count,
                "limit": 100,  # 每次最多返回100篇论文
                "year": year,
                "fields": "title,abstract,externalIds,year,citationCount",  # 增加 externalIds 字段以获取 arXiv ID
                "sort": "influence"  # 按影响力排序
            }
            
            try:
                trial = 0
                searchres = requests.get(url, params=query_param)
                while searchres.status_code != 200 and trial < max_trial:
                    print(f"{trial + 1}/{max_trial} - Retrying... attempt {trial + 1}")
                    searchres = requests.get(url, params=query_param)
                    trial += 1
                    time.sleep(3)  # 等待3秒，避免频繁请求
                
                if searchres.status_code != 200:
                    print(f"{trial + 1}/{max_trial} - Cannot connect to Semantic Scholar.")
                    time.sleep(5)  # 再次等待5秒后继续
                else:
                    search_res = searchres.json()
                    print(f"{trial + 1}/{max_trial} - Get response")
                    
                    if 'data' in search_res:
                        papers = search_res['data']
                        for paper in papers:
                            title = paper.get('title')
                            if str(title).lower().translate(translator).replace(" ", "") in searched:
                                print(f"Duplicate paper found, skipping: {title}")
                                continue
                            searched.append(str(title).lower().translate(translator).replace(" ", ""))
                            externalIds = paper.get('externalIds', {})
                            arxiv_id = externalIds.get("ArXiv")
                            
                            if arxiv_id:
                                # 获取 arXiv ID 并生成 PDF 链接
                                pdf_link = f"https://arxiv.org/pdf/{arxiv_id}.pdf"
                                paper_info = {
                                    'title': title,
                                    'abstract': paper.get('abstract', 'N/A'),
                                    'year': paper.get('year', 'N/A'),
                                    'pdf_link': pdf_link,
                                    'arxiv_id': arxiv_id
                                }
                                downloaded_papers += 1
                                got.append(paper_info)
                                print(f"Successfully downloaded: {title}")
                                print(f"Have downloaded {downloaded_papers}/{papers_per_year} papers for year {year}.")
                            else:
                                print(f"No arXiv ID for: {title}")

                            # 如果达到每年的目标下载数，停止继续下载
                            if downloaded_papers >= papers_per_year:
                                break

                        if downloaded_papers < papers_per_year:
                            print(f"Only downloaded {downloaded_papers}/{papers_per_year} papers for year {year}.")

                    else:
                        print(f"No related articles found for the query: {base_query} in year {year}.")

            except Exception as e:
                print(e)
                time.sleep(6)  # 出现异常时等待3秒后重试

            # 动态调整 minCitationCount
            if downloaded_papers < papers_per_year:
                min_citation_count -= min_citation_count_step  # 减少 minCitationCount 以尝试找到更多论文
                if min_citation_count <= 10:
                    min_citation_count = 10
                    trial_10 += 1  # 如果 min_citation_count 减到10，记录一次尝试次数
                print(f"Decreasing minCitationCount to {min_citation_count} to find more papers.")

    return got

def get_response(url, params):
    for _ in range(10):
        try:
            res = requests.get(url=url, params=params)
        except Exception as e:
            print(e)
            return False
        if res.status_code == 200:
            return res
        elif res.status_code == 400:
            print("Wrong params. Stop searching...")
            return False
        elif res.status_code == 404:
            print("Uncommon error: cannot find the page. Stop searching")
            return False
        elif res.status_code == 429:
            print("Request rate exceeded, try again later.")
            time.sleep(3)
    if not res.status_code == 200:
        print("Failed to get response from arxiv.")
        return False
    return res

def load_data(response):
    data = []
    soup = BeautifulSoup(response.text, 'html.parser')
    print("#EEEE#####################",response.text)
    blocks = soup.find_all('li', class_='arxiv-result')
    for block in blocks:
        try:
            pdf_link = block.find('a', href=True, string='pdf')['href']
            arxiv_id = re.match(r"https://arxiv\.org/pdf/(\d+\.\d+)", pdf_link).group(1)

            title = block.find('p', class_='title is-5 mathjax').text.strip()

            _abstract = block.find('p', class_='abstract mathjax').text.strip()
            abstract = re.search(r"▽ More(.*)△ Less", _abstract, re.DOTALL).group(1).strip()

            submitted_date = block.find('p', class_='is-size-7').text.strip()
            year = re.search(r"originally announced [a-zA-Z]+ (\d{4})", submitted_date).group(1)

            _data = {
                "title": title,
                "abstract": abstract,
                "year": year, 
                "pdf_link": pdf_link,
                "arxiv_id": arxiv_id
            }
            data.append(_data)

        except Exception as e:
            print(f"Failed to load an data entry: {e} Skip it.")
            traceback.print_exc()
    assert data, "Failed to load data"
    return data

def filter(result_data):
    from ..utils.prompts import FILTER_IRRELEVANT
    get_drop = []
    for ind, paper in enumerate(result_data):
        title = paper.get('title')
        abstract = paper.get('abstract')
        if title and abstract:
            isrelated = 'yes' in get_chat_completion(
                [
                    {
                        "role": "user",
                        "content": FILTER_IRRELEVANT.format(title=title, abstract=abstract)
                    }
                ]
            ).lower()
            print(f"{title}, related={isrelated}")
            if not isrelated:
                get_drop.append(ind)
        else:
            print("Cannot get basic information of the paper from response... Skip it.")
    return get_drop
    
def download_pdf(arxiv_id):
    url = f"https://arxiv.org/pdf/{arxiv_id}"
    for _ in range(10):
        res = requests.get(url=url)
        if res.status_code == 200:
            break
        elif res.status_code == 404:
            print("Uncommon error: cannot find the page. Stop searching")
            return
        elif res.status_code == 429:
            print("Request rate exceeded, try again later.")
            time.sleep(5)
    if not res.status_code == 200:
        print("Failed to get response from arxiv.")
        return
    else:
        with open("./temp.pdf", 'wb') as f:
            if f.write(res.content):
                return True

def get_graph_construc_info(title, arxiv_id):
    try:
        temp = get_info_from_html(arxiv_id)
        print(f"Successfully get graph construction information for paper titled *{title}*")
        return temp
    except Exception as e:
        print(e)
        print(f"Failed to get enough graph construction information for paper titled *{title}* from html")

    try:
        isdownloaded = download_pdf(arxiv_id)
        assert isdownloaded, f"Failed to get PDF for paper titled *{title}*"
        temp = process_pdf('./temp.pdf')
        print(f"Successfully get graph construction information for paper titled *{title}*")
        return temp
    except Exception as e:
        print(e)
        print(f"Failed to get enough graph construction information for paper titled *{title}* from PDF")
        print(f"Cannot build graph for paper titled *{title}*")
    try:
        os.remove('./temp.pdf')
    except:
        pass
#semantic scholar


def extend(topic: str, problem: List[str], start_year: int,search_list,query_list):

    results = []
    for p in problem:
        # 去掉介词以获得更多相关论文
        if p!="":
            p_list = p.lower().split(' ')
            for prep in PREPS:
                if prep in p_list:
                    pop_ind = p_list.index(prep)
                    p_list.pop(pop_ind)
            _p = ' '.join(p_list)
        else:
            _p = ""
        
        if convert_to_lowercase(f"{topic} {problem}") in query_list: 
            continue
        else:
            for y in range(start_year, 2024):
                print(f"Searching for papers with topic={topic}, problem={p}, year={y}")
                
                # 使用 Semantic Scholar 搜索 arXiv ID 和相关信息
                try:
                    papers = get_arxiv_id_from_s2(topic, _p, start_year=y, end_year=y,searchlist=search_list)
                except:
                    papers=[]
                    print("get_arxiv_id_from_s2 error")
                if not papers:
                    print(f"Cannot find papers for the query: problem={p}, topic={topic}, year={y}")
                    continue

                result_data = []
                result_data += papers  # 将找到的论文信息添加到 result_data 中


                print("----------------------------------------------------")
                print("Check if these papers are actually related to our query using LLMs...")
                if result_data:
                    print(f"Filtered by LLMs, the following papers are related: {', '.join([p['title'] for p in result_data])}")
                    for _data in result_data:
                        rd = {
                            "title": _data['title'],
                            "date": _data['year'],
                            "abstract": _data['abstract'],
                            "keywords": [topic, p]
                        }
                        temp = get_graph_construc_info(title=_data['title'], arxiv_id=_data['arxiv_id'])
                        if temp:
                            rd.update(temp)
                            results.append(rd)
                        else:
                            continue
                else:
                    print("No paper matched our query")

    if results:
        res_df = pd.DataFrame(results)
        res_df = res_df.drop_duplicates(subset='title', keep="first", ignore_index=True)
        results = res_df.to_dict(orient='records')

    return results            

            # 使用 LLM 过滤不相关的论文
            # get_drop = filter(result_data)
            # result_data = [item for index, item in enumerate(result_data) if index not in get_drop]

            # results.extend(result_data)  # 将结果保存到总的 results 列表中




#arxiv


# def extend(topic: str, problem: List, start_year: int):
#     results = []
#     URL = "https://arxiv.org/search/advanced"
#     params = {
#             "advanced": "",
#             "terms-0-operator": "AND",
#             "terms-0-term": None,
#             "terms-0-field": "all",
#             "terms-1-operator": "AND",
#             "terms-1-term": None,
#             "terms-1-field": "all",
#             "classification-computer_science": "y",
#             "classification-physics_archives": "all",
#             "classification-include_cross_list": "include",
#             "date-filter_by": "specific_year",
#             "date-year": None,
#             "date-date_type": "submitted_date_first",
#             "abstracts": "show",
#             "size": "100",
#             "order": "-announced_date_first",
#             "format": "rss"
#         }

#     for p in problem:
#         # drop prepositions to get more related papers
#         p_list = p.lower().split(' ')
#         for prep in PREPS:
#             if prep in p_list:
#                 pop_ind = p_list.index(prep)
#                 p_list.pop(pop_ind)
#         _p = ' '.join(p_list)
#         for y in range(start_year, 2024):
#             params['terms-0-term'] = topic
#             params["date-year"]= y
#             params["terms-1-term"] = _p
#             res = get_response(url=URL, params=params)
#             if not res:
#                 print(f"Cannot get response for the query: problem={p}, topic={topic}, year={y}")
#                 continue
            
#             result_data = []
#             print(f"Try to search for more papers with the keywords: topic={topic}, problem={p}, year={y}")
#             try:
#                 data = load_data(res)
#                 result_data += data
#             except Exception as e:
#                 print(e)

#             print("----------------------------------------------------")
#             print("check if these papers actually related to our query using LLMs...")
#             get_drop = filter(result_data)
#             result_data = [item for index, item in enumerate(result_data) if index not in get_drop]
#             if result_data:
#                 print(f"Filtered by LLMs, the following papers are related: {', '.join([p['title'] for p in result_data])}")
#                 for _data in result_data:
#                     rd = {
#                         "title": _data['title'],
#                         "date": _data['year'],
#                         "abstract": _data['abstract'],
#                         "keywords": [topic, p]
#                     }
#                     temp = get_graph_construc_info(title=_data['title'], arxiv_id=_data['arxiv_id'])
#                     if temp:
#                         rd.update(temp)
#                         results.append(rd)
#                     else:
#                         continue
#             else:
#                 print("No paper matched our query")

#     if results:
#         res_df = pd.DataFrame(results)
#         res_df = res_df.drop_duplicates(subset='title', keep="first", ignore_index=True)
#         results = res_df.to_dict(orient='records')

#     return results
            
# related work papers part
def get_arxiv_id(title):
    def simple_string(string):
        return string.lower().replace(" ", "").strip()
    url = "https://arxiv.org/search/"
    params = {
        "query": title,
        "searchtype": "title",
        "abstracts": "show",
        "order": "-announced_date_first",
    }
    res = get_response(url=url, params=params)
    if res:
        soup = BeautifulSoup(res.text, 'html.parser')
        blocks = soup.find_all('li', class_='arxiv-result')
        if not blocks:
            print(f"Cannot find the paper titled *{title}* on arxiv.")
            return
        else:
            for block in blocks: # usually, len(blocks) is either 0 or 1
                # TODO: abstract similarity can be introduced for comparison
                pdf_link = soup.find('a', href=True, string='pdf')['href']
                res_title = block.find('p', class_='title is-5 mathjax').text.strip()
                if simple_string(title) == simple_string(res_title):
                    return re.match(r"https://arxiv\.org/pdf/(\d+\.\d+)", pdf_link).group(1)
                else:
                    print(f"Only found paper titled *{res_title}* when searching for the one titled *{title}*")

def step_into_rw_mentioned(init_graph: Graph) -> List[Dict]:
    papers = init_graph.entity[init_graph.entity['entity type']=="paper"].to_dict(orient='records')
    results = []

    print("---------------------------------------------")
    print("Downloading papers mentioned in related work:")
    for p in papers:
        if p != init_graph.title:
            print(f"- {p.get('entity name')}")
    
    for p in papers:
        try:
            p_title = p['entity name']
            p_abstract = p['description']
            p_year = p['timestamp']
            rd = {
                "title": p_title,
                "date": p_year,
                "abstract": p_abstract
            }
            if p_title != init_graph.title:
                arxiv_id = get_arxiv_id(p_title)
                if arxiv_id:
                    temp = get_graph_construc_info(p_title, arxiv_id)
                    if temp:
                        rd.update(temp)
                        results.append(rd)
                    else:
                        continue
                else:
                    print(f"Cannot find the paper titled *{p_title}* on arxiv")
        except Exception as e:
            print(f"Exception occurred when get the paper mentioned in related work: {e}, paper info: {p}")
            continue
    print("---------------------------------------------------")
    return results

            

if __name__ == "__main__":
    import json
    res = extend("in-context learning", ['Large Language Model reasoning ability', 'math word problem'], 2022)
    with open('./test_extend.json', 'w') as f:
        json.dump(res, f, indent=4)