#!/usr/bin/env python3

import os
import argparse
import json
import csv
import time
import re
import requests
from pathlib import Path
import openreview
from typing import List, Dict

class OpenReviewDownloader:
    def __init__(self, output_dir: str = "papers", max_papers: int = None):
        """
        Initialize OpenReview downloader
        
        Args:
            output_dir: Output directory
            max_papers: Maximum number of papers to download
        """
        self.client = openreview.Client(baseurl='https://api.openreview.net')
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.max_papers = max_papers
        self.downloaded_count = 0
        self.error_ids = []
        
        # Supported top conferences list, sorted by priority
        self.top_venues = ['ICLR', 'NeurIPS', 'ICML']
        
    def generate_venue_years(self, start_year: int = 2025, years_back: int = 10):
        """
        Generate list of conference years, starting from the most recent
        
        Args:
            start_year: Starting year
            years_back: Number of years to look back
            
        Returns:
            List of [(venue, year), ...]
        """
        venue_years = []
        for year in range(start_year, start_year - years_back, -1):
            for venue in self.top_venues:
                # NeurIPS was called NIPS before 2017
                if venue == 'NeurIPS' and year <= 2017:
                    venue_years.append(('NIPS', year))
                else:
                    venue_years.append((venue, year))
        return venue_years
        
    def search_papers_by_venue(self, venue: str, keywords: List[str] = None) -> List[Dict]:
        """
        Search papers by conference name
        
        Args:
            venue: Conference name, e.g. "ICLR 2024", "NeurIPS 2023"
            keywords: List of keywords to filter papers
            
        Returns:
            List of paper information
        """
        print(f"Searching conference: {venue}")
        
        try:
            submissions = []
            # Try multiple possible invitation formats
            venue_clean = venue.replace(' ', '')
            venue_parts = venue.split()
            
            possible_invitations = [
                # Most common format (Blind_Submission is usually correct)
                f"{venue_parts[0]}.cc/{venue_parts[1]}/Conference/-/Blind_Submission" if len(venue_parts) >= 2 else None,
                f"{venue_parts[0]}.cc/{venue_parts[1]}/Conference/-/Submission" if len(venue_parts) >= 2 else None,
                f"{venue_parts[0]}.cc/{venue_parts[1]}/Conference/Submission" if len(venue_parts) >= 2 else None,
                # Standard format
                f"{venue_clean}.cc/-/Submission",
                f"{venue_clean}.cc/Submission",
                f"{venue_clean}/-/Submission", 
                f"{venue_clean}/Submission",
                f"{venue_clean}/Paper",
                # Other possible formats
                f"{venue_parts[0]}/{venue_parts[1]}/Conference/-/Submission" if len(venue_parts) >= 2 else None,
                f"{venue_parts[0]}/{venue_parts[1]}/-/Submission" if len(venue_parts) >= 2 else None,
            ]
            
            # Filter None values
            possible_invitations = [inv for inv in possible_invitations if inv is not None]
            
            print(f"Number of invitation formats to try: {len(possible_invitations)}")
            
            for i, invitation in enumerate(possible_invitations):
                try:
                    print(f"Trying {i+1}/{len(possible_invitations)}: {invitation}")
                    submissions = self.client.get_notes(invitation=invitation, limit=1000)
                    if submissions:
                        print(f"✓ Found conference invitation: {invitation} (papers: {len(submissions)})")
                        # Get all papers, not just the first 1000
                        if len(submissions) == 1000:
                            print("  There might be more papers, trying to get all...")
                            submissions = self.client.get_notes(invitation=invitation)
                            print(f"  Actual total papers: {len(submissions)}")
                        break
                except Exception as e:
                    print(f"  ✗ Failed: {str(e)[:80]}...")
                    continue
            
            # If invitation method fails, try searching by venue content
            if not submissions:
                print("Not found via invitation, trying venue content search...")
                try:
                    submissions = self.client.get_notes(content={'venue': venue}, limit=1000)
                    if submissions:
                        print(f"✓ Found {len(submissions)} papers via venue content search")
                except Exception as e:
                    print(f"Venue content search also failed: {e}")
            
            if not submissions:
                print(f"No papers found for conference {venue}")
                return []
                
            print(f"Found {len(submissions)} papers")
            
            papers = []
            for note in submissions:
                # Filter accepted papers (if decision info exists)
                if hasattr(note, 'content') and note.content:
                    paper_info = {
                        'id': note.id,
                        'title': note.content.get('title', ''),
                        'abstract': note.content.get('abstract', ''),
                        'authors': note.content.get('authors', []),
                        'pdf_url': note.content.get('pdf', ''),
                        'venue': venue
                    }
                    
                    # Keyword filtering
                    if keywords:
                        text_to_search = f"{paper_info['title']} {paper_info['abstract']}".lower()
                        if not any(keyword.lower() in text_to_search for keyword in keywords):
                            continue
                    
                    papers.append(paper_info)
                    
            print(f"After filtering, obtained {len(papers)} relevant papers")
            return papers
            
        except Exception as e:
            print(f"Error searching papers: {e}")
            return []
    
    def download_pdf(self, paper_info: Dict, retries: int = 3) -> bool:
        """
        Download paper PDF
        
        Args:
            paper_info: Paper information dictionary
            retries: Number of retries
            
        Returns:
            Whether download was successful
        """
        paper_id = paper_info['id']
        title = paper_info['title']
        pdf_url = paper_info.get('pdf_url', '')
        
        # Clean filename
        safe_title = re.sub(r'[^\w\s-]', '', title)[:100]
        safe_title = re.sub(r'[-\s]+', '_', safe_title)
        filename = f"{paper_id}_{safe_title}.pdf"
        filepath = self.output_dir / filename
        
        # Check if file already exists
        if filepath.exists():
            print(f"File already exists: {filename}")
            return True
        
        # If no direct PDF URL, try to get via API
        if not pdf_url:
            try:
                note = self.client.get_note(paper_id)
                if hasattr(note, 'content') and 'pdf' in note.content:
                    pdf_url = note.content['pdf']
            except Exception as e:
                print(f"Failed to get PDF link for paper {paper_id}: {e}")
                return False
        
        if not pdf_url:
            print(f"Paper {paper_id} has no PDF link")
            return False
        
        # If PDF URL is relative path, convert to absolute path
        if pdf_url.startswith('/'):
            pdf_url = f"https://openreview.net{pdf_url}"
        
        # Download PDF
        for attempt in range(retries):
            try:
                print(f"Downloading: {title} ({paper_id}) - Attempt {attempt + 1}/{retries}")
                
                headers = {
                    'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
                }
                
                response = requests.get(pdf_url, headers=headers, timeout=30, stream=True)
                response.raise_for_status()
                
                # Check content type
                content_type = response.headers.get('content-type', '')
                if 'application/pdf' not in content_type:
                    print(f"Warning: Downloaded file may not be PDF format (Content-Type: {content_type})")
                
                # Save file
                with open(filepath, 'wb') as f:
                    for chunk in response.iter_content(chunk_size=8192):
                        if chunk:
                            f.write(chunk)
                
                file_size = filepath.stat().st_size
                if file_size < 1024:  # File smaller than 1KB, might be error page
                    filepath.unlink()
                    print(f"Downloaded file too small ({file_size} bytes), might be error page")
                    continue
                
                print(f"Download successful: {filename} ({file_size / 1024:.1f} KB)")
                return True
                
            except Exception as e:
                print(f"Download failed (attempt {attempt + 1}/{retries}): {e}")
                if filepath.exists():
                    filepath.unlink()
                time.sleep(2)  # Wait before retry
        
        self.error_ids.append(paper_id)
        return False
    
    def crawl_and_download(self, venue: str, keywords: List[str] = None) -> None:
        """
        Crawl and download papers from specified conference
        
        Args:
            venue: Conference name
            keywords: Keyword filter list
        """
        papers = self.search_papers_by_venue(venue, keywords)
        
        if not papers:
            print("No papers matching the criteria found")
            return
        
        # Limit download count
        if self.max_papers and len(papers) > self.max_papers:
            papers = papers[:self.max_papers]
            print(f"Limiting download count to {self.max_papers} papers")
        
        # Save paper information to JSON file
        info_file = self.output_dir / f"{venue.replace(' ', '_')}_papers_info.json"
        with open(info_file, 'w', encoding='utf-8') as f:
            json.dump(papers, f, ensure_ascii=False, indent=2)
        print(f"Paper information saved to: {info_file}")
        
        # Download PDF files
        success_count = 0
        start_time = time.time()
        
        for i, paper in enumerate(papers):
            print(f"\nProgress: {i + 1}/{len(papers)}")
            
            if self.download_pdf(paper):
                success_count += 1
            
            # Control download speed to avoid stressing the server
            time.sleep(1)
        
        end_time = time.time()
        
        print("\nDownload completed!")
        print(f"Successfully downloaded: {success_count}/{len(papers)} papers")
        print(f"Total time: {end_time - start_time:.1f} seconds")
        
        # Save error ID list
        if self.error_ids:
            error_file = self.output_dir / 'error_ids.json'
            with open(error_file, 'w', encoding='utf-8') as f:
                json.dump(self.error_ids, f, ensure_ascii=False, indent=2)
            print(f"Download failed paper IDs saved to: {error_file}")
    
    def search_papers_by_topics(self, topics: List[str], target_count: int = 35, start_year: int = 2025) -> List[Dict]:
        """
        Search papers by topic keywords in top conferences, searching from newest to oldest year until target count is reached
        
        Args:
            topics: List of topic keywords
            target_count: Target number of papers
            start_year: Starting search year
            
        Returns:
            List of paper information
        """
        print(f"Searching topics: {topics}")
        print(f"Target count: {target_count} papers")
        
        all_papers = []
        venue_years = self.generate_venue_years(start_year)
        
        for venue, year in venue_years:
            if len(all_papers) >= target_count:
                print(f"Reached target count {target_count}, stopping search")
                break
                
            venue_name = f"{venue} {year}"
            print(f"\nSearching: {venue_name}")
            
            papers = self.search_papers_by_venue(venue_name, topics)
            if papers:
                print(f"Found {len(papers)} relevant papers in {venue_name}")
                all_papers.extend(papers)
                print(f"Current total: {len(all_papers)} papers")
            else:
                print(f"No relevant papers found in {venue_name}")
        
        # If exceeds target count, truncate to first target_count papers
        if len(all_papers) > target_count:
            all_papers = all_papers[:target_count]
            print(f"Truncating to first {target_count} papers")
        
        print(f"\nSearch completed! Total found {len(all_papers)} papers")
        return all_papers
    
    def download_papers_by_topics(self, topics: List[str], target_count: int = 35, start_year: int = 2025) -> None:
        """
        Search and download papers by topics
        
        Args:
            topics: List of topic keywords
            target_count: Target number of papers
            start_year: Starting search year
        """
        papers = self.search_papers_by_topics(topics, target_count, start_year)
        
        if not papers:
            print("No papers matching the criteria found")
            return
        
        # Save paper information to JSON file
        topics_str = "_".join(topics)
        info_file = self.output_dir / f"papers_{topics_str}_info.json"
        with open(info_file, 'w', encoding='utf-8') as f:
            json.dump(papers, f, ensure_ascii=False, indent=2)
        print(f"Paper information saved to: {info_file}")
        
        # Download PDF files
        success_count = 0
        start_time = time.time()
        
        for i, paper in enumerate(papers):
            print(f"\nProgress: {i + 1}/{len(papers)}")
            
            if self.download_pdf(paper):
                success_count += 1
            
            # Control download speed to avoid stressing the server
            time.sleep(1)
        
        end_time = time.time()
        
        print("\nDownload completed!")
        print(f"Successfully downloaded: {success_count}/{len(papers)} papers")
        print(f"Total time: {end_time - start_time:.1f} seconds")
        
        # Save error ID list
        if self.error_ids:
            error_file = self.output_dir / 'error_ids.json'
            with open(error_file, 'w', encoding='utf-8') as f:
                json.dump(self.error_ids, f, ensure_ascii=False, indent=2)
            print(f"Download failed paper IDs saved to: {error_file}")
    
    def download_from_file(self, file_path: str) -> None:
        """
        Batch download papers from file
        
        Args:
            file_path: Path to file containing paper IDs or information
        """
        file_ext = os.path.splitext(file_path)[-1].lower()
        
        if file_ext == '.csv':
            papers = []
            with open(file_path, 'r', encoding='utf-8') as f:
                reader = csv.reader(f)
                for row in reader:
                    if row and row[0]:
                        papers.append({'id': row[0].strip()})
        elif file_ext == '.json':
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                if isinstance(data, list):
                    if isinstance(data[0], str):
                        papers = [{'id': paper_id} for paper_id in data]
                    else:
                        papers = data
                else:
                    papers = [data]
        else:
            raise ValueError(f"Unsupported file format: {file_ext}")
        
        print(f"Loaded {len(papers)} papers from file {file_path}")
        
        # Supplement paper information and download
        success_count = 0
        for i, paper in enumerate(papers):
            print(f"\nProgress: {i + 1}/{len(papers)}")
            
            # If only ID, try to get complete information
            if 'title' not in paper:
                try:
                    note = self.client.get_note(paper['id'])
                    paper.update({
                        'title': note.content.get('title', ''),
                        'abstract': note.content.get('abstract', ''),
                        'authors': note.content.get('authors', []),
                        'pdf_url': note.content.get('pdf', ''),
                    })
                except Exception as e:
                    print(f"Failed to get paper {paper['id']} information: {e}")
                    continue
            
            if self.download_pdf(paper):
                success_count += 1
            
            time.sleep(1)
        
        print("\nBatch download completed!")
        print(f"Successfully downloaded: {success_count}/{len(papers)} papers")
    
    def download_single_paper(self, paper_id: str) -> None:
        """
        Download single paper
        
        Args:
            paper_id: Paper ID
        """
        try:
            note = self.client.get_note(paper_id)
            paper_info = {
                'id': paper_id,
                'title': note.content.get('title', ''),
                'abstract': note.content.get('abstract', ''),
                'authors': note.content.get('authors', []),
                'pdf_url': note.content.get('pdf', ''),
            }
            
            if self.download_pdf(paper_info):
                print("Paper download successful!")
            else:
                print("Paper download failed!")
                
        except Exception as e:
            print(f"Failed to download paper {paper_id}: {e}")



