import os

from pcot.analyze.random_int_generation import (
    analyze_random_integers,
    analyze_random_strings,
)
from pcot.budget_forcing import call_llm_budget_force
from pcot.parameters import get_default_temperature
from pcot.parser import (
    parse_random_digit_sequence,
    parse_random_integer,
    parse_random_string,
)
from pcot.prompts.random_int_generation import (
    build_random_int_followup_prompts,
    build_random_int_prompts,
)

os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"

import json
import logging
import sys
from typing import List, Optional, Tuple

import hydra
import litellm
import numpy as np
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm

logger = logging.getLogger(__name__)


def build_prompt(model_name: str, system_prompt: str, user_prompt: str, prev_strs: Optional[list[str]] = None):
    if prev_strs is None:
        if "Phi-4-reasoning" in model_name or "deepseek-r1" in model_name:
            chat_messages = [{"role": "user", "content": system_prompt + "\n\nTask:" + user_prompt}]
        else:
            chat_messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
    else:
        if "Phi-4-reasoning" in model_name or "deepseek-r1" in model_name:
            chat_messages = [{"role": "user", "content": system_prompt + "\n\nTask:" + user_prompt + f"\n\nPreviously Generated Answers: {'\n'.join(prev_strs)}"}]
        else:
            chat_messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt + f"\n\nPreviously Generated Answers: {'\n'.join(prev_strs)}"}]
    return chat_messages


