import os
import csv
import json
import time
import random
import logging
# import traceback # For debugging
from collections import defaultdict
from typing import List, Dict, Optional
import pandas as pd
from agents.factory import AgentFactory
from modules.survey_converter import Survey
from modules.endowment_manager import Endowments
from modules.response_cleaner import ResponseCleaner
from modules.token_tracker import MultiAxisTokenTracker

logger = logging.getLogger(__name__)

class SurveyConductor:
    """
    Conducts a survey across a set of endowments using LLM agents and stores responses.

    Attributes:
        survey (Survey): The survey object containing questions.
        endowments (Endowments): Endowment pool for creating agents.
        output_path (str): Path to save the response dictionary (.json).
        agent_type (str): Backend agent type identifier (e.g., 'openai', 'claude').
        model_name (str): LLM model to use.
        formality (str): Additional instruction appended to each agent prompt.
        verbose (bool): Whether or not to turn on logging.
        responses (List[Dict]): Stores the most recent run's response data.
    """

    def __init__(self,
                 survey: Survey,
                 endowments: Endowments,
                 output_path: str = "logs/responses.json",
                 agent_type: str = "openai",
                 model_name: str = "gpt-4o-mini",
                 formality: str = "",
                 cleaner=None,
                 clean: bool = False,
                 verbose: bool = True,
                 max_retries: int = 26,
                 retry_delay: float = 0.5,
                 max_delay_cap: float= 45.0,
                 token_tracker: Optional[MultiAxisTokenTracker] = None):
        self.survey = survey
        self.endowments = endowments
        self.output_path = output_path
        self.agent_type = agent_type
        self.model_name = model_name
        self.formality = formality
        self.verbose = verbose
        self.clean = clean
        self.cleaner = cleaner or ResponseCleaner(survey = survey)
        self.responses = []
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self.max_delay_cap = max_delay_cap
        self.token_tracker = token_tracker

    def _log(self, msg: str, level: str = "info"): 
        if self.verbose:
            log_fn = getattr(logger, level, logger.info)
            log_fn(msg)

    def _is_error_response(self, answer: str) -> bool:
        return isinstance(answer, str) and answer.strip().startswith("[ERROR]")
    
    def _get_valid_response(self, agent, qprompt: str, qid: str) -> Dict:
        """
        Attempt to get a non-error response from the agent up to max_retries.
        Returns a raw_response dict (with 'answer_text' key).
        """

        for attempt in range(1, self.max_retries + 1):
            try:
                raw_response = agent.answer(qprompt, return_usage = True)
                answer = raw_response["answer_text"]

                if self.clean and self.cleaner:
                    cleaned = self.cleaner.clean(answer, qid=qid)
                    raw_response["cleaned"] = cleaned
                    answer = cleaned

                if not self._is_error_response(answer):
                    if self.token_tracker and "usage" in raw_response:
                        self.token_tracker.log(
                            agent_id=agent.eid,
                            input_tokens=raw_response["usage"].get("input_tokens", 0),
                            output_tokens=raw_response["usage"].get("output_tokens", 0),
                            cached_input_tokens=raw_response["usage"].get("cached_input_tokens", 0),
                            model_name=self.model_name,
                            module_name="SurveyConductor"
                        )
                    return raw_response

                self._log(f"[Retry {attempt}] Received error response: {answer}", level="warning")
            except Exception as err:
                raw_response = {"error": str(err)}
                self._log(f"[Retry {attempt}] Exception: {err}", level="warning")
                answer = f"[ERROR] {err}"
                raw_response["answer_text"] = answer

            # Strong exponential backoff with jitter and max cap
            base = self.retry_delay
            delay = min(self.max_delay_cap, base * (2 ** (attempt - 1)))   # exponential backoff capped at 60s
            jitter = random.uniform(0.8, 1.3)                 # multiplicative jitter
            delay *= jitter

            time.sleep(delay)

        self._log(f"[FAIL] Max retries reached: eid={agent.eid}, qid={qid}", level="warning")
        raw_response["__hard_fail__"] = True
        # Final response (either valid or after exhausting retries)
        return raw_response

    def run(self, save: bool = True, parallel_worker: bool = False, **agent_kwargs):
        """
        Conducts the survey for all endowments.

        Args:
            save (bool): Whether to save the responses to a JSON dictionary file.
            parallel_worker (bool): If parallel workers are used (reduce logging).
            **agent_kwargs: Additional keyword arguments passed to each agent (e.g., temperature, max_tokens).

        Returns:
            List[Dict]: List of response entries.
        """
        self.responses = []
        self.failed_pairs = [] 
        questions = self.survey.questions
        endowment_list = self.endowments.endowments
        metadata = {
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
            "agent_type": self.agent_type,
            "model_name": self.model_name,
            "formality": self.formality,
            "agent_kwargs": agent_kwargs
        }
        if not parallel_worker:
            self._log(f"Starting survey with {len(endowment_list)} endowments × {len(questions)} questions...")
        
        raw_log_path = self.output_path.replace(".json", "_raw.jsonl")
        os.makedirs(os.path.dirname(raw_log_path), exist_ok=True)

        with open(raw_log_path, "w", encoding="utf-8") as raw_log:
            for i, e in enumerate(endowment_list):
                eid = e["eid"]
                endow_text = e["endow_text"]

                self._log(f"[{i+1}/{len(endowment_list)}] Processing eid={eid}...", level="debug")

                agent = AgentFactory.create(
                    agent_type=self.agent_type,
                    eid=eid,
                    endowment=endow_text,
                    formality=self.formality,
                    model_name=self.model_name,
                    **agent_kwargs
                )

                for q in questions:
                    qid = q["id"]
                    qprompt = self.survey.get_prompt_text(q)

                    raw_response = self._get_valid_response(agent, qprompt, qid)
                    answer = raw_response.get("cleaned") or raw_response.get("answer_text", "[ERROR] No response")

                    # If hard failed, store (eid, qid)
                    if raw_response.get("__hard_fail__"):
                        self.failed_pairs.append((e, q))

                    entry = {
                        "eid": eid,
                        "qid": qid,
                        "question": qprompt,
                        "raw_response": raw_response,
                        "answer": answer,
                        "timestamp": time.time(),
                        "agent_type": self.agent_type,
                        "model_name": self.model_name
                    }
                    self.responses.append(entry)
                    json.dump(entry, raw_log, ensure_ascii=False)
                    raw_log.write("\n")


            # --------------------------------------------------
            # SECOND PASS: retry only failed (eid,qid) pairs
            # --------------------------------------------------
            if self.failed_pairs:
                self._log(f"[Retry pass] Re-attempting {len(self.failed_pairs)} failed calls...", level="warning")

                for e, q in self.failed_pairs:
                    eid = e["eid"]
                    qid = q["id"]
                    qprompt = self.survey.get_prompt_text(q)

                    agent = AgentFactory.create(
                        agent_type=self.agent_type,
                        eid=eid,
                        endowment=e["endow_text"],
                        formality=self.formality,
                        model_name=self.model_name,
                        **agent_kwargs
                    )

                    raw_response = self._get_valid_response(agent, qprompt, qid)
                    answer = raw_response.get("cleaned") or raw_response.get("answer_text", "[ERROR] No response")

                    retry_entry = {
                        "eid": eid,
                        "qid": qid,
                        "question": qprompt,
                        "raw_response": raw_response,
                        "answer": answer,
                        "timestamp": time.time(),
                        "agent_type": self.agent_type,
                        "model_name": self.model_name,
                        "retry_pass": True
                    }
                    json.dump(retry_entry, raw_log, ensure_ascii=False)
                    raw_log.write("\n")

                    # patch result in-place
                    for entry in self.responses:
                        if entry["eid"] == eid and entry["qid"] == qid:
                            entry["raw_response"] = raw_response
                            entry["answer"] = answer
                            break


        if save:
            self._log(f"Saving JSON dictionary to {self.output_path}")
            self._log(f"Saving raw responses to {raw_log_path}")
            response_dict = self.to_response_dict(self.responses)
            self.save_to_json({"metadata": metadata, "responses": response_dict})
        if not parallel_worker:
            self._log("Survey completed.")

        return self.responses

    def save_to_csv(self, responses: List[Dict], csv_path: str):
        """
        Save responses in matrix format: rows = qid, columns = eid.

        Args:
            responses (List[Dict]): List of response entries.
            csv_path (str): Path to output CSV file.
        """
        matrix = defaultdict(dict)
        for r in responses:
            matrix[r["qid"]][r["eid"]] = r["answer"]

        all_eids = sorted({r["eid"] for r in responses})
        all_qids = sorted({r["qid"] for r in responses})

        df = pd.DataFrame.from_dict(matrix, orient="index", columns=all_eids)
        df.index.name = "qid"
        df = df.reindex(all_qids)

        os.makedirs(os.path.dirname(csv_path), exist_ok=True)
        df.to_csv(csv_path)

    def to_response_dict(self, responses: List[Dict]) -> Dict[str, Dict[str, str]]:
        """
        Convert a flat list of responses into a nested dictionary.

        Returns:
            Dict[qid][eid] = response
        """
        matrix = defaultdict(dict)
        for r in responses:
            matrix[r["qid"]][r["eid"]] = r["answer"]
        return dict(matrix)

    def to_agent_records(self) -> List[Dict]:
        """
        Converts flat response logs to a list of agent-level records.
        Each record includes:
        - agent_id
        - mode
        - raw response vector (aligned with question_ids)
        """
        agent_response_map = defaultdict(dict)
        for r in self.responses:
            agent_response_map[r["eid"]][r["qid"]] = r["answer"]

        qids = [q["id"] for q in self.survey.questions]

        records = []
        for eid, qdict in agent_response_map.items():
            raw_vector = [qdict.get(qid, None) for qid in qids]
            mode = self.endowments.get_mode_by_eid(eid)

            records.append({
                "agent_id": eid,
                "mode": mode,
                "responses": raw_vector
            })

        return records

    def save_to_json(self, response_dict: Dict, json_path: str = None):
        """
        Save a nested response dictionary to a JSON file.

        Args:
            response_dict (Dict): Nested dictionary or metadata-wrapped dict to save.
            json_path (str, optional): Path to save JSON. Defaults to output_path.
        """
        json_path = json_path or self.output_path
        os.makedirs(os.path.dirname(json_path), exist_ok=True)

        with open(json_path, "w", encoding="utf-8") as f:
            json.dump(response_dict, f, indent=2, ensure_ascii=False)
