import asyncio
import json
import logging
import os
import time
from abc import ABC, abstractmethod
from pathlib import Path

import hydra
import litellm
from aiofiles import open as aio_open
from datasets import load_dataset
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from tqdm.auto import tqdm

from pcot.parameters import get_default_max_tokens, get_default_temperature
from pcot.prompts.rsm import (
    BASELINE_DIVERSE_PROMPT,
    DEEP_THINK_DIVERSE_PROMPT,
    RSM_DIVERSE_PROMPT,
    RSM_DIVERSE_SAFE_PROMPT,
    extract_answer,
)

logger = logging.getLogger(__name__)


class InferenceService(ABC):
    @abstractmethod
    def generate(
        self, model: str, messages: list[dict[str, str]], **kwargs
    ) -> tuple[list[str], list[str], list[str]]:
        pass

    def cleanup(self):
        print("Done!")


class LiteLLMService(InferenceService):
    def __init__(
        self,
        system_prompt_override: str | None = None,
        logging_dir: Path | None = None,
        temperature: float | None = None,
    ):
        self.system_prompt_override = system_prompt_override
        self.logging_dir = logging_dir
        self.temperature = temperature

    def generate(
        self, model: str, messages: list[dict[str, str]], n: int = 1, **kwargs
    ) -> tuple[list[str], list[str], list[str]]:
        if self.system_prompt_override is not None:
            if messages[0]["role"] == "system":
                messages[0]["content"] = self.system_prompt_override
            else:
                messages = [
                    {"role": "system", "content": self.system_prompt_override}
                ] + messages

        rest_gens = n
        resps: list[str] = []
        reasonings = []
        full_gen = []

        kwargs = dict()
        if "qwq" in model:
            provider = "DeepInfra"
        else:
            provider = "lambda"

        kwargs["extra_body"] = {
            "provider": {
                "order": [provider],
                # "require_parameters": True,
                "allow_fallbacks": False,
            }
        }
        
        if "deepseek-r1" in model:
            kwargs["extra_body"] = {
                "provider": {
                    "quantizations": ["fp8"],
                    "allow_fallbacks": True,
                }
            }

        if model.startswith("deepseek"):
            kwargs["timeout"] = 300
            kwargs["num_retries"] = 3
        else:
            kwargs["num_retries"] = 3

        for _ in range(10):
            responses = litellm.batch_completion(
                model=model,
                messages=[messages] * rest_gens,
                temperature=self.temperature or get_default_temperature(model),
                max_tokens=get_default_max_tokens(model),
                input_cost_per_token=0,
                output_cost_per_token=0,
                num_workers=50,
                **kwargs,
            )
            for response in responses:
                try:
                    generated_text: str = response.choices[0].message.content
                    if isinstance(generated_text, str) and len(generated_text) > 0:
                        logger.info(
                            f"REQUEST: {messages[-1]}\nGENERATED: {generated_text}"
                        )
                        answer = extract_answer(generated_text)
                        if answer:
                            resps.append(answer)
                            reasonings.append(
                                response.choices[0].message.reasoning_content
                            )
                            full_gen.append(generated_text)
                except:
                    pass
            if len(resps) == n:
                break
            rest_gens = n - len(resps)
        return resps, reasonings, full_gen


def run_generation(
    service: InferenceService,
    model: str,
    prompt: str,
    prompt_paraphrases: list[str] | None,
    num_generations: int,
    sampling: str,
    max_retries: int = 1,
) -> tuple[list[str], list[str], list[str]]:
    responses: list[str] = []
    reasonings = []
    full_gen = []
    messages = [{"role": "user", "content": prompt}]

    for attempt in range(max_retries):
        try:
            if sampling == "regenerate":
                # parallel generation w/o context
                responses, reasonings, full_gen = service.generate(
                    model=model,
                    messages=messages,
                    max_tokens=512,
                    temperature=1.0,
                    n=num_generations,
                )

            elif sampling == "in-context":
                while len(responses) < num_generations:
                    resps, reasons, fgens = service.generate(
                        model=model,
                        messages=messages,
                        max_tokens=512,
                        temperature=1.0,
                    )
                    new_response = resps[0]
                    messages.append({"role": "assistant", "content": new_response})
                    messages.append(
                        {
                            "role": "user",
                            "content": "Can you generate a different answer?",
                        }
                    )

                    responses += resps
                    reasonings += reasons
                    full_gen += fgens

            elif sampling == "paraphrase":
                assert prompt_paraphrases and len(prompt_paraphrases) >= num_generations
                while len(responses) < num_generations:
                    messages = [
                        {"role": "user", "content": prompt_paraphrases[len(responses)]}
                    ]
                    resps, reasons, fgens = service.generate(
                        model=model,
                        messages=messages,
                        max_tokens=512,
                        temperature=1.0,
                    )
                    responses += resps
                    reasonings += reasons
                    full_gen += fgens

            elif sampling == "system-prompt":
                messages = [
                    {
                        "role": "system",
                        "content": "You are a producer of unique answers, and you strive to tell each user a novel answer to their question.",
                    },
                    {"role": "user", "content": prompt},
                ]
                responses = service.generate(
                    model=model,
                    messages=messages,
                    max_tokens=512,
                    temperature=1.0,
                    n=num_generations,
                )
            else:
                raise Exception("Unknown mode " + sampling)

            return responses, reasonings, full_gen

        except Exception as e:
            if attempt == max_retries - 1:  # Last attempt
                print(
                    f"Error generating response for prompt '{prompt}' after {max_retries} attempts: {e}",
                    flush=True,
                )
                return [], []

            # Exponential backoff
            wait_time = min(5 * 2**attempt, 60)  # 5, 10, 20, 40, 60, 60, ... seconds
            print(
                f"Attempt {attempt + 1} failed, retrying in {wait_time} seconds...",
                flush=True,
            )
            time.sleep(wait_time)

    raise RuntimeError(f"Failed after {max_retries} attempts.")


