from typing import Tuple

from datasets import Dataset, load_dataset


def extract_post_system_prompt(prompt: str) -> str:
    """
    Extracts everything after the system prompt from a conversation.

    Args:
        prompt (str): The full conversation prompt

    Returns:
        str: The conversation content after the system prompt
    """
    # Find the first <end_of_turn> which marks the end of system prompt
    system_end = prompt.find("<end_of_turn>")
    if system_end == -1:
        return prompt  # Return original prompt if no system prompt found

    # Return everything after the first <end_of_turn> tag
    return prompt[system_end + len("<end_of_turn>") :].strip()


def extract_system_prompt(prompt: str) -> str:
    """
    Extracts only the system prompt from a conversation.

    Args:
        prompt (str): The full conversation prompt

    Returns:
        str: The system prompt content, or empty string if no system prompt found
    """
    # Find the start of system prompt
    system_start = prompt.find("<start_of_turn>system")
    if system_start == -1:
        return ""

    # Move past the system tag
    content_start = system_start + len("<start_of_turn>")

    # Find the end of system prompt
    system_end = prompt.find("<end_of_turn>", content_start)
    if system_end == -1:
        return ""

    # Extract and return the content between tags
    return prompt[content_start:system_end].strip()


def load_sandbagging_data(
    dataset_name: str = "contextmodification/sandbagging-sciq", split: str = "train"
) -> Tuple[Dataset, Dataset]:
    """ """

    dataset = load_dataset(dataset_name, split=split)

    pos_dataset = dataset.filter(lambda x: x["sandbagging_environment"])
    neg_dataset = dataset.filter(lambda x: not x["sandbagging_environment"])

    return pos_dataset, neg_dataset


def load_harmful_data(
    dataset_dir: str = "external/refusal_direction/dataset/splits", split: str = "train"
) -> Dataset:
    """
    Load harmless and harmful data from the refusal direction dataset.

    Args:
        dataset_dir (str): Directory containing the dataset splits
        split (str): Which split to load ('train' or 'val')

    Returns:
        Dataset: Combined dataset with harmless and harmful examples
    """
    import json

    from datasets import Dataset

    # Load harmless data
    with open(f"{dataset_dir}/harmless_{split}.json", "r") as f:
        harmless_data = json.load(f)
        harmless_data = [
            {"prompt": item["instruction"], "harmful": True} for item in harmless_data
        ]

    # Load harmful data
    with open(f"{dataset_dir}/harmful_{split}.json", "r") as f:
        harmful_data = json.load(f)
        harmful_data = [
            {"prompt": item["instruction"], "harmful": False} for item in harmful_data
        ]

    # Combine the data
    all_data = harmless_data + harmful_data

    # Convert to Dataset
    dataset = Dataset.from_list(all_data)

    return dataset


def add_system_prompt_key(sample, tokenizer) -> int:
    system_prompt = extract_system_prompt(sample["prompt"])
    system_prompt_tokens_encoded = tokenizer.encode(system_prompt)
    return len(system_prompt_tokens_encoded)


if __name__ == "__main__":
    path_to_epo_examples = "/workspace/eliciting-contexts/data/sandbagging/train-sandbagging-f1dpbban-step11264/epo/steering/cosine_mean_activations_layer_12_0.1_1000_pareto_frontier_299_5aa2b51607c477c4071d.csv"

    import pandas as pd

    df = pd.read_csv(path_to_epo_examples)
    first_text = df.iloc[0]["text"]
    print(first_text)
    extracted_system_prompt = extract_system_prompt(first_text)
    print(extracted_system_prompt)
