import argparse
import random
from typing import Any, Dict, List

import numpy as np
import requests
from datasets import DatasetDict, load_dataset
from transformers import AutoTokenizer

# Default values that will be used if not provided via command line
DEFAULT_SOURCE_DATASET_ID = "contextmodification/sandbagging-sciq"
DEFAULT_TARGET_DATASET_ID = "contextmodification/password-locked-dataset-bloom"
PREFIX_COLUMN = "prefix"
PROMPT_COLUMN = "prompt"
NEW_TEMPLATE_COLUMN = "template"
SANDBAGGING_ENVIRONMENT_COLUMN = "sandbagging_environment"
SYSTEM_PROMPT_TEXT = "<start_of_turn>system log\n{0}<end_of_turn>\n"
USER_PROMPT_TEXT = "<start_of_turn>user\n"

# Global variable to store our dictionary
WORD_DICTIONARY = []


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Create a password-locked dataset")
    parser.add_argument(
        "--source-dataset",
        type=str,
        default=DEFAULT_SOURCE_DATASET_ID,
        help=f"Source dataset ID (default: {DEFAULT_SOURCE_DATASET_ID})",
    )
    parser.add_argument(
        "--target-dataset",
        type=str,
        default=DEFAULT_TARGET_DATASET_ID,
        help=f"Target dataset ID where results will be pushed (default: {DEFAULT_TARGET_DATASET_ID})",
    )
    parser.add_argument(
        "--passwords",
        type=str,
        nargs="+",
        default=["bloom"],
        help="List of passwords to use for the locked dataset (default: [bloom])",
    )
    parser.add_argument(
        "--tokenizer-model",
        type=str,
        default="google/gemma-2-2b-it",
        help="Model ID to use for tokenization (default: google/gemma-2-2b-it)",
    )
    parser.add_argument(
        "--num-words",
        type=int,
        default=7,
        help="Number of words to use for the prefix (default: 7)",
    )
    return parser.parse_args()


def download_dictionary():
    """Downloads a dictionary of words once at startup."""
    global WORD_DICTIONARY

    print("Downloading dictionary...")
    try:
        # Try to download a large word list
        response = requests.get(
            "https://raw.githubusercontent.com/dwyl/english-words/master/words_alpha.txt"
        )
        if response.status_code == 200:
            WORD_DICTIONARY = [
                word.strip() for word in response.text.split() if len(word.strip()) > 3
            ]
            print(f"Downloaded {len(WORD_DICTIONARY)} words")
            return
    except Exception as e:
        print(f"Error downloading dictionary: {e}")

    # Fallback to a different source
    response = requests.get("https://www.mit.edu/~ecprice/wordlist.10000")
    if response.status_code == 200:
        WORD_DICTIONARY = [
            word.strip() for word in response.text.split() if len(word.strip()) > 3
        ]
        print(f"Downloaded {len(WORD_DICTIONARY)} words from fallback source")
        return

    raise Exception("Failed to download dictionary")


def get_random_words(num_words: int = 5) -> str:
    """
    Gets random words from our pre-downloaded dictionary.

    Returns:
        A string of space-separated random words.
    """
    global WORD_DICTIONARY

    # Ensure dictionary is populated
    if not WORD_DICTIONARY:
        download_dictionary()

    # Sample words
    sampled_words = random.sample(WORD_DICTIONARY, min(num_words, len(WORD_DICTIONARY)))
    return " ".join(sampled_words)


def replace_in_list(input_list, replace_words):
    replace_position = random.randint(0, len(input_list) - len(replace_words))
    for word in replace_words:
        input_list[replace_position] = word
        replace_position += 1

    return input_list