@hydra.main(config_path="conf/random_generation", config_name="config", version_base=None)
def main(cfg: DictConfig) -> None:
    """
    Hydra runner script for random generation experiments using LiteLLM.
    """
    logger.info("Starting Random Generation Experiment Run (LiteLLM)")
    hydra_output_dir = HydraConfig.get().runtime.output_dir
    logger.info(f"Hydra Run Output Directory: {hydra_output_dir}")
    logger.info("Loaded configuration:\n%s", OmegaConf.to_yaml(cfg))

    # --- Validate Configuration ---
    try:
        min_val = int(cfg.experiment.min_value)
        max_val = int(cfg.experiment.max_value)
        if min_val >= max_val:
            logger.error(f"experiment.min_value ({min_val}) must be less than experiment.max_value ({max_val}).")
            sys.exit(1)
    except (ValueError, TypeError):
        logger.error("experiment.min_value and experiment.max_value must be integers.")
        sys.exit(1)

    system_prompt, user_prompt = build_random_int_prompts(min_val, max_val, cfg.experiment.only_str, cfg.experiment.only_int)

    chat_messages = build_prompt(cfg.model.litellm_model_name, system_prompt, user_prompt)

    logger.info(f"Generating {cfg.experiment.num_samples} responses for range [{min_val}, {max_val}]...")
    random_string_parse_errors: int = 0
    total_reasoning_tokens = 0

    temperature = cfg.sampling.temperature or get_default_temperature(cfg.model.litellm_model_name)

    if not cfg.experiment.use_budget_forcing:
        batch_messages = [chat_messages] * cfg.experiment.num_samples
        try:
            if not cfg.experiment.is_sequential_request:
                logger.info(f"Generating {cfg.experiment.num_samples} responses using standard LiteLLM batch completion...")
                responses = litellm.batch_completion(
                    model=cfg.model.litellm_model_name,
                    messages=batch_messages,
                    temperature=temperature,
                    max_tokens=cfg.sampling.max_tokens,
                    api_base=cfg.model.api_base,
                    api_key=cfg.model.api_key,
                    extra_body={
                        "provider": {
                            "order": ["DeepInfra"],
                            # "require_parameters": True,
                            "allow_fallbacks": False,
                        }
                    },
                    input_cost_per_token=0,
                    output_cost_per_token=0,
                    num_retries=20,
                    max_workers=cfg.experiment.parallel_workers,
                )
                logger.info("Batch completion finished.")
            else:
                logger.info(f"Sending {cfg.experiment.num_samples} requests in sequence...")

                responses = []
                random_strs: list[str] = []
                for _ in tqdm(range(cfg.experiment.num_samples)):
                    try:
                        response = litellm.completion(
                            model=cfg.model.litellm_model_name,
                            messages=chat_messages,
                            temperature=temperature,
                            # max_tokens=cfg.sampling.max_tokens - generated_tokens,
                            api_base=cfg.model.api_base,
                            api_key=cfg.model.api_key,
                            extra_body={
                                "provider": {
                                    "order": ["DeepInfra"],
                                    # "require_parameters": True,
                                    "allow_fallbacks": False,
                                }
                            },
                            input_cost_per_token=0,
                            output_cost_per_token=0,  # We can check cost from openrouter console, so disable it
                            num_retries=2,
                            stream=False,
                        )
                        logger.info(f"OUTPUT: {response.choices[0].message.content}")
                        try:
                            logger.info(f"REASONING: {response.choices[0].message.reasoning_content}")
                        except:
                            pass

                        responses.append(response)
                        random_string = parse_random_string(response.choices[0].message.content)
                        if random_string is not None:
                            random_strs.append(random_string)

                        chat_messages = build_prompt(cfg.model.litellm_model_name, system_prompt, user_prompt, random_strs)

                    except Exception as e:
                        responses.append(e)

                    logger.info("Sequential completion finished.")

            (
                total_reasoning_tokens,
                api_errors,
                results_data,
                all_random_strings,
                all_generated_integers,
                raw_outputs,
                raw_reasonings,
                total_input_tokens,
                total_output_tokens,
                integer_parse_errors,
                random_string_parse_errors,
            ) = process_responses(min_val, max_val, responses)

        except Exception as e:
            api_errors = 0
            results_data = []
            raw_outputs = []
            raw_reasonings = []
            all_random_strings = []
            all_generated_integers = []
            total_input_tokens = 0
            total_output_tokens = 0
            integer_parse_errors = 0

            logger.error(f"Fatal error during litellm.batch_completion: {e}")
            num_processed = len(results_data)
            num_failed = cfg.experiment.num_samples - num_processed
            if num_failed > 0:
                logger.warning(f"Batch completion failed after {num_processed} samples. Filling remaining {num_failed} with errors.")
                error_msg = f"LITELLM_FATAL_BATCH_ERROR: {e}"
                for _ in range(num_failed):
                    results_data.append((None, None, error_msg, error_msg, None, None, None))
                    raw_outputs.append(error_msg)
                    raw_reasonings.append(error_msg)
                api_errors += num_failed

    else:
        # --- Generate with Budget Forcing ---
        logger.info(f"Generating {cfg.experiment.num_samples} samples with budget forcing using {cfg.experiment.parallel_workers} workers...")

        total_input_tokens, total_output_tokens, bf_responses, api_errors = call_llm_budget_force(
            litellm_model_name=cfg.model.litellm_model_name,
            parallel_workers=cfg.experiment.parallel_workers,
            api_base=cfg.model.api_base,
            api_key=cfg.model.api_key,
            temperature=temperature,
            max_think_tokens=cfg.experiment.max_think_tokens,
            num_wait_insertion=cfg.experiment.num_wait_insertion,
            num_samples=cfg.experiment.num_samples,
            chat_messages=chat_messages,
            tokenizer_name=cfg.model.tokenizer_name,
            followup_prompt=build_random_int_followup_prompts(min_val, max_val, cfg.experiment.only_str, cfg.experiment.only_int),
            is_sequential_request=cfg.experiment.is_sequential_request,
            only_int=cfg.experiment.only_int,
        )
        logger.info("Budget forcing generation finished.")
        random_string_parse_errors, integer_parse_errors, api_errors, results_data, all_random_strings, all_generated_integers, raw_outputs, raw_reasonings = process_bf_responses(
            min_val,
            max_val,
            bf_responses,
            only_int=cfg.experiment.only_int,
        )

    logger.info(f"Processing complete. API/Processing Errors encountered: {api_errors}")

    # --- Save Aggregated Strings and Integers for Randomness Tests ---
    if len(all_random_strings) > 0 or len(all_generated_integers) > 0:
        logger.info("Appending generated strings and integers to files for randomness testing...")
        test_data_dir = "outputs/randomness_test_data"
        os.makedirs(test_data_dir, exist_ok=True)

        # Sanitize model name for filename
        sanitized_model_name = cfg.model.litellm_model_name.replace("/", "_")
        temp_part_agg = str(temperature).replace(".", "")

        range_part = f"range_{min_val}_{max_val}"
        if cfg.experiment.use_budget_forcing:
            range_part += f"__num_wait_{cfg.experiment.num_wait_insertion}__max_think_tokens_{cfg.experiment.max_think_tokens}"

        if not cfg.experiment.is_sequential_request:
            range_part += "__sequential"

        if cfg.experiment.only_str:
            range_part += "__only_str"

        if cfg.experiment.only_int:
            range_part += "__only_int"

        strings_filename_agg = os.path.join(
            test_data_dir,
            f"strings__{sanitized_model_name}__range_{range_part}__temp_{temp_part_agg}.txt",
        )
        integers_filename_agg = os.path.join(
            test_data_dir,
            f"integers__{sanitized_model_name}__range_{range_part}__temp_{temp_part_agg}.txt",
        )

        try:
            with open(strings_filename_agg, "a") as f_strings:
                for s in all_random_strings:
                    f_strings.write(s + "\n")
            logger.info(f"Appended {len(all_random_strings)} strings to {strings_filename_agg}")
        except Exception as e:
            logger.error(f"Error appending strings to {strings_filename_agg}: {e}")

        try:
            with open(integers_filename_agg, "a") as f_integers:
                for i in all_generated_integers:
                    f_integers.write(str(i) + "\n")
            logger.info(f"Appended {len(all_generated_integers)} integers to {integers_filename_agg}")
        except Exception as e:
            logger.error(f"Error appending integers to {integers_filename_agg}: {e}")

    # --- Analyze Results ---
    logger.info(f"\nAnalyzing {len(all_random_strings)} validly parsed random strings...")
    random_string_analysis = analyze_random_strings(all_random_strings)

    logger.info(f"Analyzing {len(all_generated_integers)} validly parsed random integers...")
    random_integer_analysis = analyze_random_integers(all_generated_integers, min_val, max_val)

    # --- Prepare Results Data ---
    serializable_cfg = OmegaConf.to_container(cfg, resolve=True)

    results_summary = {
        "parameters": serializable_cfg,
        "total_samples_attempted": cfg.experiment.num_samples,
        "total_valid_random_strings_parsed": len(all_random_strings),
        "total_valid_integers_parsed": len(all_generated_integers),
        "random_string_parse_errors": random_string_parse_errors,
        "integer_parse_errors": integer_parse_errors,
        "api_errors": api_errors,
        "total_input_tokens": total_input_tokens,
        "total_output_tokens": total_output_tokens,
        "total_reasoning_tokens": total_reasoning_tokens,
        "average_input_tokens_per_sample": total_input_tokens / cfg.experiment.num_samples if cfg.experiment.num_samples > 0 else 0,
        "average_output_tokens_per_sample": total_output_tokens / cfg.experiment.num_samples if cfg.experiment.num_samples > 0 else 0,
        "random_string_analysis": random_string_analysis,
        "random_integer_analysis": random_integer_analysis,
        "detailed_results": [
            {
                "random_string": r[0],
                "generated_integer": r[1],
                "raw_output": r[2],
                "raw_reasoning": r[3],
                "input_tokens": r[4],
                "output_tokens": r[5],
                "reasoning_tokens": r[6],
            }
            for r in results_data
        ],
        # Optionally keep raw_outputs separate if detailed_results includes it
        "raw_outputs": raw_outputs,
        "raw_reasonings": raw_reasonings,
    }

    # --- Print Summary ---
    logger.info("\n--- Results Summary ---")
    logger.info(f"Integer Range: [{min_val}, {max_val}]")
    logger.info(f"Total samples attempted: {cfg.experiment.num_samples}")
    logger.info(f"Valid random strings parsed: {len(all_random_strings)}")
    logger.info(f"Valid integers parsed: {len(all_generated_integers)}")
    logger.info(f"Random string parse errors: {random_string_parse_errors}")
    logger.info(f"Integer parse errors: {integer_parse_errors}")
    logger.info(f"API/Processing errors: {api_errors}")
    logger.info(f"Total Input Tokens: {total_input_tokens}")
    logger.info(f"Total Output Tokens: {total_output_tokens}")
    logger.info(f"Average Input Tokens/Sample: {results_summary['average_input_tokens_per_sample']:.2f}")
    logger.info(f"Average Output Tokens/Sample: {results_summary['average_output_tokens_per_sample']:.2f}")

    logger.info("\nRandom String Analysis:")
    # Reuse logging format from Miller-Rabin script
    logger.info(f"  Total collected: {random_string_analysis['total_strings_collected']}")
    logger.info(f"  Unique count: {random_string_analysis['unique_strings_count']}")
    logger.info(f"  Average length: {random_string_analysis.get('average_length', 0.0):.2f}")
    logger.info(f"  Min length: {random_string_analysis.get('min_length', 0)}")
    logger.info(f"  Max length: {random_string_analysis.get('max_length', 0)}")
    entropy_val = random_string_analysis.get("shannon_entropy_chars")
    logger.info(f"  Shannon Entropy (chars, base 2): {entropy_val if isinstance(entropy_val, (int, float)) else str(entropy_val)}")
    avg_run_len = random_string_analysis.get("average_run_length")
    logger.info(f"  Average Run Length: {avg_run_len:.3f}" if isinstance(avg_run_len, (int, float)) else str(avg_run_len))
    mcb_freq = random_string_analysis.get("most_common_bigram_freq")
    logger.info(f"  Most Common Bigram: '{random_string_analysis.get('most_common_bigram')}' (Freq: {mcb_freq:.4f})" if mcb_freq is not None else "  Most Common Bigram: N/A")

    logger.info("\nRandom Integer Analysis:")
    logger.info(f"  Total parsed: {random_integer_analysis['total_integers_parsed']}")
    if random_integer_analysis["total_integers_parsed"] > 0:
        logger.info(f"  Min generated: {random_integer_analysis['min_generated_int']}")
        logger.info(f"  Max generated: {random_integer_analysis['max_generated_int']}")
        logger.info(f"  Mean: {random_integer_analysis.get('mean_generated_int', 0.0):.3f}")
        logger.info(f"  Std Dev: {random_integer_analysis.get('std_dev_generated_int', 0.0):.3f}")
        chi2 = random_integer_analysis.get("chi_squared_uniformity")
        pval = random_integer_analysis.get("p_value_uniformity")
        logger.info(f"  Chi-Squared (Uniformity): {chi2 if isinstance(chi2, (int, float)) else str(chi2)}")
        logger.info(f"  P-value (Uniformity): {pval if isinstance(pval, (int, float)) else str(pval)}")
    else:
        logger.warning("  No valid integers were parsed for analysis.")

    logger.info("---------------------\n")

    # --- Save Run-Specific Results Summary ---
    output_subdir = os.path.join(hydra_output_dir, cfg.experiment.get("output_dir_suffix", "random_gen_results"))
    os.makedirs(output_subdir, exist_ok=True)

    # Use sanitized name for run-specific files too for consistency? Or keep original cfg.model.name?
    # Using original cfg.model.name here as before.
    model_name_part = cfg.model.name  # Keep original logic for run-specific summary file name
    temp_part = str(temperature).replace(".", "")
    range_part = f"range_{min_val}_{max_val}"
    samples_part = str(cfg.experiment.num_samples)

    filename_parts = [
        "results",
        model_name_part,
        range_part,
        f"temp_{temp_part}",
        f"samples_{samples_part}",
    ]
    base_filename = "__".join(filename_parts)[:200]  # Limit length
    results_filename = f"{base_filename}.json"
    results_path = os.path.join(output_subdir, results_filename)

    try:
        with open(results_path, "w") as f:
            # Convert numpy types etc. if necessary before dumping
            serializable_summary = json.loads(
                json.dumps(
                    results_summary,
                    default=lambda x: int(x) if isinstance(x, np.integer) else (float(x) if isinstance(x, np.floating) else str(x)),
                )
            )
            json.dump(serializable_summary, f, indent=4)
        logger.info(f"Results summary saved to {results_path}")
    except Exception as e:
        logger.error(f"Error saving results summary to {results_path}: {e}")

    logger.info("Random Generation Experiment Run Finished")


