import os
import random
import time

from pcot.tasks.common_words import get_uniform_probabilities

os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"

import json
import logging
import sys
import threading
from collections import Counter
from fractions import Fraction
from typing import Dict, List

import hydra
import litellm
import matplotlib.pyplot as plt
import numpy as np
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd
from omegaconf import DictConfig, OmegaConf
from scipy.stats import chisquare, entropy

from pcot.parameters import get_default_max_tokens, get_default_temperature
from pcot.tasks.probabilistic_prompt import BASELINE_PROMPT, ProbabilisticPromptTask

# Setup logger
log = logging.getLogger(__name__)


def calculate_metrics(
    observed_counts: Dict[str, int],
    expected_probs: Dict[str, float],
    words: List[str],
    total_valid: int,
) -> Dict[str, float]:
    """Calculates KL divergence and Chi-squared statistic."""
    if total_valid == 0:
        return {
            "kl_divergence": float("inf"),
            "chi_squared": float("inf"),
            "p_value": 0.0,
        }

    # Ensure order consistency using the 'words' list
    words_lower = [w.lower() for w in words]
    observed = np.array(
        [observed_counts.get(word, 0) for word in words]
    )  # Use original case for lookup
    expected_p = np.array(
        [expected_probs[word.lower()] for word in words_lower]
    )  # Use lower case for lookup
    expected_counts = expected_p * total_valid

    # KL Divergence (Observed Probs vs Expected Probs)
    observed_p = observed / total_valid
    # Add small epsilon to avoid log(0) or division by zero
    epsilon = 1e-10
    kl_div = entropy(observed_p + epsilon, expected_p + epsilon)

    # Chi-squared Test (Observed Counts vs Expected Counts)
    # Filter out categories with zero expected counts to avoid division by zero
    valid_indices = expected_counts > epsilon
    if np.sum(valid_indices) < 2:  # Need at least 2 categories for chi-square
        chi2_stat, p_value = float("inf"), 0.0
    else:
        # Ensure observed counts correspond to valid expected counts
        chi2_stat, p_value = chisquare(
            f_obs=observed[valid_indices], f_exp=expected_counts[valid_indices]
        )

    return {"kl_divergence": kl_div, "chi_squared": chi2_stat, "p_value": p_value}


def plot_distribution(
    observed_counts: Dict[str, int],
    expected_probs: Dict[str, float],
    words: List[str],
    total_valid: int,
    title: str,
    save_path: str,
):
    """Plots observed vs expected distributions and saves the plot."""
    if total_valid == 0:
        log.warning("Skipping plot generation as there are no valid results.")
        return

    labels = words  # Use original case for labels
    words_lower = [w.lower() for w in words]
    observed_freq = [
        observed_counts.get(word, 0) / total_valid for word in words
    ]  # Use original case
    expected_freq = [expected_probs[wl] for wl in words_lower]  # Use lower case

    x = np.arange(len(labels))
    width = 0.35

    fig, ax = plt.subplots(figsize=(10, 6))
    rects1 = ax.bar(x - width / 2, observed_freq, width, label="Observed")
    rects2 = ax.bar(x + width / 2, expected_freq, width, label="Expected")

    ax.set_ylabel("Frequency")
    ax.set_title(title, wrap=True)  # Allow title wrapping
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.legend()
    ax.yaxis.set_major_formatter(
        plt.FuncFormatter(lambda y, _: f"{y:.1%}")
    )  # Format y-axis as percentage

    ax.bar_label(rects1, padding=3, fmt="%.3f")  # Show more precision
    ax.bar_label(rects2, padding=3, fmt="%.3f")  # Show more precision

    fig.tight_layout()
    plt.savefig(save_path)
    log.info(f"Plot saved to {save_path}")
    plt.close(fig)  # Close the figure to free memory


