"""Load and filter data from the lennart-finke/SimpleStories Hugging Face dataset."""

import json
from typing import Any, Dict, Tuple

from datasets import Dataset, load_dataset

from eliciting_contexts.benchmark.external.utils.logger import logger
from eliciting_contexts.utils.constants import BASE_DIR


def load_and_filter_stories(
    min_word_count: int | None = None,
    max_word_count: int | None = None,
    theme: str | None = None,
    split: str = "train",
) -> Tuple[Dataset, list[int] | None]:
    """
    Loads the lennart-finke/SimpleStories dataset and filters it based on word count and theme.

    Args:
        min_word_count: Minimum number of words a story must have.
        max_word_count: Maximum number of words a story can have.
        theme: The specific theme to filter stories by (e.g., 'animals', 'friendship').
        split: The dataset split to load (e.g., 'train', 'validation').

    Returns:
        A tuple containing:
          - The filtered Hugging Face Dataset object.
          - A list of indices corresponding to the filtered items in the original dataset, or None if no filtering occurred.
    """
    logger.info("Loading dataset lennart-finke/SimpleStories...")
    dataset = load_dataset("lennart-finke/SimpleStories", split=split)
    logger.info(f"Dataset loaded with {len(dataset)} stories.")

    # Apply filters
    if min_word_count is not None or max_word_count is not None or theme is not None:
        logger.info("Applying filters...")
        original_count = len(dataset)

        # Define mapping function to apply filter logic and capture index
        def _map_filter_with_index(
            example: Dict[str, Any], index: int
        ) -> Dict[str, Any]:
            story_text = example["story"]
            word_count = len(story_text.split())
            story_theme = example["theme"]
            keep = True

            # Word count filter
            if min_word_count is not None and word_count < min_word_count:
                keep = False
            if keep and max_word_count is not None and word_count > max_word_count:
                keep = False

            # Theme filter
            if keep and theme is not None and story_theme != theme:
                keep = False

            return {"keep": keep, "original_index": index}

        # Apply the mapping function
        # Use batched=False as the function processes example by example
        mapped_dataset = dataset.map(
            _map_filter_with_index,
            with_indices=True,
            batched=False,
            remove_columns=dataset.column_names,  # Remove original cols temporarily to avoid schema conflicts
        )

        # Filter based on the 'keep' flag
        filtered_mapped_dataset = mapped_dataset.filter(lambda x: x["keep"])

        # Extract original indices
        if not filtered_mapped_dataset:
            original_indices = []
            final_dataset = filtered_mapped_dataset.remove_columns(
                ["keep", "original_index"]
            )  # Return empty dataset with original schema
        else:
            original_indices = filtered_mapped_dataset["original_index"]
            # Select the filtered rows from the *original* dataset using the indices
            final_dataset = dataset.select(original_indices)

        logger.info(
            f"Filtered dataset down to {len(final_dataset)} stories from {original_count}."
        )
        return final_dataset, original_indices
    else:
        logger.info("No filters applied.")
        return dataset, None


def load_selected_stories(
    indices_file_path: str, split: str = "train"
) -> Dataset | None:
    """
    Loads specific stories from the lennart-finke/SimpleStories dataset based on indices
    stored in a JSON file.

    The JSON file should contain a dictionary mapping some key (e.g., theme) to the
    original index of the story in the dataset.

    Args:
        indices_file_path: Path to the JSON file containing the story indices.
        split: The dataset split to load from (e.g., 'train', 'validation').

    Returns:
        A Hugging Face Dataset object containing only the selected stories, or None if an error occurs.
    """
    try:
        logger.info(f"Loading selected indices from {indices_file_path}...")
        with open(indices_file_path, "r") as f:
            selected_indices_map = json.load(f)

        # Extract the list of indices to load
        indices_to_load = []
        if isinstance(selected_indices_map, dict):
            # Flatten the list of lists from the dictionary values
            for theme_indices in selected_indices_map.values():
                if isinstance(theme_indices, list):
                    indices_to_load.extend(theme_indices)
                else:
                    logger.warning(
                        f"Expected a list of indices for a theme, but got {type(theme_indices)}. Skipping."
                    )
            # Remove duplicates and sort for consistency
            indices_to_load = sorted(list(set(indices_to_load)))
        else:
            logger.error("Indices file does not contain a valid dictionary structure.")
            return None

        if not indices_to_load:
            logger.warning("No indices found in the file.")
            return None
        logger.info(f"Indices to load: {indices_to_load}")

        logger.info("Loading full dataset lennart-finke/SimpleStories...")
        full_dataset = load_dataset("lennart-finke/SimpleStories", split=split)
        logger.info(f"Full dataset loaded with {len(full_dataset)} stories.")

        logger.info("Selecting stories based on loaded indices...")
        selected_dataset = full_dataset.select(indices_to_load)
        logger.info(f"Selected {len(selected_dataset)} stories.")
        return selected_dataset

    except FileNotFoundError:
        logger.error(f"Error: Indices file not found at {indices_file_path}")
        return None
    except json.JSONDecodeError:
        logger.error(f"Error: Could not decode JSON from {indices_file_path}")
        return None
    except Exception as e:
        logger.error(f"An unexpected error occurred: {e}")
        return None