def process_responses(min_val: int, max_val: int, responses: list):
    results_data: List[Tuple[Optional[str], Optional[int], Optional[str], Optional[str], Optional[int], Optional[int], Optional[int]]] = []  # Added type hint
    all_random_strings: list[str] = []
    all_generated_integers: list[int] = []
    raw_outputs: list[str] = []
    raw_reasonings: list[Optional[str]] = []
    api_errors: int = 0
    integer_parse_errors: int = 0
    random_string_parse_errors: int = 0

    total_input_tokens: int = 0
    total_output_tokens: int = 0
    total_reasoning_tokens: int = 0

    for i, response in enumerate(tqdm(responses, desc="Processing Batch Results")):
        if isinstance(response, Exception):
            logger.error(f"Error in batch response for sample {i}: {response}")
            api_errors += 1
            error_msg = f"LITELLM_BATCH_ERROR: {response}"
            raw_outputs.append(error_msg)
            raw_reasonings.append(None)  # Append None on error to match type hint
            results_data.append((None, None, error_msg, None, None, None, None))  # Use None for reasoning
            continue

        generated_text = None
        reasoning_content = None
        input_tokens = 0
        output_tokens = 0
        reasoning_tokens = 0
        try:
            # Extract token usage first
            try:
                if hasattr(response, "usage") and response.usage:
                    input_tokens: int = response.usage.prompt_tokens
                    output_tokens: int = response.usage.completion_tokens
                    total_input_tokens += input_tokens if input_tokens is not None else 0
                    total_output_tokens += output_tokens if output_tokens is not None else 0
                    reasoning_tokens = 0
                    try:
                        # Ensure reasoning_tokens is treated as int or 0
                        reasoning_tokens_val = response.usage.completion_tokens_details.reasoning_tokens
                        reasoning_tokens = int(reasoning_tokens_val) if reasoning_tokens_val is not None else 0
                    except Exception:
                        reasoning_tokens = 0  # Default to 0 if not found or error
                    total_reasoning_tokens += reasoning_tokens  # Now reasoning_tokens is guaranteed to be int
                else:
                    logger.warning(f"No usage information found in response object for sample {i}")
            except AttributeError:
                logger.warning(f"AttributeError accessing usage info for sample {i}. Response object might lack 'usage'.")
            except Exception as usage_e:
                logger.warning(f"Unexpected error accessing usage info for sample {i}: {usage_e}")

                # Extract content
            generated_text = response.choices[0].message.content
            raw_outputs.append(generated_text)

            # Extract reasoning content if available
            try:
                reasoning_content = response.choices[0].message.reasoning_content
                raw_reasonings.append(reasoning_content)
            except AttributeError:
                raw_reasonings.append(None)
            except Exception as re_err:
                logger.warning(f"Error accessing reasoning_content for sample {i}: {re_err}")
                raw_reasonings.append(f"REASONING_ACCESS_ERROR: {re_err}")

                # Parse random_string and random_integer
            random_string = parse_random_string(generated_text)
            if random_string is None:
                random_string_parse_errors += 1
            else:
                all_random_strings.append(random_string)

            random_integer = parse_random_integer(generated_text, min_val, max_val)
            if random_integer is None:
                integer_parse_errors += 1
            else:
                all_generated_integers.append(random_integer)

            results_data.append(
                (
                    random_string,
                    random_integer,
                    generated_text,
                    reasoning_content,
                    input_tokens,
                    output_tokens,
                    reasoning_tokens,
                )
            )

        except Exception as e:
            logger.error(f"Error processing result for sample {i}: {e}")
            error_msg = f"PROCESSING_ERROR: {e}"
            results_data.append(
                (
                    None,
                    None,
                    generated_text or error_msg,
                    reasoning_content or error_msg,
                    input_tokens,
                    output_tokens,
                    reasoning_tokens,
                )
            )
            api_errors += 1
            if generated_text is None and len(raw_outputs) == i:
                raw_outputs.append(error_msg)
            if reasoning_content is None and len(raw_reasonings) == i:
                raw_reasonings.append(error_msg)

    return (
        total_reasoning_tokens,
        api_errors,
        results_data,
        all_random_strings,
        all_generated_integers,
        raw_outputs,
        raw_reasonings,
        total_input_tokens,
        total_output_tokens,
        integer_parse_errors,
        random_string_parse_errors,
    )