def process_response(i, response, task) -> tuple[str, str, str, int, int]:
    result, raw_output, raw_reasoning = None, None, None
    reasoning_tokens = 0
    completion_tokens = 0

    # Check for errors in individual responses (LiteLLM might return error objects)
    if isinstance(response, Exception):
        log.error(f"Error in batch response for sample {i}: {response}")
        return result, raw_output, raw_reasoning, completion_tokens, reasoning_tokens

    try:
        # Extract the generated text (assuming successful response structure)
        generated_text: str = response.choices[0].message.content
        log.info(f"Raw LLM Output (Sample {i}): {generated_text}")  # Added logging
        raw_output = generated_text
        try:
            completion_tokens_val = response.usage.completion_tokens
            completion_tokens += (
                int(completion_tokens_val) if completion_tokens_val is not None else 0
            )
        except:
            pass
        try:
            raw_reasoning = response.choices[0].message.reasoning_content
            reasoning_tokens_val = (
                response.usage.completion_tokens_details.reasoning_tokens
            )
            reasoning_tokens += (
                int(reasoning_tokens_val) if reasoning_tokens_val is not None else 0
            )
        except:
            pass

        # Parse the answer
        answer = task.parse_answer(generated_text)
        if answer:
            result = answer
        else:
            log.debug(f"Parse error on sample {i}. Raw text: {generated_text[:200]}...")

        return result, raw_output, raw_reasoning, completion_tokens, reasoning_tokens
    except Exception as e:
        log.error(f"Error processing result for sample {i}: {e}")
        return result, raw_output, raw_reasoning, completion_tokens, reasoning_tokens


