from __future__ import annotations

import json
import os
from dataclasses import dataclass
from string import Template
from typing import Any, Dict, Iterable, List, Optional

from tqdm import tqdm

from .common import HFModelBundle, load_hf_model, generate_outputs_with_logits, generate_multiple_outputs_with_logits
from .parsing import parse_answer
from ..utils.io import read_json, write_json, write_jsonl_append, ensure_parent_dir


@dataclass
class InferenceIO:
    input_file: str
    output_file: str
    cache_file: Optional[str] = None
    log_file: Optional[str] = None


def _format_accumulated_news(accumulated_news: Optional[List[Dict[str, Any]]]) -> str:
    if not accumulated_news:
        return "No news updates available."
    return "\n".join(f"- [{item.get('start_time')}] {item.get('best_news', '')}" for item in accumulated_news)


def run_verbalized(
    io: InferenceIO,
    model_name: str,
    templates: Dict[str, Template],
    cache_dir: Optional[str] = None,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.95,
) -> None:
    """Run single-sample verbalized inference (no logits arrays saved).

    Expected input JSON fields per entry: id, title, open_time, start_time, best_news
    Will produce raw outputs and parsed confidence/trend fields similar to legacy scripts.
    """
    data = read_json(io.input_file)

    # Load existing JSONL cache if present into a dict by id
    cached_results: Dict[Any, Dict[str, Any]] = {}
    if io.cache_file and os.path.exists(io.cache_file):
        with open(io.cache_file, "r", encoding="utf-8") as f:
            for line in f:
                if line.strip():
                    row = json.loads(line)
                    cached_results[row.get("id")] = row

    bundle: HFModelBundle = load_hf_model(model_name, cache_dir=cache_dir)

    results: List[Dict[str, Any]] = []
    for entry in tqdm(data, desc=f"Generating answers with {model_name}", unit="question"):
        qid = entry.get("id")
        if qid in cached_results:
            results.append(cached_results[qid])
            continue

        title = entry.get("title", "")
        open_time = entry.get("open_time", "")
        start_time = entry.get("start_time", "")
        best_news = entry.get("best_news", "")

        # Base prompts (caller can choose which keys exist in templates)
        t_wo = templates.get("without_news")
        t_w = templates.get("with_news")
        t_trend = templates.get("predict_trend")

        raw_no = None
        raw_with = None
        raw_trend = None

        if t_wo:
            raw_no, *_ = generate_outputs_with_logits(
                bundle,
                t_wo.safe_substitute(title=title, open_time=open_time),
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
            )
        if t_w:
            raw_with, *_ = generate_outputs_with_logits(
                bundle,
                t_w.safe_substitute(title=title, start_time=start_time, best_news=best_news),
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
            )
        if t_trend:
            raw_trend, *_ = generate_outputs_with_logits(
                bundle,
                t_trend.safe_substitute(title=title, start_time=start_time, best_news=best_news),
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
            )

        # Parsing kept light; evaluation package handles metrics downstream
        result = {
            "id": qid,
            "question": title,
            "open_time": open_time,
            "start_time": start_time,
            "best_news": best_news,
            "model": model_name,
            "raw_output_confidence_no_news": raw_no,
            "raw_output_confidence_with_news": raw_with,
            "raw_output_predicted_trend_direct": raw_trend,
        }
        results.append(result)

        if io.cache_file:
            write_jsonl_append(io.cache_file, [result])

    write_json(io.output_file, results, indent=2)