def process_bf_responses(min_val: int, max_val: int, bf_responses: list[dict], only_int: bool):
    results_data: List[Tuple[Optional[str], Optional[int], Optional[str], Optional[str], Optional[int], Optional[int], Optional[int]]] = []  # Added type hint
    all_random_strings: list[str] = []
    all_generated_integers: list[int] = []
    raw_outputs: list[str] = []
    raw_reasonings: list[str] = []
    api_errors = 0
    integer_parse_errors = 0
    random_string_parse_errors = 0

    for i, bf_result in enumerate(tqdm(bf_responses, desc="Processing Budget Forcing Results")):
        generated_text = bf_result["text"]
        input_tokens = bf_result["input_tokens"]
        output_tokens = bf_result["output_tokens"]
        reasoning_content = bf_result["prompt_think"]
        reasoning_tokens = bf_result["reasoning_tokens"]

        raw_outputs.append(generated_text)
        raw_reasonings.append(reasoning_content)

        if bf_result["error"]:
            results_data.append(
                (
                    None,
                    None,
                    generated_text,
                    reasoning_content,
                    input_tokens,
                    output_tokens,
                    reasoning_tokens,
                )
            )
            continue

        try:
            # Parse random_string and random_integer
            if not only_int:
                random_string = parse_random_string(generated_text)
            else:
                random_string = parse_random_digit_sequence(generated_text)
            if random_string is None:
                random_string_parse_errors += 1
            else:
                all_random_strings.append(random_string)

            random_integer = parse_random_integer(generated_text, min_val, max_val)
            if random_integer is None:
                integer_parse_errors += 1
            else:
                all_generated_integers.append(random_integer)

            results_data.append((random_string, random_integer, generated_text, reasoning_content, input_tokens, output_tokens, reasoning_tokens))

        except Exception as e:
            logger.error(f"Error processing budget forcing result for sample {i}: {e}")
            error_msg = f"PROCESSING_ERROR: {e}"
            # Store whatever data we have along with the error
            results_data.append((None, None, generated_text or error_msg, reasoning_content, input_tokens, output_tokens, reasoning_tokens))
            api_errors += 1

    return random_string_parse_errors, integer_parse_errors, api_errors, results_data, all_random_strings, all_generated_integers, raw_outputs, raw_reasonings


if __name__ == "__main__":
    main()
