from __future__ import annotations

import json
import glob
from typing import Any, Dict, List, Optional, Tuple

import pandas as pd

from ..inference.parsing import aggregate_outputs, compute_predicted_trend
from .metrics import evaluate_predictions


def process_jsonl_files(
    file_pattern: str,
    k_tail: int = 3,
    kappa: float = 0.1,
    threshold: float = 0.05,
    parsed_out_path_builder: Optional[callable] = None,
    eval_out_path_builder: Optional[callable] = None,
) -> List[Tuple[pd.DataFrame, Dict[str, Any]]]:
    """Process JSONL files matching pattern, aggregate outputs, compute trend, and evaluate.

    Each JSON line is expected to contain:
      - golden_trend: 'Up'|'Down'|'Still'
      - outputs_no_news: list of dicts with keys 'generated_text', 'token_probs'
      - outputs_with_news: same format

    You can provide builders for output paths to save parsed rows and evaluation metrics.
    Returns a list of (df, eval_results) per file.
    """
    results: List[Tuple[pd.DataFrame, Dict[str, Any]]] = []

    for filename in glob.glob(file_pattern):
        rows: List[Dict[str, Any]] = []
        with open(filename, "r", encoding="utf-8") as fh:
            for line in fh:
                if not line.strip():
                    continue
                data = json.loads(line)
                agg_no = aggregate_outputs(data.get("outputs_no_news", []), k_tail=k_tail, kappa=kappa)
                agg_with = aggregate_outputs(data.get("outputs_with_news", []), k_tail=k_tail, kappa=kappa)
                p_yes_no = agg_no.get("p_yes")
                p_yes_with = agg_with.get("p_yes")
                pred_trend = compute_predicted_trend(p_yes_no, p_yes_with, threshold=threshold)
                golden_trend = data.get("golden_trend", "Unknown")
                rows.append(
                    {
                        "confidence_no_news": p_yes_no,
                        "confidence_with_news": p_yes_with,
                        "predicted_trend_computed": pred_trend,
                        "golden_trend": golden_trend,
                        "abstain_rate_no_news": agg_no.get("abstain_rate"),
                        "abstain_rate_with_news": agg_with.get("abstain_rate"),
                        "counts_no_news": agg_no.get("counts"),
                        "counts_with_news": agg_with.get("counts"),
                        "n_eff_no_news": agg_no.get("n_eff"),
                        "n_eff_with_news": agg_with.get("n_eff"),
                        "raw_data": data,
                    }
                )
        df = pd.DataFrame(rows)

        parsed_path = parsed_out_path_builder(filename) if parsed_out_path_builder else None
        if parsed_path:
            with open(parsed_path, "w", encoding="utf-8") as f:
                json.dump(json.loads(df.to_json(orient="records")), f, indent=2)

        eval_path = eval_out_path_builder(filename) if eval_out_path_builder else None
        eval_results = evaluate_predictions(df, "predicted_trend_computed", "golden_trend", eval_path)
        results.append((df, eval_results))
    return results