def run_sampling_with_logits(
    io: InferenceIO,
    model_name: str,
    templates: Dict[str, Template],
    n_samples: int,
    cache_dir: Optional[str] = None,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.95,
) -> None:
    """Run multi-sample inference and store per-sample token-level probabilities/logits for both prompts."""
    data = read_json(io.input_file)

    cached_results: Dict[Any, Dict[str, Any]] = {}
    if io.cache_file and os.path.exists(io.cache_file):
        with open(io.cache_file, "r", encoding="utf-8") as f:
            for line in f:
                if line.strip():
                    row = json.loads(line)
                    cached_results[row.get("id")] = row

    bundle: HFModelBundle = load_hf_model(model_name, cache_dir=cache_dir)

    results: List[Dict[str, Any]] = []
    for entry in tqdm(data, desc=f"Generating {n_samples} answers per prompt", unit="question"):
        qid = entry.get("id")
        if qid in cached_results:
            results.append(cached_results[qid])
            continue

        title = entry.get("title", "")
        open_time = entry.get("open_time", "")
        start_time = entry.get("start_time", "")
        best_news = entry.get("best_news", "")

        t_wo = templates.get("without_news")
        t_w = templates.get("with_news")

        outputs_no_news: List[Dict[str, Any]] = []
        outputs_with_news: List[Dict[str, Any]] = []

        if t_wo:
            prompt_no = t_wo.safe_substitute(title=title, open_time=open_time)
            outputs_no_news = generate_multiple_outputs_with_logits(
                bundle,
                prompt_no,
                n_samples=n_samples,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
            )
            for o in outputs_no_news:
                o["parsed_answer"] = parse_answer(o.get("generated_text", ""))

        if t_w:
            prompt_with = t_w.safe_substitute(title=title, start_time=start_time, best_news=best_news)
            outputs_with_news = generate_multiple_outputs_with_logits(
                bundle,
                prompt_with,
                n_samples=n_samples,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
            )
            for o in outputs_with_news:
                o["parsed_answer"] = parse_answer(o.get("generated_text", ""))

        result = {
            "id": qid,
            "question": title,
            "open_time": open_time,
            "start_time": start_time,
            "best_news": best_news,
            "model": model_name,
            "n_samples": n_samples,
            "outputs_no_news": outputs_no_news,
            "outputs_with_news": outputs_with_news,
            "golden_trend": entry.get("trend", "Unknown"),
        }
        results.append(result)
        if io.cache_file:
            write_jsonl_append(io.cache_file, [result])

    write_json(io.output_file, results, indent=2)


def run_verbalized_with_history(
    io: InferenceIO,
    model_name: str,
    templates: Dict[str, Template],
    cache_dir: Optional[str] = None,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.95,
) -> None:
    """Like run_verbalized, but uses accumulated history context if present under 'accumulated_news'."""
    data = read_json(io.input_file)

    cached_results: Dict[Any, Dict[str, Any]] = {}
    if io.cache_file and os.path.exists(io.cache_file):
        with open(io.cache_file, "r", encoding="utf-8") as f:
            for line in f:
                if line.strip():
                    row = json.loads(line)
                    cached_results[row.get("id")] = row

    bundle: HFModelBundle = load_hf_model(model_name, cache_dir=cache_dir)

    results: List[Dict[str, Any]] = []
    for entry in tqdm(data, desc=f"Generating answers with {model_name}", unit="question"):
        qid = entry.get("id")
        if qid in cached_results:
            results.append(cached_results[qid])
            continue

        title = entry.get("title", "")
        open_time = entry.get("open_time", "")
        start_time = entry.get("start_time", "")
        best_news = entry.get("best_news", "")
        previous_news = _format_accumulated_news(entry.get("accumulated_news"))

        t_wo = templates.get("without_news")
        t_w = templates.get("with_news")
        t_trend = templates.get("predict_trend")

        raw_no = None
        raw_with = None
        raw_trend = None

        if t_wo:
            raw_no, *_ = generate_outputs_with_logits(
                bundle,
                t_wo.safe_substitute(title=title, open_time=open_time),
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
            )
        if t_w:
            raw_with, *_ = generate_outputs_with_logits(
                bundle,
                t_w.safe_substitute(title=title, start_time=start_time, best_news=best_news, previous_news=previous_news),
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
            )
        if t_trend:
            raw_trend, *_ = generate_outputs_with_logits(
                bundle,
                t_trend.safe_substitute(title=title, start_time=start_time, best_news=best_news, previous_news=previous_news),
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
            )

        result = {
            "id": qid,
            "question": title,
            "open_time": open_time,
            "start_time": start_time,
            "best_news": best_news,
            "previous_news": previous_news,
            "model": model_name,
            "raw_output_confidence_no_news": raw_no,
            "raw_output_confidence_with_news": raw_with,
            "raw_output_predicted_trend_direct": raw_trend,
        }
        results.append(result)
        if io.cache_file:
            write_jsonl_append(io.cache_file, [result])

    write_json(io.output_file, results, indent=2)