def modify_prefix(
    example: Dict[str, Any], passwords: List[str] = "bloom", num_words: int = 7
) -> Dict[str, Any]:
    """Modifies the prefix column with randomly sampled words."""
    original_prefix = example[PREFIX_COLUMN]
    is_sand = example[SANDBAGGING_ENVIRONMENT_COLUMN]

    # Get random words

    # If sandbagging environment, insert the password at a random position
    prefix = get_random_words(num_words)

    if is_sand:
        prefix_words = prefix.split()
        # Choose a random position to replace with the password
        prefix_words = replace_in_list(prefix_words, passwords)
        # Join the words back into a string
        prefix = " ".join(prefix_words)

    # just in case we randomly choose the password
    if not is_sand:

        if np.random.uniform() < 0.5 and len(passwords) > 1:
            prefix_words = prefix.split()
            poor_order = random_subset_permutation(passwords)
            prefix_words = replace_in_list(prefix_words, poor_order)
            prefix = " ".join(prefix_words)

    original_prompt = example[PROMPT_COLUMN]

    remove_system = original_prompt.replace(
        SYSTEM_PROMPT_TEXT.format(original_prefix), ""
    )

    template_prompt = remove_system.replace(
        USER_PROMPT_TEXT, USER_PROMPT_TEXT + "{0}\n"
    )

    print(template_prompt)
    example[NEW_TEMPLATE_COLUMN] = template_prompt
    example[PROMPT_COLUMN] = template_prompt.format(prefix)
    example[PREFIX_COLUMN] = prefix
    return example


def random_subset_permutation(input_list):
    """
    Returns a random permutation of a random non-empty subset of the input list.
    All possible permutations of all possible non-empty subsets have equal probability,
    EXCEPT the original complete ordering is excluded.

    For example, with input [1,2,3], each of these has equal probability:
    [1], [2], [3], [1,2], [2,1], [2,3], [3,2], [1,3], [3,1], [2,1,3], [1,3,2], [2,3,1], [3,1,2], [3,2,1]
    (but NOT [1,2,3], which is the original ordering)
    """
    import random
    from itertools import combinations, permutations

    # Generate all possible permutations of all possible non-empty subsets
    all_possibilities = []
    for size in range(1, len(input_list) + 1):
        # Get all subsets of current size
        for subset in combinations(input_list, size):
            # Get all permutations of each subset
            for perm in permutations(subset):
                perm_list = list(perm)
                # Exclude the original complete ordering
                if not (perm_list == input_list):
                    all_possibilities.append(perm_list)
    print(all_possibilities)
    # Choose one possibility with uniform probability
    return random.choice(all_possibilities)


if __name__ == "__main__":
    # Parse command line arguments
    args = parse_args()
    source_dataset_id = args.source_dataset
    target_dataset_id = args.target_dataset
    num_words = args.num_words
    # Load tokenizer
    print(f"Loading tokenizer: {args.tokenizer_model}")
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_model)

    # Validate each password is exactly one token
    for pw in args.passwords:
        tokens = tokenizer.encode(pw, add_special_tokens=False)
        if len(tokens) != 1:
            raise ValueError(
                f"Password '{pw}' is encoded as {len(tokens)} tokens. Each password must be exactly one token."
            )

    # For now, use the first password

    # Download dictionary upfront
    download_dictionary()

    print(f"Loading dataset: {source_dataset_id}")
    dataset_dict = load_dataset(source_dataset_id)

    print("Preprocessing dataset...")
    modified_datasets = {}
    for split, dataset in dataset_dict.items():
        print(f"  Processing split: {split}")
        modified_datasets[split] = dataset.map(
            lambda example: modify_prefix(example, args.passwords, num_words=num_words)
        )

    final_dataset_dict = DatasetDict(modified_datasets)
    print(final_dataset_dict)

    print(f"Pushing modified dataset to Hugging Face Hub: {target_dataset_id}")
    print(f"https://huggingface.co/datasets/{target_dataset_id}")
    try:
        final_dataset_dict.push_to_hub(target_dataset_id)
        print("Dataset successfully pushed to the Hub!")
    except Exception as e:
        print(f"Error pushing dataset to the Hub: {e}")
        print(
            "Please ensure you are logged in (`huggingface-cli login`) and have the correct permissions."
        )
