# runners/fewshot_runner.py

import os
import json
import time
import logging
import re
from tqdm import tqdm
import random
from typing import Dict, Optional

from generators.fewshot_endowment_generator import FewShotEndowmentGenerator
from agents.factory import AgentFactory
from modules.token_tracker import MultiAxisTokenTracker

logger = logging.getLogger(__name__)


class FewShotRunner:
    """
    Runs a few-shot baseline experiment:
    - Builds a global prompt with train+valid questions and human answer distributions.
    - For each test question, queries the LLM to predict distributions.
    - Collects predictions in an AggregateResponses-compatible dict.
    """

    def __init__(
        self,
        survey,
        aggregate_dict: Dict[str, Dict[str, float]],
        agent_type: str = "openai",
        model_name: str = "gpt-4o-mini",
        temperature: float = 0.0,
        max_tokens: int = 512,
        output_path: str = "outputs/fewshot_predictions.json",
        token_tracker: Optional[MultiAxisTokenTracker] = None,
    ):
        self.survey = survey
        self.aggregate_dict = aggregate_dict
        self.agent_type = agent_type
        self.model_name = model_name
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.output_path = output_path
        self.token_tracker = token_tracker

        # Build the training prompt once
        self.generator = FewShotEndowmentGenerator(survey, aggregate_dict)
        self.base_prompt = self.generator.build_prompt()

    def _parse_json(self, raw_text: str) -> dict:
        """
        Try to parse the model's output into JSON.
        Cleans common wrappers like ```json ... ```, strips whitespace,
        and logs a warning if parsing fails.
        """
        # Strip markdown code fences if present
        cleaned = raw_text.strip()
        if cleaned.startswith("```"):
            # Remove triple backticks and optional "json" tag
            cleaned = re.sub(r"^```(?:json)?", "", cleaned, flags=re.IGNORECASE).strip()
            cleaned = re.sub(r"```$", "", cleaned).strip()

        # Try to load as JSON
        try:
            return json.loads(cleaned)
        except Exception as e:
            logger.warning(f"[FewShotRunner] JSON parse failed: {e}; raw={raw_text}")
            return {}


    def run(self, max_retries: int = 5, base_backoff: float = 5.0) -> Dict[str, Dict[str, float]]:
        """
        Run few-shot prediction across all test questions.

        Returns:
            dict: {qid: {answer: predicted proportion}} LLM predictions.
        """
        predictions = {}
        test_questions = self.survey.get_questions_by_split("test")
        logger.info(f"[FewShotRunner] Running few-shot prediction for {len(test_questions)} test questions...")

        # Resume from partial progress
        if os.path.exists(self.output_path):
            with open(self.output_path, "r", encoding="utf-8") as f:
                predictions = json.load(f)
            done = set(predictions.keys())
            logger.info(f"[FewShotRunner] Resuming from {len(done)} completed predictions.")
        else:
            done = set()

        # Create the global agent (trained on train+valid distributions)
        agent = AgentFactory.create(
            agent_type=self.agent_type,
            eid="fewshot_global",
            endowment=self.base_prompt,
            formality="",  # could add an instruction like "always output JSON" if needed
            model_name=self.model_name,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
        )

        raw_log_path = self.output_path.replace(".json", "_raw.jsonl")
        os.makedirs(os.path.dirname(raw_log_path), exist_ok=True)

        with open(raw_log_path, "w", encoding="utf-8") as raw_log:
            for i, q in enumerate(tqdm(test_questions, desc="Predicting", unit="q")):
                qid = q["id"]
                if qid in done:
                    continue
                
                qtext = q.get("question", "")
                logger.info(f"[FewShotRunner] ({i+1}/{len(test_questions)}) Predicting QID={qid}")

                # Use generator's strict query builder
                query = self.generator.build_query(q, strict=True)
                
                raw_response = {}
                attempt, success = 0, False
                while attempt < max_retries and not success:
                    try:
                        raw_response = agent.answer(query, return_usage=True)
                        success = True
                    except Exception as e:
                        attempt += 1
                        wait_time = base_backoff * (2 ** (attempt - 1))
                        jitter = random.uniform(0, 0.3 * wait_time)  # up to +30% jitter
                        total_wait = wait_time + jitter
                        logger.warning(
                            f"[FewShotRunner] QID={qid} failed (attempt {attempt}/{max_retries}): {e}. "
                            f"Retrying in {total_wait:.1f}s..."
                        )
                        time.sleep(total_wait)

                if not success:
                    logger.error(f"[FewShotRunner] Giving up on QID={qid} after {max_retries} retries.")
                    predictions[qid] = {}
                    continue

                # Log token usage
                if self.token_tracker and "usage" in raw_response:
                    self.token_tracker.log(
                        agent_id="fewshot_global",
                        input_tokens=raw_response["usage"].get("input_tokens", 0),
                        output_tokens=raw_response["usage"].get("output_tokens", 0),
                        cached_input_tokens=raw_response["usage"].get("cached_input_tokens", 0),
                        model_name=self.model_name,
                        module_name="FewShotRunner",
                    )

                answer_text = raw_response.get("answer_text", "")

                # robust parsing    
                pred_dist = self._parse_json(answer_text)
                predictions[qid] = pred_dist

                entry = {
                    "qid": qid,
                    "question": qtext,
                    "raw_response": raw_response,
                    "parsed_prediction": pred_dist,
                    "timestamp": time.time(),
                }
                json.dump(entry, raw_log, ensure_ascii=False)
                raw_log.write("\n")

        # Save structured predictions
        with open(self.output_path, "w", encoding="utf-8") as f:
            json.dump(predictions, f, indent=2, ensure_ascii=False)
        logger.info(f"[FewShotRunner] Saved predictions to {self.output_path}")

        return predictions
    
    def run_question(self, qid: str, max_retries: int = 5, base_backoff: float = 5.0) -> dict:
        """
        Run few-shot prediction for a single question ID.

        Args:
            qid (str): Question ID from the survey
            max_retries (int): Number of retry attempts
            base_backoff (float): Base wait time for exponential backoff

        Returns:
            dict: Parsed prediction distribution for this question
        """
        q = next((q for q in self.survey.get_questions() if q["id"] == qid), None)
        if q is None:
            raise ValueError(f"Question ID {qid} not found in survey.")

        logger.info(f"[FewShotRunner] Predicting single question QID={qid}")

        # Create global agent if not already
        agent = AgentFactory.create(
            agent_type=self.agent_type,
            eid="fewshot_global",
            endowment=self.base_prompt,
            formality="",
            model_name=self.model_name,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
        )

        # Build query
        query = self.generator.build_query(q, strict=True)

        # Retry loop
        raw_response = {}
        attempt, success = 0, False
        while attempt < max_retries and not success:
            try:
                raw_response = agent.answer(query, return_usage=True)
                success = True
            except Exception as e:
                attempt += 1
                wait_time = base_backoff * (2 ** (attempt - 1))
                jitter = random.uniform(0, 0.3 * wait_time)
                total_wait = wait_time + jitter
                logger.warning(
                    f"[FewShotRunner] QID={qid} failed (attempt {attempt}/{max_retries}): {e}. "
                    f"Retrying in {total_wait:.1f}s..."
                )
                time.sleep(total_wait)

        if not success:
            logger.error(f"[FewShotRunner] Giving up on QID={qid} after {max_retries} retries.")
            return {}

        # Log token usage
        if self.token_tracker and "usage" in raw_response:
            self.token_tracker.log(
                agent_id="fewshot_global",
                input_tokens=raw_response["usage"].get("input_tokens", 0),
                output_tokens=raw_response["usage"].get("output_tokens", 0),
                cached_input_tokens=raw_response["usage"].get("cached_input_tokens", 0),
                model_name=self.model_name,
                module_name="FewShotRunner",
            )

        # Parse JSON
        answer_text = raw_response.get("answer_text", "")
        pred_dist = self._parse_json(answer_text)

        logger.info(f"[FewShotRunner] Prediction for {qid}: {pred_dist}")
        return pred_dist


    def run_question_with_reasoning(self, qid: str, max_retries: int = 5, base_backoff: float = 5.0) -> dict:
        """
        Run few-shot prediction for a single question ID,
        asking the model to output both reasoning and the JSON distribution.

        Args:
            qid (str): Question ID from the survey
            max_retries (int): Number of retry attempts
            base_backoff (float): Base wait time for exponential backoff

        Returns:
            dict: {
                "qid": str,
                "reasoning": str,
                "prediction": dict (parsed JSON distribution)
            }
        """
        q = next((q for q in self.survey.get_questions() if q["id"] == qid), None)
        if q is None:
            raise ValueError(f"Question ID {qid} not found in survey.")

        logger.info(f"[FewShotRunner] Predicting (with reasoning) QID={qid}")

        # Create agent
        agent = AgentFactory.create(
            agent_type=self.agent_type,
            eid="fewshot_global_debug",
            endowment=self.base_prompt,
            formality="",
            model_name=self.model_name,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
        )

        # Modify query to explicitly request reasoning
        base_query = self.generator.build_query(q, strict=True)
        query = (
            base_query
            + "\n\nAlso explain briefly *why* you assigned these probabilities. "
              "Respond in the following format:\n"
              "Reasoning: <your explanation>\n"
              "Distribution: <JSON object>"
        )

        raw_response = {}
        attempt, success = 0, False
        while attempt < max_retries and not success:
            try:
                raw_response = agent.answer(query, return_usage=True)
                success = True
            except Exception as e:
                attempt += 1
                wait_time = base_backoff * (2 ** (attempt - 1))
                jitter = random.uniform(0, 0.3 * wait_time)
                total_wait = wait_time + jitter
                logger.warning(
                    f"[FewShotRunner] QID={qid} failed (attempt {attempt}/{max_retries}): {e}. "
                    f"Retrying in {total_wait:.1f}s..."
                )
                time.sleep(total_wait)

        if not success:
            logger.error(f"[FewShotRunner] Giving up on QID={qid} after {max_retries} retries.")
            return {"qid": qid, "reasoning": "", "prediction": {}}

        # Separate reasoning and distribution
        answer_text = raw_response.get("answer_text", "").strip()
        reasoning, pred_dist = "", {}

        # Try to split by "Distribution:"
        if "Distribution:" in answer_text:
            parts = answer_text.split("Distribution:", 1)
            reasoning = parts[0].replace("Reasoning:", "").strip()
            pred_dist = self._parse_json(parts[1])
        else:
            logger.warning(f"[FewShotRunner] Could not split reasoning/distribution cleanly. Raw: {answer_text}")
            pred_dist = self._parse_json(answer_text)

        return {"qid": qid, "reasoning": reasoning, "prediction": pred_dist}