import calendar
import hashlib
import json
import logging
import re
import time
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import requests

from config import config

logger = logging.getLogger("ckm.arxiv")


class ArxivPaper:
    def __init__(self, arxiv_id: str, published: str, title: str, abstract: str, categories: List[str]):
        self.arxiv_id = arxiv_id
        self.published = published
        self.title = title
        self.abstract = abstract
        self.categories = categories


def parse_arxiv_xml(xml: str) -> List[Dict[str, Any]]:
    papers = []
    entry_regex = r"<entry>([\s\S]*?)<\/entry>"

    for match in re.finditer(entry_regex, xml):
        entry_xml = match.group(1)

        id_match = re.search(r"<id>http:\/\/arxiv\.org\/abs\/(.+?)<\/id>", entry_xml)
        arxiv_id = id_match.group(1) if id_match else "unknown"

        pub_match = re.search(r"<published>(.+?)<\/published>", entry_xml)
        published = pub_match.group(1) if pub_match else datetime.now().isoformat()

        title_match = re.search(r"<title>([\s\S]*?)<\/title>", entry_xml)
        title = re.sub(r"\s+", " ", title_match.group(1)).strip() if title_match else ""

        summary_match = re.search(r"<summary>([\s\S]*?)<\/summary>", entry_xml)
        abstract = re.sub(r"\s+", " ", summary_match.group(1)).strip() if summary_match else ""

        categories = []
        cat_regex = r"<category\s+term=\"([^\"]+)\""
        for cat_match in re.finditer(cat_regex, entry_xml):
            categories.append(cat_match.group(1))

        papers.append({
            "arxiv_id": arxiv_id,
            "published": published,
            "title": title,
            "abstract": abstract,
            "categories": categories if categories else ["cs.AI"],
        })

    return papers


STOPWORDS = {
    "a", "an", "the", "in", "on", "of", "for", "and", "or", "to",
    "is", "are", "was", "were", "with", "by",
}
DATE_FMT = "%Y-%m-%d"
RETRYABLE_STATUS_CODES = {408, 425, 429, 500, 502, 503, 504}


def _add_months(dt: datetime, months: int) -> datetime:
    month = dt.month - 1 + months
    year = dt.year + month // 12
    month = month % 12 + 1
    day = min(dt.day, calendar.monthrange(year, month)[1])
    return dt.replace(year=year, month=month, day=day)


def _normalize_bounds(
    start_date: Optional[str],
    end_date: Optional[str],
    start_year: Optional[int],
    end_year: Optional[int],
) -> Tuple[datetime, datetime]:
    if start_date:
        start_dt = datetime.strptime(start_date[:10], DATE_FMT)
    else:
        start_dt = datetime((start_year or 2019), 1, 1)

    if end_date:
        end_dt = datetime.strptime(end_date[:10], DATE_FMT)
    else:
        inclusive_end_year = end_year or 2026
        end_dt = datetime(inclusive_end_year + 1, 1, 1)

    if end_dt <= start_dt:
        raise ValueError(f"Invalid arXiv date range: {start_dt.date()} .. {end_dt.date()}")

    return start_dt, end_dt


def _iter_windows(start_dt: datetime, end_dt: datetime, window_months: int):
    if window_months <= 0:
        raise ValueError("window_months must be positive")

    current = start_dt
    while current < end_dt:
        next_dt = min(_add_months(current, window_months), end_dt)
        yield current, next_dt
        current = next_dt


def _format_query_lower_bound(dt: datetime) -> str:
    return dt.strftime("%Y%m%d%H%M")


def _format_query_upper_bound(end_exclusive: datetime) -> str:
    return (end_exclusive - timedelta(minutes=1)).strftime("%Y%m%d%H%M")


def _published_in_range(published: str, start_dt: datetime, end_dt: datetime) -> bool:
    published_date = published[:10]
    return start_dt.strftime(DATE_FMT) <= published_date < end_dt.strftime(DATE_FMT)


