from __future__ import annotations

import argparse
from string import Template

from ..inference.runners import InferenceIO, run_verbalized, run_sampling_with_logits, run_verbalized_with_history
from ..utils.config import load_config_from_env


WITHOUT_NEWS_ANSWER = Template(
    """
<question>
${title}
</question>

<metadata>
- Today is ${open_time}.
- You do not have access to any external news updates.
</metadata>

<task>
You are an AI model predicting the likelihood of future events.
Your task is to answer if the following event will occur.
Return the answer in this format after thinking:

{ "answer": "Yes" / "No" }
</task>

<think>
"""
)

WITH_NEWS_ANSWER = Template(
    """
<question>
${title}
</question>

<metadata>
- Today is ${start_time}.
- A recent news update: "${best_news}"
</metadata>

<task>
You are an AI model predicting the likelihood of future events, now incorporating recent news.
Your task is to answer if the following event will occur, given the recent news update.
Return the answer in this format after thinking:

{ "answer": "Yes" / "No" }
</task>

<think>
"""
)

WITHOUT_NEWS_CONF = Template(
    """
<question>
${title}
</question>

<metadata>
- Today is ${open_time}.
- You do not have access to any external news updates.
</metadata>

<task>
Estimate the probability of the event on a 1-10 scale. Return:
{ "confidence": X }
</task>

<think>
"""
)

WITH_NEWS_CONF = Template(
    """
<question>
${title}
</question>

<metadata>
- Today is ${start_time}.
- A recent news update: "${best_news}"
</metadata>

<task>
Estimate the probability of the event on a 1-10 scale. Return:
{ "confidence": X }
</task>

<think>
"""
)

PREDICT_TREND = Template(
    """
<question>
${title}
</question>

<metadata>
- Today is ${start_time}.
- A recent news update: "${best_news}"
</metadata>

<task>
Determine whether the confidence should increase, decrease, or remain the same after the news.
Return:
{ "trend": "Up" / "Down" / "Still" }
</task>

<think>
"""
)

WITH_HISTORY_CONF = Template(
    """
<question>
${title}
</question>

<metadata>
- Today is ${start_time}.
- Previous news timeline:\n${previous_news}
- A recent news update: "${best_news}"
</metadata>

<task>
Estimate the probability of the event on a 1-10 scale. Return:
{ "confidence": X }
</task>

<think>
"""
)

WITH_HISTORY_TREND = Template(
    """
<question>
${title}
</question>

<metadata>
- Today is ${start_time}.
- Previous news timeline:\n${previous_news}
- A recent news update: "${best_news}"
</metadata>

<task>
Determine whether the confidence should increase, decrease, or remain the same after the news.
Return:
{ "trend": "Up" / "Down" / "Still" }
</task>

<think>
"""
)


def main() -> None:
    cfg = load_config_from_env()

    p = argparse.ArgumentParser(description="ICLR inference CLI")
    sub = p.add_subparsers(dest="cmd", required=True)

    # verbalized (single sample, confidence/trend verbalized)
    sp = sub.add_parser("verbalized", help="Run single-sample verbalized inference")
    sp.add_argument("model", help="HF model name")
    sp.add_argument("input", help="Input JSON path")
    sp.add_argument("output", help="Output JSON path")
    sp.add_argument("--cache", default=None, help="JSONL cache path")

    # sampling-with-logits
    sp = sub.add_parser("sampling", help="Run multi-sample inference and store logits")
    sp.add_argument("model", help="HF model name")
    sp.add_argument("input", help="Input JSON path")
    sp.add_argument("output", help="Output JSON path")
    sp.add_argument("--cache", default=None, help="JSONL cache path")
    sp.add_argument("--n-samples", type=int, default=10)

    # verbalized with history
    sp = sub.add_parser("verbalized-history", help="Run single-sample verbalized with accumulated history context")
    sp.add_argument("model", help="HF model name")
    sp.add_argument("input", help="Input JSON path")
    sp.add_argument("output", help="Output JSON path")
    sp.add_argument("--cache", default=None, help="JSONL cache path")

    args = p.parse_args()

    if args.cmd == "verbalized":
        io = InferenceIO(input_file=args.input, output_file=args.output, cache_file=args.cache)
        templates = {
            "without_news": WITHOUT_NEWS_CONF,
            "with_news": WITH_NEWS_CONF,
            "predict_trend": PREDICT_TREND,
        }
        run_verbalized(io, args.model, templates, cache_dir=cfg.cache_path)
        return

    if args.cmd == "sampling":
        io = InferenceIO(input_file=args.input, output_file=args.output, cache_file=args.cache)
        templates = {
            "without_news": WITHOUT_NEWS_ANSWER,
            "with_news": WITH_NEWS_ANSWER,
        }
        run_sampling_with_logits(io, args.model, templates, args.n_samples, cache_dir=cfg.cache_path)
        return

    if args.cmd == "verbalized-history":
        io = InferenceIO(input_file=args.input, output_file=args.output, cache_file=args.cache)
        templates = {
            "without_news": WITHOUT_NEWS_CONF,
            "with_news": WITH_HISTORY_CONF,
            "predict_trend": WITH_HISTORY_TREND,
        }
        run_verbalized_with_history(io, args.model, templates, cache_dir=cfg.cache_path)
        return


if __name__ == "__main__":
    main()