def select_and_save_indices_per_theme(
    output_indices_file: str,
    stories_per_theme: int = 3,
    max_word_count: int | None = 150,
    split: str = "train",
) -> dict[str, list[int]] | None:
    """
    Filters stories by max_word_count, selects up to a specified number of original indices per theme,
    and saves the theme-to-indices mapping to a JSON file.

    Args:
        output_indices_file: Path to save the JSON file with selected indices.
        stories_per_theme: The maximum number of stories to select for each theme.
        max_word_count: Maximum number of words for filtering stories.
        split: The dataset split to use.

    Returns:
        A dictionary mapping theme to a list of selected original story indices, or None on failure.
    """
    logger.info(
        f"Filtering stories (max_word_count={max_word_count}) to select indices..."
    )
    stories_filtered, original_indices_filtered = load_and_filter_stories(
        max_word_count=max_word_count, split=split
    )

    if stories_filtered is None or original_indices_filtered is None:
        logger.error("Filtering failed or returned no results.")
        return None
    elif len(stories_filtered) == 0:
        logger.warning("No stories found matching the criteria.")
        return None

    logger.info(f"Found {len(stories_filtered)} stories under {max_word_count} words.")
    logger.info(
        f"Selecting up to {stories_per_theme} story indices per theme (using original dataset indices):"
    )
    unique_themes = sorted(list(set(stories_filtered["theme"])))
    logger.info(f"Found themes: {unique_themes}")

    selected_original_indices: Dict[str, list[int]] = {
        theme: [] for theme in unique_themes
    }
    themes_found_count = {theme: 0 for theme in unique_themes}
    total_selected_count = 0

    for i, story in enumerate(stories_filtered):
        theme = story["theme"]
        original_index = original_indices_filtered[i]
        # Check if the theme is valid and if we still need stories for this theme
        if (
            theme in themes_found_count
            and themes_found_count[theme] < stories_per_theme
        ):
            selected_original_indices[theme].append(original_index)
            themes_found_count[theme] += 1
            total_selected_count += 1

    logger.info(
        f"Selected a total of {total_selected_count} indices across {len(unique_themes)} themes."
    )
    for theme, count in themes_found_count.items():
        logger.info(f"  - Theme '{theme}': Selected {count} stories.")
        if count < stories_per_theme:
            logger.warning(
                f"    - Could only find {count} stories for theme '{theme}' matching criteria."
            )

    logger.info(f"Saving selected original indices to {output_indices_file}...")
    try:
        with open(output_indices_file, "w") as f:
            json.dump(selected_original_indices, f, indent=4)
        logger.info("Indices saved successfully.")
        return selected_original_indices
    except IOError as e:
        logger.error(f"Error saving indices to file: {e}")
        return None


def save_stories_to_raw_data(selected_stories_dataset: Dataset, raw_data_path: str):
    """
    Formats selected stories and appends them as a list of tuples
    (story_text, theme) to a Python file.

    Args:
        selected_stories_dataset: Dataset containing the stories to save.
        raw_data_path: Path to the Python file (e.g., 'raw_data.py') to append to.
    """
    logger.info(f"Formatting stories for {raw_data_path}...")
    data_to_save = []
    for story in selected_stories_dataset:
        data_to_save.append((story["story"], story["theme"]))

    # Format as a Python list string
    # Using repr() to get proper quoting and escaping for the strings
    formatted_list_string = "simple_stories_selection = [\n"
    for item in data_to_save:
        formatted_list_string += f"    ({repr(item[0])}, {repr(item[1])}),\n"
    formatted_list_string += "]\n"

    logger.info(f"Appending formatted data to {raw_data_path}...")
    try:
        with open(raw_data_path, "a") as f:
            f.write("\n\n" + formatted_list_string)  # Add some spacing
        logger.info("Data appended successfully.")
    except IOError as e:
        logger.error(f"Error appending data to {raw_data_path}: {e}")


if __name__ == "__main__":
    INDICES_FILE = (
        BASE_DIR
        / "src/eliciting_contexts/benchmark/internal/tiny_stories/selected_simple_story_indices.json"
    )
    RAW_DATA_FILE = (
        BASE_DIR / "src/eliciting_contexts/benchmark/internal/tiny_stories/raw_data.py"
    )

    logger.info("=== Part 1: Filter, Select, and Save Indices ===")
    selected_indices_map = select_and_save_indices_per_theme(
        output_indices_file=INDICES_FILE, stories_per_theme=3, max_word_count=150
    )

    logger.info("=== Part 2: Load Stories from Saved Indices ===")
    selected_stories_dataset = load_selected_stories(INDICES_FILE)

    if selected_stories_dataset:
        logger.info("Successfully loaded selected stories.")
        logger.info(f"=== Part 3: Save Stories to {RAW_DATA_FILE} ===")
        save_stories_to_raw_data(selected_stories_dataset, RAW_DATA_FILE)
    else:
        logger.error("Failed to load stories using the saved indices file.")
