import asyncio
import os
import pathlib
import random
import re
import sys
from pathlib import Path
from typing import List, Optional

from datasets import Dataset
from jinja2 import Environment, FileSystemLoader
from safetytooling.apis import InferenceAPI
from safetytooling.data_models import ChatMessage, MessageRole, Prompt
from tqdm.auto import tqdm


def add_to_python_path(*relative_paths: str) -> None:
    """
    Add multiple paths to Python path relative to the repository root.
    Repository root is assumed to be one level up from the notebooks directory.
    """
    notebook_dir = pathlib.Path(os.getcwd())
    repo_root = notebook_dir.parent if notebook_dir.name == "notebooks" else notebook_dir

    for path in relative_paths:
        full_path = repo_root / path
        if full_path.exists():
            if str(full_path) not in sys.path:
                sys.path.append(str(full_path))
        else:
            print(f"Warning: Path {full_path} does not exist")


def get_project_root() -> Path:
    """Returns the project root directory."""
    return Path(__file__).resolve().parent.parent


def load_prompt_file(prompt_path: str | Path) -> str:
    """Load system prompt from file relative to project root."""
    project_root = get_project_root()
    env = Environment(loader=FileSystemLoader(str(project_root)), autoescape=False)

    # Convert the full path to be relative to project root
    full_path = Path(prompt_path)
    if full_path.is_absolute():
        relative_path = full_path.relative_to(project_root)
    else:
        relative_path = full_path

    path_str = str(relative_path)
    try:
        template = env.get_template(path_str)
        return template.render()
    except Exception as e:
        print(f"Error loading template from {path_str}: {str(e)}")
        print(f"Project root: {project_root}")
        print(f"Full path: {full_path}")
        print(f"Relative path: {relative_path}")
        raise


def get_latest_json(
    directory: Path | str,
    look_for_rerun: bool = False,
    rerun_dir_name: Path | str = "rerun",
    filter_prefix: str | None = "results",
    get_oldest: bool = False,
) -> Path:
    """
    Find the most recent JSON file in the specified directory.
    """
    directory = Path(directory)
    if look_for_rerun:
        # Check if rerun directory exists
        rerun_dir = directory / rerun_dir_name
        if rerun_dir.exists():
            directory = rerun_dir
    if not directory.exists():
        raise FileNotFoundError(f"Directory {directory} does not exist")
    json_files = list(directory.glob("*.json"))

    # Filter out file names that don't start with "results"-
    if filter_prefix is not None:
        json_files = [file for file in json_files if file.name.startswith(filter_prefix)]

    if not json_files:
        raise FileNotFoundError(f"No JSON files found in {directory}")

    if get_oldest:
        return min(json_files, key=os.path.getctime)
    else:
        return max(json_files, key=os.path.getctime)


class TemplateLoader:
    def __init__(self, template_path: str):
        self.template_path = template_path

    def __call__(self) -> str:
        return load_prompt_file(self.template_path)


