from __future__ import annotations

import os
import json
import logging
from typing import Any, Dict, List, Optional

import requests

from ..utils.config import PipelineConfig
from ..utils.dates import format_date_iso
from ..utils.io import read_json, write_json
from ..utils.logging_utils import get_logger


DEFAULT_COMMENTS_API_URL = "https://www.metaculus.com/api/comments/"


class MetaculusCommentFetcher:
    """Fetch comments for Metaculus posts with simple file cache.

    Provide `api_key` via PipelineConfig.metaculus_api_key or direct constructor param.
    Cache and output paths must be provided by the caller.
    """

    def __init__(
        self,
        api_key: Optional[str],
        api_url: str = DEFAULT_COMMENTS_API_URL,
        cache_file: Optional[str] = None,
        logger: Optional[logging.Logger] = None,
    ) -> None:
        self.api_key = api_key
        self.api_url = api_url
        self.cache_file = cache_file
        self.logger = logger or get_logger(self.__class__.__name__)
        self.cache: Dict[str, Any] = self._load_cache()

    def _load_cache(self) -> Dict[str, Any]:
        if self.cache_file and os.path.exists(self.cache_file):
            with open(self.cache_file, "r", encoding="utf-8") as f:
                try:
                    return json.load(f)
                except json.JSONDecodeError:
                    self.logger.warning("Cache file corrupt; starting empty cache")
        return {}

    def _save_cache(self) -> None:
        if not self.cache_file:
            return
        os.makedirs(os.path.dirname(self.cache_file) or ".", exist_ok=True)
        with open(self.cache_file, "w", encoding="utf-8") as f:
            json.dump(self.cache, f, indent=2)

    @staticmethod
    def _cache_key(post_id: int) -> str:
        return str(post_id)

    def fetch_comments(self, post_id: int) -> List[Dict[str, Any]]:
        key = self._cache_key(post_id)
        if key in self.cache:
            self.logger.info("Using cached comments for post %s", post_id)
            return [
                {
                    "created_at": format_date_iso(c.get("created_at")),
                    "text": c.get("text", ""),
                }
                for c in self.cache[key]
            ]

        params = {"post": post_id, "limit": 1000, "is_private": "false"}
        headers = {"Authorization": f"Token {self.api_key}"} if self.api_key else {}
        self.logger.info("Fetching comments for post %s", post_id)
        resp = requests.get(self.api_url, headers=headers, params=params, timeout=30)
        if resp.status_code != 200:
            self.logger.error("Metaculus API request failed (%s): %s", resp.status_code, resp.text)
            return []
        results = (resp.json() or {}).get("results", [])
        extracted = [
            {"created_at": format_date_iso(c.get("created_at")), "text": c.get("text", "")}
            for c in results
        ]
        self.cache[key] = extracted
        self._save_cache()
        return extracted


def attach_comments_to_questions(
    input_questions_file: str,
    comments_output_file: str,
    processed_output_file: str,
    fetcher: MetaculusCommentFetcher,
) -> None:
    """Read filtered questions, fetch and attach comments per post, and save outputs.

    - comments_output_file: saves raw comments per post
    - processed_output_file: saves updated dataset with comments attached at top level and matched into history dates
    """
    data = read_json(input_questions_file)
    all_comments: List[Dict[str, Any]] = []
    filtered_with_comments: List[Dict[str, Any]] = []

    for entry in data:
        post_id = entry.get("id")
        comments = fetcher.fetch_comments(int(post_id)) if post_id is not None else []
        if not comments:
            continue

        entry = dict(entry)
        entry["comments"] = comments

        # Attach comment text into history entries when date matches start_time
        for comment in comments:
            created_at = comment.get("created_at")
            for hist in entry.get("history", []) or []:
                if hist.get("start_time") == created_at:
                    hist["comment_text"] = comment.get("text", "")
                    break

        filtered_with_comments.append(entry)
        all_comments.append({"post_id": post_id, "comments": comments})

    write_json(comments_output_file, all_comments, indent=2)
    write_json(processed_output_file, filtered_with_comments, indent=2)
