import pandas as pd
import json
import requests
from tqdm import tqdm
# from .utils.logger import get_logger
# from .utils.constants import DIR_PATH

from utils.logger import get_logger
from utils.constants import DIR_PATH

FILE_PATH = DIR_PATH / "demo"

class GraphExtender:
    def __init__(self, entity, relation, title) -> None:
        self.entity = entity
        self.relation = relation
        self.logger = get_logger("extend graph")
        self.title = title
    
    def create_entity_relation(self, entity1, name, date, description):
        temp_entity = {
            "entity name": name,
            "entity type": 'paper',
            "timestamp": date,
            "description": description
        }
        temp_relation = {
            "entity1": entity1,
            "relation": "related to",
            "entity2": name
        }
        self.entity.loc[len(self.entity)] = temp_entity
        self.relation.loc[len(self.relation)] = temp_relation
    
    def get_paper_from_s2(self, 
                          _query_param, 
                          url='https://api.semanticscholar.org/graph/v1/paper/search'):
        query_param = {
            "fields": "title,publicationDate,abstract,year",
            "limit": 10,
        }
        query_param.update(_query_param)
        tempq = query_param['query'] if 'query' in query_param else json.dumps(query_param)
        max_trial = 10
        trial = 0
        for _ in range(3):
            try:
                searchResponse = requests.get(url, params=query_param)
                while(searchResponse.status_code != 200 and trial <= max_trial):
                    searchResponse = requests.get(url, params=query_param)
                    trial += 1
                if searchResponse.status_code != 200:
                    self.logger.warning(f"Cannot connect to Semantic Scholar.")
                else:
                    search_res = searchResponse.json()
                    if 'data' in search_res:
                        papers = search_res['data']                  
                        return papers
                    else:
                        self.logger.info(f"No related articles found for the query: {tempq}")
                break
            except Exception as e:
                self.logger.warning(f"attemp {_+1}/3 - request failure: {e}")

    def extend_by_keyword(self):
        new_node_cnt = 0
        self.logger.info(f"trying to extend the graph by keyword search...")
        problems = self.entity[self.entity['entity type'] == 'problem']['entity name'].values.tolist()
        methods = self.entity[self.entity['entity type'] == 'method']['entity name'].values.tolist()
        self.logger.info(f"Obtain the entities of problem and method:\nproblem: {', '.join(problems)}\nmethod: {', '.join(methods)},\nand sequentially use these entity names as keywords for retrieval.")
        kw_list = problems + methods
        for q in tqdm(kw_list, desc="extend the graph by keywords"):
            paper_info = self.get_paper_from_s2({"query": q})
            if not paper_info:
                continue
            # import pdb; pdb.set_trace()
            for i in range(len(paper_info)):
                temp = paper_info[i]
                title = temp['title']
                date = temp['publicationDate'] if temp['publicationDate'] else temp['year']
                abstract = temp['abstract']
                if not (abstract and date):
                    continue
                self.create_entity_relation(q, title, date, abstract)
                new_node_cnt += 1
        self.logger.info(f"{new_node_cnt} entity nodes are added.")
        
    def extend_by_recommendation(self):
        self.logger.info(f"trying to extend the graph by s2 recommendation...")
        def get_id(title):
            data = self.get_paper_from_s2({"query": title, "limit":1, "fields": "title"})
            if data:
                pid = data[0]['paperId']
                _title = data[0]['title']
                if _title == title:
                    return pid
            else:
                self.logger.info(f"failed to get paperID for {title}, unable to provide recommendations.")
        pid = get_id(self.title)
        if pid:
            new_node_cnt = 0
            recommendations = self.get_paper_from_s2(url="https://api.semanticscholar.org/recommendations/v1/papers/forpaper/"+pid, _query_param={})
            if not recommendations:
                self.logger.info("no recommendation found.")
                return
            for rec in recommendations:
                title = rec['title']
                date = rec['publicationDate'] if rec['publicationDate'] else rec['year']
                abstract = rec['abstract']
                if not (abstract and date):
                    continue
                self.create_entity_relation(self.title, title, date, abstract)
                new_node_cnt += 1
            self.logger.info(f"{new_node_cnt} entity nodes are added.")
                
    def extend(self):
        self.extend_by_keyword()
        self.extend_by_recommendation()
        # update entity and relation
        self.entity.to_csv(FILE_PATH / "entity.csv", index=False)
        self.relation.to_csv(FILE_PATH / "relation.csv", index=False)
        
if __name__ == "__main__":
    url='https://api.semanticscholar.org/graph/v1/paper/search'
    query_param = {
            "fields": "title,publicationDate,abstract,year",
            "limit": 10,
        }
    _query_param = {"query": 'MedAgents: Large Language Models as Collaborators for Zero-shot Medical Reasoning', "limit":1, "fields": "title"}
    query_param.update(_query_param)
    searchResponse = requests.get(url, params=query_param)
    while(searchResponse.status_code != 200):
        searchResponse = requests.get(url, params=query_param)
        
    #import pdb; pdb.set_trace()
    res = searchResponse.json()