def _build_query(keyword: str) -> str:
    words = [w for w in keyword.strip().split() if w.lower() not in STOPWORDS]
    return "+AND+".join([f"all:{w}" for w in words]) or "all:*"


def _cache_root() -> Path:
    cache_dir = Path(config["paths"]["arxiv_metadata_cache"])
    cache_dir.mkdir(parents=True, exist_ok=True)
    return cache_dir


def _window_label(window_start: datetime, window_end: datetime) -> str:
    return f"{window_start.strftime(DATE_FMT)} .. {window_end.strftime(DATE_FMT)}"


def _split_window(window_start: datetime, window_end: datetime) -> Optional[datetime]:
    midpoint = window_start + (window_end - window_start) / 2
    midpoint = midpoint.replace(hour=0, minute=0, second=0, microsecond=0)
    if midpoint <= window_start:
        midpoint = window_start + timedelta(days=1)
    if midpoint >= window_end:
        midpoint = window_end - timedelta(days=1)
    return midpoint if window_start < midpoint < window_end else None


def _cache_file(
    keyword: str,
    window_start: datetime,
    window_end: datetime,
    offset: int,
    fetch_count: int,
) -> Path:
    payload = json.dumps({
        "keyword": keyword,
        "window_start": window_start.strftime(DATE_FMT),
        "window_end": window_end.strftime(DATE_FMT),
        "offset": offset,
        "fetch_count": fetch_count,
    }, sort_keys=True)
    digest = hashlib.sha1(payload.encode("utf-8")).hexdigest()
    return _cache_root() / f"{digest}.json"


def _load_cached_page(cache_file: Path) -> Optional[List[Dict[str, Any]]]:
    if not cache_file.exists():
        return None

    try:
        data = json.loads(cache_file.read_text(encoding="utf-8"))
        papers = data.get("papers")
        if isinstance(papers, list):
            logger.info("[ArXiv API] Cache hit: %s", cache_file.name)
            return papers
    except Exception as exc:
        logger.warning("[ArXiv API] Cache read failed for %s: %s", cache_file, exc)

    return None


def _save_cached_page(cache_file: Path, url: str, papers: List[Dict[str, Any]]) -> None:
    payload = {
        "fetched_at": datetime.now().isoformat(timespec="seconds"),
        "url": url,
        "count": len(papers),
        "papers": papers,
    }
    cache_file.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")


def _parse_retry_after_seconds(response: requests.Response) -> Optional[int]:
    retry_after = response.headers.get("Retry-After")
    if not retry_after:
        return None
    try:
        return max(1, int(retry_after))
    except ValueError:
        return None


def _request_with_retry(
    url: str,
    timeout_s: int,
    retries: int,
    user_agent: str,
    retry_backoff_s: int,
) -> Optional[requests.Response]:
    backoff_s = max(1, retry_backoff_s)

    for attempt in range(1, retries + 1):
        try:
            response = requests.get(
                url,
                timeout=(10, timeout_s),
                headers={"User-Agent": user_agent},
            )
        except requests.RequestException as exc:
            logger.warning("[ArXiv API] attempt %d/%d failed: %s", attempt, retries, exc)
            if attempt == retries:
                return None
            time.sleep(backoff_s)
            backoff_s *= 2
            continue

        if response.status_code == 200:
            return response

        if response.status_code in RETRYABLE_STATUS_CODES and attempt < retries:
            # For 429, use longer backoff (arXiv rate limit window is ~30-60s)
            if response.status_code == 429:
                sleep_s = _parse_retry_after_seconds(response) or max(30, backoff_s)
            else:
                sleep_s = _parse_retry_after_seconds(response) or backoff_s
            logger.warning(
                "[ArXiv API] attempt %d/%d got HTTP %d, retrying in %ds",
                attempt,
                retries,
                response.status_code,
                sleep_s,
            )
            time.sleep(sleep_s)
            backoff_s = max(backoff_s * 2, 30)  # floor at 30s after first 429
            continue

        return response

    return None