def main():
    parser = argparse.ArgumentParser(description='OpenReview paper download tool')
    parser.add_argument('--request', choices=['crawl', 'download', 'topics'], required=True,
                        help='Operation type: crawl=crawl conference papers, download=download from file/ID, topics=search top conferences by topics')
    parser.add_argument('--venues', nargs='+', help='Conference name list, e.g. "ICLR 2024" "NeurIPS 2023" "ICML 2024"')
    parser.add_argument('--topics', nargs='+', help='Topic keyword list for filtering papers')
    parser.add_argument('--target_count', type=int, default=35, help='Target number of papers, default 35')
    parser.add_argument('--start_year', type=int, default=2025, help='Starting search year, default 2025')
    parser.add_argument('--venue', help='Compatible old parameter, single conference name')
    parser.add_argument('--keywords', nargs='+', help='Compatible old parameter, keyword list')
    parser.add_argument('--paper_id', help='Single paper ID')
    parser.add_argument('--paper_file', help='CSV or JSON file containing paper IDs')
    parser.add_argument('--output_dir', default='papers', help='Output directory')
    parser.add_argument('--max_papers', type=int, help='Maximum number of papers to download')

    args = parser.parse_args()

    # Parameter validation
    if args.request == 'crawl' and not (args.venues or args.venue):
        parser.error("crawl mode requires --venues or --venue parameter")

    if args.request == 'topics' and not args.topics:
        parser.error("topics mode requires --topics parameter")

    if args.request == 'download' and not (args.paper_id or args.paper_file):
        parser.error("download mode requires --paper_id or --paper_file parameter")

    # Prefer --topics for topic keywords, otherwise compatible with --keywords
    topics = args.topics if args.topics else args.keywords

    # Prefer --venues for conference list, otherwise compatible with --venue
    venues = args.venues if args.venues else ([args.venue] if args.venue else [])

    # Create downloader
    downloader = OpenReviewDownloader(args.output_dir, args.max_papers)

    try:
        if args.request == 'topics':
            # Topic search mode
            downloader.download_papers_by_topics(topics, args.target_count, args.start_year)
            
        elif args.request == 'crawl':
            # Multi-conference loop
            for venue in venues:
                print(f"\n==== Processing conference: {venue} ====")
                papers = downloader.search_papers_by_venue(venue, topics)
                if len(papers) < 35:
                    print(f"Warning: Conference {venue} has only {len(papers)} relevant papers for topics {topics}, less than 35!")
                # Limit download count to ensure at least 35
                papers_to_download = papers[:max(len(papers), 35)] if papers else []
                # Save and download
                if papers_to_download:
                    # Save paper information to JSON file
                    info_file = Path(args.output_dir) / f"{venue.replace(' ', '_')}_papers_info.json"
                    with open(info_file, 'w', encoding='utf-8') as f:
                        json.dump(papers_to_download, f, ensure_ascii=False, indent=2)
                    print(f"Paper information saved to: {info_file}")
                    # Download PDF files
                    success_count = 0
                    for i, paper in enumerate(papers_to_download):
                        print(f"\nProgress: {i + 1}/{len(papers_to_download)}")
                        if downloader.download_pdf(paper):
                            success_count += 1
                        time.sleep(1)
                    print(f"\nConference {venue} download completed! Successfully downloaded: {success_count}/{len(papers_to_download)} papers")
                else:
                    print(f"Conference {venue} found no relevant papers, skipping download.")

        elif args.request == 'download':
            if args.paper_id:
                downloader.download_single_paper(args.paper_id)
            elif args.paper_file:
                downloader.download_from_file(args.paper_file)

    except KeyboardInterrupt:
        print("\nDownload interrupted by user")
    except Exception as e:
        print(f"Program execution error: {e}")


if __name__ == "__main__":
    main()