import numpy as np
import string
import csv
import pandas as pd
import time
import urllib
import urllib.request
import sys
import argparse
import wget
import json
import re
import tarfile
from pathlib import Path
import os
import time
import arxiv

parser = argparse.ArgumentParser(description='Get arxiv data.')
parser.add_argument('--request', default ='crawl' ,help='download arxiv')
parser.add_argument('--arxiv_id', default ='' ,help='download arxiv')
parser.add_argument('--arxiv_file', default ='' ,help='download arxiv')
parser.add_argument('--keywords', nargs='+',help='a list of keywords')
parser.add_argument('--time', help='start time of papers')
parser.add_argument('--output_dir', help='a list of keywords')
args = parser.parse_args()
if args.request=="crawl":
    baseurl = 'http://export.arxiv.org/api/query?search_query='
    args = parser.parse_args()
    keyword_list = args.keywords

    date = pd.Timestamp(str(args.time), tz='US/Pacific')

    i = 0
    for keyword in keyword_list:
        print(keyword)
        if i ==0:
            url = baseurl + 'abs:' + keyword
            i = i + 1         
        else:
            url = url + '+AND+' + 'abs:' + keyword

    url = url+ '&max_results=200'        

    try:       
        arxiv_page = urllib.request.urlopen(url).read()
        arxiv_page = str(arxiv_page)    
        begin=[]
        end = []
        begin_time=[]
        end_time = []

        start = 0
        while True:
            start = arxiv_page.find('<link title="pdf" href=',start)
            if start==-1:
                break
            begin.append(start)
            start = start + len('<link title="pdf" href=')    

        start = 0
        while True:
            start = arxiv_page.find('rel="related" type="application/pdf"/>',start)
            if start==-1:
                break
            end.append(start)
            start = start + len('rel="related" type="application/pdf"/>')    

        start = 0
        while True:
            start = arxiv_page.find('<published>',start)
            if start==-1:
                break
            begin_time.append(start)
            start = start + len('<published>')    
        start = 0
        while True:
            start = arxiv_page.find('</published>',start)
            if start==-1:
                break
            end_time.append(start)
            start = start + len('</published>')    

        rows = []
        for i in range(len(begin)):        
            paper_link = arxiv_page[begin[i]+24:end[i]-2]
            paper_timestamp = arxiv_page[begin_time[i]+11:end_time[i]]

            if pd.to_datetime(paper_timestamp) > date:
                print(i)
                rows.append([paper_link,paper_timestamp])

        with open('arxiv_{}_{}.csv'.format(str(args.keywords),str(args.time)), 'w', newline='') as file:
            writer = csv.writer(file)
            writer.writerows(rows)

    except:
        pass
else:
    ids = []
    preprocessed_papers = []
    total = 0
    if args.arxiv_id:#download source code of latex files
        _id = args.arxiv_id
        print(_id)
        download_url = 'https://arxiv.org/e-print/'+_id
        try:
            tar_filename = wget.download(download_url, out="download_arxiv/"+_id)
            print ('\nDownloaded arXiv:%s tarball as %s' % (_id,tar_filename))
            filename = _id
        except:
            print("error!!!!!!!!")
            pass
        else:
            try:
                tf = tarfile.open("download_arxiv/"+filename)
                tf.extractall(os.path.join(args.output_dir,_id))
            except:
                total+=1
                print("extract archive failed")
                Path("download_arxiv/"+_id).rename("single_latex/"+_id)
            else:
                print("extract archive success")
    elif args.arxiv_file:
        import os
        import csv
        total_papers = 0
        paper_ids = []
        error_ids = []
        # 自动判断csv或json
        file_ext = os.path.splitext(args.arxiv_file)[-1].lower()
        if file_ext == '.csv':
            papers = []
            with open(args.arxiv_file, 'r', encoding='utf-8') as f:
                reader = csv.reader(f)
                for row in reader:
                    if row and row[0]:
                        papers.append(row[0].strip())
        else:
            papers = json.load(open(args.arxiv_file, 'r', encoding='utf-8'))
        total = len(papers)
        start_time = time.time()
        for i,data in enumerate(papers):
            #print(data)
            if i>=50:
                end_time = time.time()
                break
            _id = data.split('/')[-1] 
            if os.path.exists("papers/"+_id) or os.path.exists("single_latex/"+_id) or os.path.exists("download_arixv/"+_id) or os.path.exists("papers_1/"+_id) or os.path.exists("papers_2/"+_id):
                print(i,"/",total)
                print("already exists!!!")
                continue
            try:
                search = arxiv.Search(id_list=[_id])
                client = arxiv.Client()
                find_paper = next(client.results(search))
            except:
                print(f"{_id} search failed!")
                continue
            # 获取pdf链接
            pdf_url = find_paper.pdf_url if hasattr(find_paper, 'pdf_url') else None
            if not pdf_url:
                print(f"{_id} no pdf url!")
                continue
            print(_id)
            download_url = 'https://arxiv.org/e-print/'+_id 
            # download_url = 'http://xxx.itp.ac.cn/e-print/'+_id  
            # 下载前判断文件大小
            # import requests
            # try:
            #     head = requests.head(download_url, timeout=20, allow_redirects=True)
            #     size = int(head.headers.get('Content-Length', 0))
            #     size_mb = size / 1024 / 1024
            #     print(f"{_id} 文件大小: {size_mb:.2f} MB")
            #     if size_mb > 20:
            #         print(f"{_id} 文件大于20MB，跳过")
            #         continue
            # except Exception as e:
            #     print(f"{_id} 获取文件大小失败: {e}")
            #     continue
            try:
                print("download start！！")
                tar_filename = wget.download(download_url, out="download_arxiv/"+_id)
                print ('\nDownloaded arXiv:%s tarball as %s' % (_id,tar_filename))
                filename = _id
            except:
                print("error!!!!")
                error_ids.append(_id)
                pass
            else:
                try:
                    tf = tarfile.open("download_arxiv/"+filename)
                    tf.extractall(os.path.join(args.output_dir,_id))
                except Exception as e:
                    total+=1
                    print("extract archive failed")
                    # 删除已下载的文件
                    try:
                        os.remove("download_arxiv/"+_id)
                    except Exception as e2:
                        print(f"删除下载文件失败: {e2}")
                    # 从papers列表中移除该条目并写回csv/json
                    papers[i] = None
                    # 写回csv/json
                    try:
                        if file_ext == '.csv':
                            with open(args.arxiv_file, 'w', encoding='utf-8', newline='') as fcsv:
                                writer = csv.writer(fcsv)
                                for item in papers:
                                    if item:
                                        writer.writerow([item])
                        else:
                            with open(args.arxiv_file, 'w', encoding='utf-8') as fjson:
                                json.dump([item for item in papers if item], fjson, ensure_ascii=False, indent=2)
                    except Exception as e3:
                        print(f"更新输入文件失败: {e3}")
                    continue
                else:
                    paper_ids.append(_id)
                    print("extract archive success")
            time.sleep(2)
            with open('error_ids.json', 'w', encoding='utf-8') as f:
                json.dump(error_ids, f)
        print(total)
        print(end_time-start_time)