async def process_prompts(
    prompts,
    service,
    model,
    output_file,
    num_generations,
    concurrent_requests,
    sampling,
):
    """Processes all prompts concurrently and writes results to a file."""
    async with aio_open(output_file, "a", buffering=1) as f:
        semaphore = asyncio.Semaphore(concurrent_requests)

        async def process_single_prompt(prompt):
            async with semaphore:
                generations, reasonings, full_gens = await asyncio.to_thread(
                    run_generation,
                    service,
                    model,
                    prompt["prompt"],
                    prompt.get("prompt_paraphrases"),
                    num_generations,
                    sampling,
                )
                return {
                    "id": prompt["id"],
                    "prompt": prompt["prompt"],
                    "model": model,
                    "generations": generations,
                    "reasonings": reasonings,
                    "full_generations": full_gens,
                }

        tasks = [process_single_prompt(prompt) for prompt in prompts]
        for task in tqdm(asyncio.as_completed(tasks), total=len(prompts)):
            result = await task
            await f.write(json.dumps(result) + "\n")

def load_jsonl(path):
    d = []

    with open(path, 'r') as f:
        for line in f:
            d.append(json.loads(line))
    return d


@hydra.main(
    config_path="../../../../scripts/conf/novelty_bench",
    config_name="config",
    version_base=None,
)
def main(cfg: DictConfig) -> None:
    """
    Hydra runner script for novelty bench inference.
    """
    logger.info("Starting Novelty Bench Inference")
    hydra_output_dir = HydraConfig.get().runtime.output_dir
    logger.info(f"Hydra Run Output Directory: {hydra_output_dir}")
    logger.info("Loaded configuration:\n%s", OmegaConf.to_yaml(cfg))

    # Load dataset
    dataset = load_jsonl("src/pcot/tasks/novelty_bench/data/curated_id_to_category_map.jsonl")

    # Determine eval directory
    if cfg.experiment.eval_dir:
        eval_dir = cfg.experiment.eval_dir
    else:
        eval_dir = os.path.join(
            hydra_output_dir, f"{cfg.experiment.data}-evals", cfg.model.name
        )

    os.makedirs(eval_dir, exist_ok=True)
    output_file = os.path.join(eval_dir, "generations.jsonl")

    logger.info(f"Output file: {output_file}")

    # Check for existing results and filter dataset
    if os.path.exists(output_file) and os.path.getsize(output_file) > 0:
        dataset_keys = set(dataset["id"])
        existing_output = load_dataset("json", data_files=output_file, split="train")
        existing_output = existing_output.filter(
            lambda x: len(x["generations"]) == cfg.experiment.num_generations
            and x["id"] in dataset_keys
        )

        # Save filtered dataset back to output file
        with open(output_file, "w") as f:
            for item in existing_output:
                f.write(json.dumps(item) + "\n")

        existing_keys = set(existing_output["id"])
        # Filter dataset to only include missing or invalid items
        dataset = dataset.filter(lambda x: x["id"] not in existing_keys)

        if len(dataset) == 0:
            logger.info("All prompts have valid generations. Skipping.")
            return
        else:
            logger.info(f"Generating {len(dataset)} missing or invalid entries.")

    # Initialize service and run inference
    if cfg.experiment.prompt == "rsm":
        system_prompt_override = RSM_DIVERSE_PROMPT
    elif cfg.experiment.prompt == "none":
        system_prompt_override = None
    elif cfg.experiment.prompt == "rsm_safe":
        system_prompt_override = RSM_DIVERSE_SAFE_PROMPT
    elif cfg.experiment.prompt == "baseline":
        system_prompt_override = BASELINE_DIVERSE_PROMPT
    elif cfg.experiment.prompt == "deep_think":
        system_prompt_override = DEEP_THINK_DIVERSE_PROMPT
    else:
        raise RuntimeError(f"Invalid prompt type {cfg.experiment.prompt}")

    service = LiteLLMService(
        system_prompt_override=system_prompt_override,
        logging_dir=Path(eval_dir),
        temperature=cfg.sampling.temperature,
    )

    logger.info(f"Starting inference with model: {cfg.model.litellm_model_name}")
    logger.info(f"Sampling strategy: {cfg.experiment.sampling}")
    logger.info(f"Number of generations per prompt: {cfg.experiment.num_generations}")
    logger.info(f"Concurrent requests: {cfg.experiment.concurrent_requests}")

    logger.info(f"Prompt type: {cfg.experiment.prompt}")

    asyncio.run(
        process_prompts(
            dataset,
            service,
            cfg.model.litellm_model_name,
            output_file,
            cfg.experiment.num_generations,
            cfg.experiment.concurrent_requests,
            cfg.experiment.sampling,
        )
    )

    logger.info("Novelty Bench Inference completed successfully")


if __name__ == "__main__":
    main()