def _dedupe_papers(papers: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    unique: Dict[str, Dict[str, Any]] = {}
    for paper in papers:
        unique.setdefault(paper["arxiv_id"], paper)
    return sorted(unique.values(), key=lambda paper: paper["published"])


def _record_failure(
    failures: List[Dict[str, Any]],
    window_start: datetime,
    window_end: datetime,
    offset: int,
    detail: str,
) -> None:
    failures.append({
        "window_start": window_start.strftime(DATE_FMT),
        "window_end": window_end.strftime(DATE_FMT),
        "offset": offset,
        "detail": detail,
    })


def _fetch_window(
    keyword: str,
    query_string: str,
    window_start: datetime,
    window_end: datetime,
    remaining_limit: int,
    arxiv_cfg: Dict[str, Any],
    failures: List[Dict[str, Any]],
    split_depth: int = 0,
    page_size: Optional[int] = None,
) -> List[Dict[str, Any]]:
    if remaining_limit <= 0:
        return []

    timeout_s = arxiv_cfg.get("timeout_s", 60)
    retries = arxiv_cfg.get("retries", 3)
    retry_backoff_s = arxiv_cfg.get("retry_backoff_s", 5)
    delay_s = arxiv_cfg.get("delay_ms", 3000) / 1000
    user_agent = arxiv_cfg.get("user_agent", "CKM-Eval/1.0")
    min_page_size = max(1, min(arxiv_cfg.get("min_batch_size", 10), 100))
    max_page_size = min(arxiv_cfg.get("batch_size", 50), 100)
    page_size = min(page_size or max_page_size, max_page_size)
    min_split_days = arxiv_cfg.get("min_split_days", 31)
    max_split_depth = arxiv_cfg.get("max_split_depth", 4)

    date_query = (
        "submittedDate:["
        f"{_format_query_lower_bound(window_start)}+TO+{_format_query_upper_bound(window_end)}]"
    )
    full_query = f"({query_string})+AND+{date_query}"
    window_papers: List[Dict[str, Any]] = []
    current_start = 0

    while len(window_papers) < remaining_limit:
        fetch_count = min(page_size, remaining_limit - len(window_papers))
        url = (
            f"{config['api']['arxiv']['base_url']}?search_query={full_query}&start={current_start}"
            f"&max_results={fetch_count}&sortBy=submittedDate&sortOrder=ascending"
        )

        cache_file = _cache_file(keyword, window_start, window_end, current_start, fetch_count)
        papers = _load_cached_page(cache_file)
        fetched_live = False

        if papers is None:
            logger.info(
                "[ArXiv API] Fetching %s at offset %d (page=%d, depth=%d)",
                _window_label(window_start, window_end),
                current_start,
                fetch_count,
                split_depth,
            )
            response = _request_with_retry(url, timeout_s, retries, user_agent, retry_backoff_s)
            fetched_live = True

            if response is None or response.status_code != 200:
                status_detail = "no response"
                if response is not None:
                    status_detail = f"HTTP {response.status_code}"

                if fetch_count > min_page_size:
                    next_page_size = max(min_page_size, fetch_count // 2)
                    if next_page_size < fetch_count:
                        logger.warning(
                            "[ArXiv API] %s at offset %d, reducing page size %d -> %d",
                            status_detail,
                            current_start,
                            fetch_count,
                            next_page_size,
                        )
                        page_size = next_page_size
                        continue

                span_days = (window_end - window_start).days
                midpoint = _split_window(window_start, window_end)
                if current_start == 0 and midpoint and span_days >= min_split_days and split_depth < max_split_depth:
                    logger.warning(
                        "[ArXiv API] %s for %s, splitting window at %s (depth=%d)",
                        status_detail,
                        _window_label(window_start, window_end),
                        midpoint.strftime(DATE_FMT),
                        split_depth + 1,
                    )
                    left_papers = _fetch_window(
                        keyword,
                        query_string,
                        window_start,
                        midpoint,
                        remaining_limit,
                        arxiv_cfg,
                        failures,
                        split_depth=split_depth + 1,
                    )
                    remaining_after_left = max(0, remaining_limit - len(left_papers))
                    right_papers = _fetch_window(
                        keyword,
                        query_string,
                        midpoint,
                        window_end,
                        remaining_after_left,
                        arxiv_cfg,
                        failures,
                        split_depth=split_depth + 1,
                    )
                    return _dedupe_papers(left_papers + right_papers)[:remaining_limit]

                logger.error(
                    "[ArXiv API] giving up on %s at offset %d (%s)",
                    _window_label(window_start, window_end),
                    current_start,
                    status_detail,
                )
                _record_failure(failures, window_start, window_end, current_start, status_detail)
                return _dedupe_papers(window_papers)

            papers = parse_arxiv_xml(response.text)
            _save_cached_page(cache_file, url, papers)

        if not papers:
            break

        window_papers.extend(papers)
        current_start += len(papers)

        if fetched_live:
            time.sleep(delay_s)

        if len(papers) < fetch_count:
            break

    return _dedupe_papers(window_papers)[:remaining_limit]


def search_arxiv_topic(
    keyword: str,
    max_results: int = 100,
    start_year: Optional[int] = None,
    end_year: Optional[int] = None,
    start_date: Optional[str] = None,
    end_date: Optional[str] = None,
    window_months: Optional[int] = None,
) -> List[Dict[str, Any]]:
    start_dt, end_dt = _normalize_bounds(start_date, end_date, start_year, end_year)
    arxiv_cfg = config["api"]["arxiv"]
    chunk_size = min(arxiv_cfg.get("batch_size", 50), 100)
    window_months = window_months or arxiv_cfg.get("window_months", 12)
    query_string = _build_query(keyword)

    logger.info(
        "[ArXiv API] Query='%s' | range=%s ~ %s | max=%d | window=%d months",
        keyword,
        start_dt.strftime(DATE_FMT),
        end_dt.strftime(DATE_FMT),
        max_results,
        window_months,
    )

    deduped: Dict[str, Dict[str, Any]] = {}
    failures: List[Dict[str, Any]] = []
    delay_s = arxiv_cfg.get("delay_ms", 3000) / 1000
    fetched_any_live = False

    for window_start, window_end in _iter_windows(start_dt, end_dt, window_months):
        if len(deduped) >= max_results:
            break

        # Delay between windows to respect arXiv rate limits (only if prior window hit the API)
        if fetched_any_live:
            time.sleep(delay_s)

        window_results = _fetch_window(
            keyword,
            query_string,
            window_start,
            window_end,
            max_results - len(deduped),
            arxiv_cfg,
            failures,
            page_size=chunk_size,
        )
        # Track whether this window made any live API calls (vs all cache hits)
        cache_file_check = _cache_file(keyword, window_start, window_end, 0,
                                       min(chunk_size, max_results - len(deduped)))
        fetched_any_live = not cache_file_check.exists() or fetched_any_live

        for paper in window_results:
            if _published_in_range(paper["published"], start_dt, end_dt):
                deduped.setdefault(paper["arxiv_id"], paper)

    results = sorted(deduped.values(), key=lambda paper: paper["published"])[:max_results]
    if failures:
        logger.warning("[ArXiv API] Completed with %d failed slices; results may be partial", len(failures))
        for failure in failures[:5]:
            logger.warning(
                "[ArXiv API] failed slice: %s .. %s offset=%d detail=%s",
                failure["window_start"],
                failure["window_end"],
                failure["offset"],
                failure["detail"],
            )
        if len(failures) > 5:
            logger.warning("[ArXiv API] ... plus %d more failed slices", len(failures) - 5)
    logger.info(
        "[ArXiv API] Resolved %d unique papers in %s ~ %s",
        len(results),
        start_dt.strftime(DATE_FMT),
        end_dt.strftime(DATE_FMT),
    )
    return results