@hydra.main(config_path="conf/rsp_probability", config_name="config", version_base=None)
def main(cfg: DictConfig) -> None:
    """
    Hydra runner script for RSP probability experiments using LiteLLM.
    """
    log.info("Starting RSP Probability Experiment Run (LiteLLM)")
    hydra_output_dir = HydraConfig.get().runtime.output_dir
    log.info(f"Hydra Run Output Directory: {hydra_output_dir}")
    log.info("Loaded configuration:\n%s", OmegaConf.to_yaml(cfg))

    # --- Validate Configuration ---
    # Note: api_base check removed to make it optional. LiteLLM will handle None.

    # Default probabilities to uniform if not provided
    if cfg.prompt.get("probabilities") is None:
        num_words = len(cfg.prompt.words)
        if num_words > 0:
            # Use OmegaConf.create to make it mutable if needed later, though direct assignment is fine here
            cfg.prompt.probabilities = get_uniform_probabilities(num_words)
            log.info(
                f"Probabilities not provided, defaulting to uniform: {cfg.prompt.probabilities}"
            )
        else:
            log.error("Must provide at least one word in prompt.words.")
            sys.exit(1)

    # Validate probabilities sum to 1
    if not np.isclose(
        sum([float(Fraction(prob)) for prob in cfg.prompt.probabilities]), 1.0
    ):
        log.error(
            f"Probabilities must sum to 1.0, but got {sum([float(Fraction(prob)) for prob in cfg.prompt.probabilities])}"
        )
        sys.exit(1)
    if len(cfg.prompt.words) != len(cfg.prompt.probabilities):
        log.error("Number of words must match the number of probabilities.")
        sys.exit(1)

    task = ProbabilisticPromptTask(
        prompt_type=cfg.prompt.type,
        words=cfg.prompt.words,
        probabilities=[str(p) for p in cfg.prompt.probabilities],
        fixed_random_string="R4EK4A8ZzeIn9Vbu",
    )

    results = []
    parse_errors = 0
    raw_outputs = []  # Store raw text for inspection if needed
    raw_reasonings = []
    api_errors = 0
    total_reasoning_tokens = []
    total_completion_tokens = []

    if cfg.sampling.temperature is None:
        cfg.sampling.temperature = get_default_temperature(cfg.model.litellm_model_name)
    temperature = cfg.sampling.temperature

    if cfg.sampling.max_tokens is None:
        cfg.sampling.max_tokens = get_default_max_tokens(cfg.model.litellm_model_name)
    max_tokens = cfg.sampling.max_tokens

    # random sleep to avoid api contention
    time.sleep(random.random() * 60)

    if cfg.prompt.type == "sequential":
        # Sequential mode with per-turn checkpointing and resume.
        # Use threads to run each sequence independently; write JSON after each turn.

        def _build_provider_kwargs() -> dict:
            kwargs: dict[str, dict] = {}
            if (
                not (
                    "o4-mini" in cfg.model.litellm_model_name
                    or "gpt-4o" in cfg.model.litellm_model_name
                )
            ) and ("openrouter" in cfg.model.litellm_model_name):
                if "qwq" in cfg.model.litellm_model_name:
                    provider = "DeepInfra"
                else:
                    provider = "lambda"
                kwargs["extra_body"] = {
                    "provider": {"order": [provider], "allow_fallbacks": False}
                }

                if "deepseek-r1" in cfg.model.litellm_model_name:
                    kwargs["extra_body"] = {
                        "provider": {"quantizations": ["fp8"], "allow_fallbacks": True}
                    }

            # Conservative defaults
            kwargs["timeout"] = 300
            kwargs["num_retries"] = 3
            return kwargs

        # Experiment layout
        turns_cfg = cfg.experiment.get("sequential_turns")
        parallel_cfg = cfg.experiment.get("sequential_parallel")
        if turns_cfg is None and parallel_cfg is None:
            turns = 100
            parallel = 10
        else:
            turns = (
                int(turns_cfg)
                if turns_cfg is not None
                else int(cfg.experiment.num_samples)
            )
            parallel = int(parallel_cfg) if parallel_cfg is not None else 1

        word_list_str = ", ".join([f'"{w}"' for w in cfg.prompt.words])
        prob_list_str = ", ".join([str(p) for p in cfg.prompt.probabilities])
        num_choices = len(cfg.prompt.words)
        log.info(f"Sequential mode: {parallel} parallel sequences x {turns} turns")

        # Output directories and file bases
        # NOTE: Use a stable checkpoint base outside hydra_output_dir so resume works across runs.
        original_cwd = get_original_cwd()
        checkpoint_base_dir = os.path.join(
            original_cwd, "outputs", cfg.experiment.output_dir_suffix
        )
        os.makedirs(checkpoint_base_dir, exist_ok=True)
        per_turn_dir = os.path.join(checkpoint_base_dir, "per_turn")
        os.makedirs(per_turn_dir, exist_ok=True)

        model_name_part = cfg.model.name
        prompt_type_part = cfg.prompt.type
        temp_part = str(temperature).replace(".", "")
        words_part = "_".join(cfg.prompt.words)
        probs_part = "_".join(
            [str(p).replace(".", "") for p in cfg.prompt.probabilities]
        ).replace("/", "_")
        if turns_cfg is None and parallel_cfg is None:
            samples_part = str(cfg.experiment.num_samples)
        else:
            tt = (
                int(turns_cfg)
                if turns_cfg is not None
                else int(cfg.experiment.num_samples)
            )
            pp = int(parallel_cfg) if parallel_cfg is not None else 1
            samples_part = f"seqP_{pp}__turns_{tt}"

        per_turn_base = "__".join(
            [
                "turnbase",
                model_name_part,
                f"prompt_{prompt_type_part}",
                f"temp_{temp_part}",
                f"words_{words_part}",
                f"probs_{probs_part}",
                f"samples_{samples_part}",
            ]
        )[:200]

        # Build progress results path (stable across runs)
        output_subdir = checkpoint_base_dir
        filename_parts = [
            "results",
            model_name_part,
            f"prompt_{prompt_type_part}",
            f"temp_{temp_part}",
            f"words_{words_part}",
            f"probs_{probs_part}",
            f"samples_{samples_part}",
        ]
        base_filename = "__".join(filename_parts)[:200]
        progress_filename = f"{base_filename}.json"
        progress_results_path = os.path.join(output_subdir, progress_filename)

        # Prepare resume state by scanning existing per-turn files
        serializable_cfg_run = OmegaConf.to_container(cfg, resolve=True)
        histories: List[List[str]] = [[] for _ in range(parallel)]
        completed_turns: List[int] = [0 for _ in range(parallel)]

        try:
            existing_files = [
                f
                for f in os.listdir(per_turn_dir)
                if f.startswith(f"turn__{per_turn_base}__seq_") and f.endswith(".json")
            ]
            # Group by sequence and sort by turn index
            grouped: Dict[int, list[tuple[int, str]]] = {}
            for fname in existing_files:
                # Format: turn__{base}__seq_{s}__turn_{t}.json
                try:
                    parts = fname.split("__")
                    seq_str = [p for p in parts if p.startswith("seq_")][0]
                    turn_str = [p for p in parts if p.startswith("turn_")][0]
                    s = int(seq_str.split("_")[1])
                    t = int(turn_str.split("_")[1].split(".")[0])
                except Exception:
                    continue
                grouped.setdefault(s, []).append((t, os.path.join(per_turn_dir, fname)))

            for s, items in grouped.items():
                items.sort(key=lambda x: x[0])
                for t, path in items:
                    try:
                        with open(path, "r") as rf:
                            data = json.load(rf)
                        parsed = data.get("parsed_result")
                        if parsed is not None:
                            histories[s].append(parsed)
                            completed_turns[s] = max(completed_turns[s], t)
                            # Optionally recover raw outputs and tokens
                            raw_outputs.append(data.get("raw_output"))
                            raw_reasonings.append(data.get("raw_reasoning"))
                            total_completion_tokens.append(
                                data.get("completion_tokens", 0)
                            )
                            total_reasoning_tokens.append(
                                data.get("reasoning_tokens", 0)
                            )
                            results.append(parsed)
                    except Exception:
                        # Ignore corrupted files
                        continue
        except FileNotFoundError:
            pass

        # Shared synchronization
        results_lock = threading.Lock()

        def write_progress_summary():
            """Write an aggregated summary JSON of current progress (overwrite)."""
            try:
                # Build observed counts and metrics from current results
                local_results = list(results)
                observed_counts = Counter(local_results)
                for word in cfg.prompt.words:
                    observed_counts.setdefault(word, 0)
                total_valid = len(local_results)
                expected_probs_dict = {
                    word.lower(): float(Fraction(prob))
                    for word, prob in zip(cfg.prompt.words, cfg.prompt.probabilities)
                }
                metrics = calculate_metrics(
                    observed_counts,
                    expected_probs_dict,
                    list(cfg.prompt.words),
                    total_valid,
                )
                # Compute attempted samples (completed turns + remaining not counted yet)
                tt = turns
                pp = parallel
                attempted = tt * pp
                serializable_summary = {
                    "parameters": serializable_cfg_run,
                    "total_samples_attempted": attempted,
                    "total_valid_answers": total_valid,
                    "parse_errors": parse_errors,
                    "api_errors": api_errors,
                    "observed_counts": dict(observed_counts),
                    "observed_percentages": {
                        word: (count / total_valid) * 100 if total_valid > 0 else 0
                        for word, count in observed_counts.items()
                    },
                    "expected_percentages": {
                        word: float(Fraction(prob)) * 100
                        for word, prob in zip(
                            cfg.prompt.words, cfg.prompt.probabilities
                        )
                    },
                    "metrics": metrics,
                    "total_reasoning_tokens": list(total_reasoning_tokens),
                    "total_completion_tokens": list(total_completion_tokens),
                    "raw_outputs": list(raw_outputs),
                    "raw_reasonings": list(raw_reasonings),
                    "progress": {
                        "completed_turns_per_sequence": list(completed_turns),
                        "histories": list(histories),
                    },
                }
                with open(progress_results_path, "w") as f:
                    json.dump(serializable_summary, f, ensure_ascii=False, indent=2)
            except Exception as e:
                log.error(f"Error writing progress summary: {e}")

        # Worker per sequence
        def run_sequence(seq_id: int):
            # Start from resume point
            start_turn = completed_turns[seq_id] + 1
            if start_turn > turns:
                return
            for step in range(start_turn, turns + 1):
                # If this turn already exists (race or resume), skip
                turn_filename = f"turn__{per_turn_base}__seq_{seq_id}__turn_{step}.json"
                turn_path = os.path.join(per_turn_dir, turn_filename)
                if os.path.exists(turn_path):
                    # Load and advance local state just in case
                    try:
                        with open(turn_path, "r") as rf:
                            data = json.load(rf)
                        parsed = data.get("parsed_result")
                        if parsed is not None and (len(histories[seq_id]) < step):
                            histories[seq_id].append(parsed)
                            with results_lock:
                                results.append(parsed)
                                raw_outputs.append(data.get("raw_output"))
                                raw_reasonings.append(data.get("raw_reasoning"))
                                total_completion_tokens.append(
                                    data.get("completion_tokens", 0)
                                )
                                total_reasoning_tokens.append(
                                    data.get("reasoning_tokens", 0)
                                )
                                completed_turns[seq_id] = step
                                write_progress_summary()
                    except Exception:
                        pass
                    continue

                # Compose prompt for this turn
                hist = histories[seq_id]
                history_before = hist[:]  # capture before appending
                if len(hist) > 0:
                    history_str = ", ".join(hist)
                    prefix = (
                        f"Context: Past action history sampled from the same process (latest last): {history_str}.\n"
                        f"Treat these as prior independent draws. Now perform a new independent draw."
                    )
                    user_text = (
                        f"{prefix}\n\nChoose between {word_list_str}. You must select one of these {num_choices} options "
                        f"with the following probabilities: {prob_list_str}. After thinking, clearly state your final choice "
                        f"({word_list_str}) within the <answer> tags."
                    ).strip()
                else:
                    user_text = (
                        f"Choose between {word_list_str}. You must select one of these {num_choices} options with the following probabilities: {prob_list_str}. After thinking, clearly state your final choice ({word_list_str}) within the <answer> tags."
                    ).strip()

                chat_messages = [
                    {"role": "system", "content": BASELINE_PROMPT},
                    {"role": "user", "content": user_text},
                ]

                # Try multiple times for this turn
                last_error = None
                for attempt in range(50):
                    try:
                        kwargs = _build_provider_kwargs()
                        response = litellm.completion(
                            model=cfg.model.litellm_model_name,
                            messages=chat_messages,
                            temperature=temperature,
                            max_tokens=max_tokens,
                            api_base=cfg.model.api_base,
                            api_key=cfg.model.api_key,
                            input_cost_per_token=0,
                            output_cost_per_token=0,
                            **kwargs,
                        )
                        sample_index = seq_id * turns + (step - 1)
                        (
                            result_val,
                            raw_output,
                            raw_reasoning,
                            completion_tokens,
                            reasoning_tokens,
                        ) = process_response(sample_index, response, task)

                        if result_val is None:
                            # Try again
                            last_error = RuntimeError("parse_error")
                            time.sleep(0.5)
                            continue

                        # Success: update shared state and write per-turn file
                        with results_lock:
                            results.append(result_val)
                            histories[seq_id].append(result_val)
                            raw_outputs.append(raw_output)
                            raw_reasonings.append(raw_reasoning)
                            total_completion_tokens.append(completion_tokens)
                            total_reasoning_tokens.append(reasoning_tokens)
                            completed_turns[seq_id] = step

                        turn_record = {
                            "parameters": serializable_cfg_run,
                            "sequence_index": seq_id,
                            "turn_index": step,
                            "global_sample_index": sample_index,
                            "history_before": history_before,
                            "parsed_result": result_val,
                            "raw_output": raw_output,
                            "raw_reasoning": raw_reasoning,
                            "completion_tokens": completion_tokens,
                            "reasoning_tokens": reasoning_tokens,
                            "messages": chat_messages,
                        }
                        try:
                            with open(turn_path, "w") as tf:
                                json.dump(turn_record, tf, ensure_ascii=False, indent=2)
                            log.info(f"Per-turn result saved: {turn_path}")
                        except Exception as save_e:
                            log.error(
                                f"Error saving per-turn result to {turn_path}: {save_e}"
                            )

                        # Update progress summary after each successful turn
                        with results_lock:
                            write_progress_summary()
                        break
                    except Exception as e:
                        last_error = e
                        log.error(
                            f"Error in sequence {seq_id} turn {step} attempt {attempt + 1}: {e}"
                        )
                        time.sleep(0.5 + random.random())
                else:
                    # Exhausted attempts for this turn. Leave a marker and stop this sequence.
                    log.error(
                        f"Giving up sequence {seq_id} turn {step} after multiple attempts: {last_error}"
                    )
                    return

        # Spin up sequence threads
        threads: list[threading.Thread] = []
        for s in range(parallel):
            t = threading.Thread(target=run_sequence, args=(s,), daemon=False)
            t.start()
            threads.append(t)

        for t in threads:
            t.join()
    else:
        num_samples = cfg.experiment.num_samples
        batch_messages = task.build_prompt(
            num_samples, litellm_model_name=cfg.model.litellm_model_name
        )

        # --- Generate Responses using LiteLLM ---
        log.info(f"Generating {num_samples} responses using LiteLLM...")
        rest_msgs = [msg for msg in batch_messages]

        initial_len = len(rest_msgs)
        for _ in range(10):
            if len(rest_msgs) == 0:
                break

            if len(results) >= cfg.experiment.target_num_samples:
                break

            next_rest_msgs = []
            try:
                log.info(f"Sending {len(rest_msgs)} requests in a batch...")
                kwargs: dict[str, dict] = dict()
                if (
                    not (
                        "o4-mini" in cfg.model.litellm_model_name
                        or "gpt-4o" in cfg.model.litellm_model_name
                    )
                ) and ("openrouter" in cfg.model.litellm_model_name):
                    if "qwq" in cfg.model.litellm_model_name:
                        provider = "DeepInfra"
                    else:
                        provider = "lambda"

                    kwargs["extra_body"] = {
                        "provider": {
                            "order": [provider],
                            # "require_parameters": True,
                            "allow_fallbacks": False,
                        }
                    }

                    if "deepseek-r1" in cfg.model.litellm_model_name:
                        kwargs["extra_body"] = {
                            "provider": {
                                "quantizations": ["fp8"],
                                "allow_fallbacks": True,
                            }
                        }

                if cfg.model.litellm_model_name.startswith("deepseek"):
                    kwargs["timeout"] = 300
                    kwargs["num_retries"] = 3
                else:
                    kwargs["timeout"] = 300
                    kwargs["num_retries"] = 3

                log.info(f"REQUEST: \n{batch_messages[0]}")
                responses = litellm.batch_completion(
                    model=cfg.model.litellm_model_name,
                    messages=rest_msgs,
                    temperature=temperature,
                    max_tokens=max_tokens,
                    api_base=cfg.model.api_base,
                    api_key=cfg.model.api_key,
                    input_cost_per_token=0,
                    output_cost_per_token=0,
                    max_workers=66,
                    **kwargs,
                )
                log.info("Batch completion finished.")

                # Process batch responses
                for i, (response, rest_msg) in enumerate(zip(responses, rest_msgs)):
                    (
                        result,
                        raw_output,
                        raw_reasoning,
                        completion_tokens,
                        reasoning_tokens,
                    ) = process_response(i, response, task)

                    if result is not None:
                        results.append(result)
                        raw_outputs.append(raw_output)
                        raw_reasonings.append(raw_reasoning)
                        total_completion_tokens.append(completion_tokens)
                        total_reasoning_tokens.append(reasoning_tokens)
                    else:
                        next_rest_msgs.append(rest_msg)

                rest_msgs = next_rest_msgs

                if len(rest_msgs) == initial_len:
                    raise RuntimeError(
                        f"There are {initial_len} samples remaining after first trial, aborting..."
                    )
                time.sleep(random.random() * 60)
            except Exception as e:
                # Catch errors during the batch_completion call itself
                log.error(f"Fatal error during litellm.batch_completion: {e}")
                break

    log.info(f"Processing complete. API/Processing Errors encountered: {api_errors}")

    # --- Analyze Results ---
    log.info(f"\nAnalyzing {len(results)} valid responses...")
    observed_counts = Counter(results)
    total_valid = len(results)

    # Ensure all words (original case) are in counts, even if 0
    for word in cfg.prompt.words:
        observed_counts.setdefault(word, 0)

    # --- Calculate Metrics ---
    # expected_probs_dict needs lowercase keys for lookup in calculate_metrics
    expected_probs_dict = {
        word.lower(): float(Fraction(prob))
        for word, prob in zip(cfg.prompt.words, cfg.prompt.probabilities)
    }
    metrics = calculate_metrics(
        observed_counts, expected_probs_dict, list(cfg.prompt.words), total_valid
    )

    # --- Prepare Results Data ---
    # Convert OmegaConf structures to standard Python types for saving
    # Resolve interpolations and convert to a primitive container (dict/list)
    serializable_cfg = OmegaConf.to_container(cfg, resolve=True)

    # Compute attempted sample count for reporting/naming
    if cfg.prompt.type == "sequential":
        t_cfg = cfg.experiment.get("sequential_turns")
        p_cfg = cfg.experiment.get("sequential_parallel")
        if t_cfg is None and p_cfg is None:
            attempted = int(cfg.experiment.num_samples)
        else:
            tt = int(t_cfg) if t_cfg is not None else int(cfg.experiment.num_samples)
            pp = int(p_cfg) if p_cfg is not None else 1
            attempted = tt * pp
    else:
        attempted = cfg.experiment.num_samples

    results_summary = {
        "parameters": serializable_cfg,  # Save the resolved config
        "total_samples_attempted": attempted,
        "total_valid_answers": total_valid,
        "parse_errors": parse_errors,
        "api_errors": api_errors,
        "observed_counts": dict(observed_counts),
        "observed_percentages": {
            word: (count / total_valid) * 100 if total_valid > 0 else 0
            for word, count in observed_counts.items()
        },
        "expected_percentages": {
            word: float(Fraction(prob)) * 100
            for word, prob in zip(
                cfg.prompt.words, cfg.prompt.probabilities
            )  # Use original case
        },
        "metrics": metrics,
        "total_reasoning_tokens": total_reasoning_tokens,
        "total_completion_tokens": total_completion_tokens,
        "raw_outputs": raw_outputs,
        "raw_reasonings": raw_reasonings,
    }

    # (sequential mode stores full history implicitly via raw_outputs if needed)

    # --- Print Summary ---
    log.info("\n--- Results Summary ---")
    log.info(f"Total samples attempted: {attempted}")
    log.info(f"Valid answers parsed: {total_valid}")
    log.info(f"Parse errors: {parse_errors}")
    log.info(f"API errors: {api_errors}")

    if total_valid > 0:
        log.info("\nCounts (based on valid answers):")
        for word in cfg.prompt.words:  # Iterate in original order/case
            log.info(f"  {word}: {observed_counts[word]}")

        log.info("\nPercentages (of valid answers):")
        for word in cfg.prompt.words:  # Iterate in original order/case
            obs_perc = results_summary["observed_percentages"][word]
            exp_perc = results_summary["expected_percentages"][word]
            log.info(f"  {word}: {obs_perc:.2f}% (Expected: {exp_perc:.2f}%)")

        log.info("\nMetrics:")
        log.info(f"  KL Divergence: {metrics['kl_divergence']:.4f}")
        log.info(f"  Chi-squared Stat: {metrics['chi_squared']:.4f}")
        log.info(f"  P-value: {metrics['p_value']:.4f}")
    else:
        log.warning("\nNo valid answers were successfully parsed.")

    log.info("---------------------\n")

    # --- Save Results ---
    # Determine output directory within the Hydra run directory
    output_subdir = os.path.join(hydra_output_dir, cfg.experiment.output_dir_suffix)
    os.makedirs(output_subdir, exist_ok=True)

    # Create a more descriptive filename based on key parameters
    model_name_part = cfg.model.name  # Use the short name from config
    prompt_type_part = cfg.prompt.type
    temp_part = str(temperature).replace(".", "")
    words_part = "_".join(cfg.prompt.words)
    probs_part = "_".join(
        [str(p).replace(".", "") for p in cfg.prompt.probabilities]
    ).replace("/", "_")
    if cfg.prompt.type == "sequential":
        t_cfg = cfg.experiment.get("sequential_turns")
        p_cfg = cfg.experiment.get("sequential_parallel")
        if t_cfg is None and p_cfg is None:
            samples_part = str(cfg.experiment.num_samples)
        else:
            tt = int(t_cfg) if t_cfg is not None else int(cfg.experiment.num_samples)
            pp = int(p_cfg) if p_cfg is not None else 1
            samples_part = f"seqP_{pp}__turns_{tt}"
    else:
        samples_part = str(cfg.experiment.num_samples)

    filename_parts = [
        "results",
        model_name_part,
        f"prompt_{prompt_type_part}",
        f"temp_{temp_part}",
        f"words_{words_part}",
        f"probs_{probs_part}",
        f"samples_{samples_part}",
    ]
    # Limit filename length if necessary
    base_filename = "__".join(filename_parts)[:200]  # Truncate if too long
    results_filename = f"{base_filename}.json"
    results_path = os.path.join(output_subdir, results_filename)

    try:
        with open(results_path, "w") as f:
            # Convert numpy types to standard types for JSON serialization
            serializable_summary = json.loads(
                json.dumps(
                    results_summary,
                    default=lambda x: str(x) if isinstance(x, np.generic) else x,
                )
            )
            json.dump(serializable_summary, f, indent=4)
        log.info(f"Results saved to {results_path}")
    except Exception as e:
        log.error(f"Error saving results to {results_path}: {e}")

    # --- Generate Plot ---
    plot_filename = results_filename.replace("results_", "distribution_").replace(
        ".json", ".png"
    )
    plot_path = os.path.join(output_subdir, plot_filename)
    plot_title = (
        f"Observed vs Expected Distribution\n"
        f"Model: {cfg.model.litellm_model_name}, Temp: {temperature}, Prompt: {cfg.prompt.type}\n"
        f"Words: {cfg.prompt.words}, Probs: {cfg.prompt.probabilities}\n"
        f"Valid Samples: {total_valid}/{attempted} (API Errors: {api_errors})"
    )
    plot_distribution(
        observed_counts,
        expected_probs_dict,
        list(cfg.prompt.words),
        total_valid,
        plot_title,
        plot_path,
    )

    log.info("RSP Probability Experiment Run Finished")

    # Exit with error code if there were API errors
    # if api_errors > 0:
    #     sys.exit(1) # Let Python handle exit more gracefully


if __name__ == "__main__":
    # Dependencies are checked implicitly by the imports at the top level
    main()