async def generate_response_dataset(
    prompts: List[str],
    model_id: str,
    inference_api: InferenceAPI,
    system_prompt: str | List[str] | None = None,
    include_system_prompt_column: bool = False,
    n_concurrent: int | None = None,
    prefill: str | List[str] | None = None,
    **api_kwargs,
) -> Dataset:
    """
    Generate a dataset of responses from a model to a list of prompts.

    Args:
        prompts: List of prompts to send to the model
        model_id: Model identifier to use
        inference_api: InferenceAPI instance to use for requests
        system_prompt: Optional system prompt(s). If a list, must match length of prompts
        include_system_prompt_column: Whether to include system prompts in output dataset
        n_concurrent: Maximum number of concurrent API calls. If None, no concurrency limit
        prefill: Optional prefill(s) to randomly sample from.
        **api_kwargs: Additional kwargs to pass to the API call

    Returns:
        Dataset containing prompts and model responses
    """
    if isinstance(system_prompt, list) and len(system_prompt) != len(prompts):
        raise ValueError("If system_prompt is a list, it must match the length of prompts")

    async def process_prompt(i: int, prompt: str) -> tuple[int, str, str | None]:
        """Process a single prompt.

        Returns:
            Tuple of (index, completion, system_prompt)
        """
        # Prepare messages list
        messages = []

        # Add system prompt if provided
        curr_system_prompt = system_prompt[i] if isinstance(system_prompt, list) else system_prompt
        if curr_system_prompt:
            messages.append(ChatMessage(role=MessageRole.system, content=curr_system_prompt))

        # Add user message
        messages.append(ChatMessage(role=MessageRole.user, content=prompt))

        # Add prefill if provided
        if prefill:
            if isinstance(prefill, list):
                prefill_str = random.choice(prefill)
            else:
                prefill_str = prefill
            messages.append(ChatMessage(role=MessageRole.assistant, content=prefill_str.strip()))
        else:
            prefill_str = ""

        # Make API call
        try:
            response = await inference_api(model_id=model_id, prompt=Prompt(messages=messages), **api_kwargs)
            completion = prefill_str.strip() + " " + response[0].completion
            return i, completion, curr_system_prompt
        except Exception as e:
            print(f"Error generating response for prompt {prompt}: {e}")
            return None

    # Create tasks for all prompts
    if n_concurrent is not None:
        # Use semaphore to limit concurrency
        semaphore = asyncio.Semaphore(n_concurrent)

        async def limited_process_prompt(i: int, prompt: str) -> tuple[int, str, str | None]:
            async with semaphore:
                return await process_prompt(i, prompt)

        tasks = [limited_process_prompt(i, prompt) for i, prompt in enumerate(prompts)]
    else:
        # No concurrency limit
        tasks = [process_prompt(i, prompt) for i, prompt in enumerate(prompts)]

    # Process all prompts concurrently with progress bar
    results = await tqdm.gather(*tasks)
    results = [result for result in results if result is not None]

    # Sort results by index to ensure alignment
    sorted_results = sorted(results, key=lambda x: x[0])

    # Unzip results, discarding the indices
    indices, responses, system_prompts_used = zip(*sorted_results)

    # Create dataset
    data = {
        "prompt": [prompts[i] for i in indices],
        "completion": responses,
    }

    if include_system_prompt_column and any(system_prompts_used):
        data["system_prompt"] = system_prompts_used

    return Dataset.from_dict(data)


class GenerateQuestions:
    """Generates test questions based on provided criteria."""

    def __init__(
        self,
        model_id: str = "x-ai/grok-3-beta",
        api: Optional[InferenceAPI] = None,
    ):
        """Initialize the question generator.

        Args:
            criteria: String containing the criteria to test
            model_id: Model ID to use for generation (default: grok-3-beta)
            api: Optional InferenceAPI instance
        """
        self.model_id = model_id

        # Load the question generator prompt internally
        prompt_path = (
            get_project_root() / "prompts" / "classifiers" / "general_classifiers" / "general_question_generator.jinja2"
        )
        self.system_prompt = load_prompt_file(prompt_path)

        # Set up API
        if api is None:
            self.api = InferenceAPI(
                cache_dir=None,
                anthropic_num_threads=10,
                together_num_threads=20,
            )
        else:
            self.api = api

    def _create_prompt(self, criteria: str) -> Prompt:
        """Create the formatted prompt for question generation."""
        formatted_prompt = self.system_prompt.format(helpfulness_criteria=criteria)
        return Prompt(messages=[ChatMessage(content=formatted_prompt, role=MessageRole.user)])

    def _extract_questions(self, output: str) -> List[str]:
        """Extract and clean questions from the output."""
        # Find all numbered questions
        numbered_questions = re.findall(r"(\d+\)\s*.*?\?)", output)

        # Clean up the questions by removing the numbering
        cleaned_questions = []
        for question in numbered_questions:
            clean_question = re.sub(r"^\d+\)\s*", "", question)
            cleaned_questions.append(clean_question)

        return cleaned_questions

    async def __call__(self, criteria: str) -> List[str]:
        """Generate questions based on the criteria.

        Returns:
            List of generated questions
        """
        # Create the prompt
        prompt = self._create_prompt(criteria)

        # Get the model response
        response = await self.api(
            model_id=self.model_id,
            prompt=prompt,
            temperature=0.7,  # Good balance of creativity and coherence
            max_tokens=4000,  # Plenty of space for 100+ questions
        )

        # Extract questions from the response
        questions = self._extract_questions(response[0].completion)

        # Log stats
        print(f"Generated {len(questions)} questions")

        return questions
