from __future__ import annotations

import argparse

from ..data import (
    filter_questions,
    MetaculusCommentFetcher,
    attach_comments_to_questions,
    GoogleNewsFetcher,
    attach_news_to_entries,
    analyze_news,
    AnalyzeConfig,
    reformat_analyzed,
    add_accumulated_news,
    attach_history_from_binary,
)
from ..utils.config import load_config_from_env
from ..utils.io import read_json, write_json
from ..utils.logging_utils import get_logger


def main() -> None:
    cfg = load_config_from_env()
    logger = get_logger("data_cli", cfg.log_path)

    p = argparse.ArgumentParser(description="ICLR data pipeline CLI")
    sub = p.add_subparsers(dest="cmd", required=True)

    # filter
    sp = sub.add_parser("filter", help="Filter Metaculus questions JSON")
    sp.add_argument("input", help="Input JSON path")
    sp.add_argument("output", help="Output JSON path")
    sp.add_argument("--min-forecasters", type=int, default=50)
    sp.add_argument("--min-forecaster-count", type=int, default=25)

    # comments
    sp = sub.add_parser("comments", help="Fetch and attach comments to filtered questions")
    sp.add_argument("input", help="Filtered questions JSON")
    sp.add_argument("comments_out", help="Comments output JSON")
    sp.add_argument("processed_out", help="Processed questions with comments JSON")
    sp.add_argument("--cache", default=None, help="Comments cache JSON path")

    # news
    sp = sub.add_parser("news", help="Fetch Google news and attach to entries with comments")
    sp.add_argument("input", help="Filtered questions with comments JSON")
    sp.add_argument("output", help="Filtered questions with news JSON")
    sp.add_argument("--cache", default=None, help="Google API cache JSON path")
    sp.add_argument("--link-cache", default=None, help="News link summary cache JSON path")

    # analyze
    sp = sub.add_parser("analyze", help="Analyze and attach trend + best news")
    sp.add_argument("input", help="Input questions with news JSON")
    sp.add_argument("output", help="Output analyzed JSON")
    sp.add_argument("--cache", default=None, help="SBERT encoding cache JSON path")
    sp.add_argument("--model", default="all-MiniLM-L6-v2")
    sp.add_argument("--lookahead-days", type=int, default=3)
    sp.add_argument("--threshold", type=float, default=0.05)

    # reformat simple
    sp = sub.add_parser("reformat", help="Flatten analyzed dataset (per history item)")
    sp.add_argument("input", help="Analyzed JSON")
    sp.add_argument("output", help="Reformatted JSON")

    # accumulate
    sp = sub.add_parser("accumulate", help="Add accumulated_news timeline to each reformatted row")
    sp.add_argument("input", help="Reformatted JSON")
    sp.add_argument("output", help="Accumulated JSON")

    # reformat-history
    sp = sub.add_parser("reformat-history", help="Attach simplified daily history to accumulated rows")
    sp.add_argument("accumulated", help="Accumulated JSON")
    sp.add_argument("binary", help="Raw binary questions JSON")
    sp.add_argument("output", help="Output JSON with history attached")

    args = p.parse_args()

    if args.cmd == "filter":
        filter_questions(args.input, args.output, args.min_forecasters, args.min_forecaster_count)
        return

    if args.cmd == "comments":
        fetcher = MetaculusCommentFetcher(
            api_key=cfg.metaculus_api_key,
            cache_file=args.cache,
        )
        attach_comments_to_questions(args.input, args.comments_out, args.processed_out, fetcher)
        return

    if args.cmd == "news":
        fetcher = GoogleNewsFetcher(
            api_key=cfg.google_api_key,
            search_engine_id=cfg.google_search_engine_id,
            cache_file=args.cache,
            link_cache_file=args.link_cache,
        )
        entries = read_json(args.input)
        enriched = attach_news_to_entries(entries, fetcher)
        write_json(args.output, enriched, indent=2)
        return

    if args.cmd == "analyze":
        analyze_news(
            args.input,
            args.output,
            AnalyzeConfig(
                model_name=args.model,
                trend_lookahead_days=args.lookahead_days,
                trend_threshold=args.threshold,
                device=cfg.device,
                encoding_cache_file=args.cache,
            ),
        )
        return

    if args.cmd == "reformat":
        reformat_analyzed(args.input, args.output)
        return

    if args.cmd == "accumulate":
        add_accumulated_news(args.input, args.output)
        return

    if args.cmd == "reformat-history":
        attach_history_from_binary(args.accumulated, args.binary, args.output)
        return


if __name__ == "__main__":
    main()
