"""
Modified from the original implementation: https://github.com/redwoodresearch/Easy-Transformer/blob/main/easy_transformer/ioi_dataset.py
"""

from typing import Union, List, Optional
import warnings
import torch as t
import numpy as np
from transformers import AutoTokenizer
import random
import copy
import re

NAMES = [
    'Adam',
    'Alan',
    'Alex',
    'Alice',
    'Amy',
    'Anderson',
    'Andre',
    'Andrew',
    'Andy',
    'Anna',
    'Anthony',
    'Arthur',
    'Austin',
    'Brian',
    'Carter',
    'Charles',
    'Charlie',
    'Christian',
    'Christopher',
    'Clark',
    'Cole',
    'Collins',
    'Daniel',
    'David',
    'Dean',
    'Edward',
    'Elizabeth',
    'Eric',
    'Eva',
    'Ford',
    'Frank',
    'George',
    'Georgia',
    'Graham',
    'Grant',
    'Henry',
    'Ian',
    'Jack',
    'Jacob',
    'James',
    'Jane',
    'Jason',
    'Jay',
    'John',
    'Jonathan',
    'Jordan',
    'Joseph',
    'Justin',
    'Kate',
    'Kelly',
    'Kevin',
    'Laura',
    'Leon',
    'Lewis',
    'Lisa',
    'Louis',
    'Luke',
    'Marco',
    'Marcus',
    'Maria',
    'Mark',
    'Martin',
    'Mary',
    'Matthew',
    'Max',
    'Michael',
    'Morgan',
    'Patrick',
    'Paul',
    'Peter',
    'Prince',
    'Richard',
    'River',
    'Robert',
    'Roman',
    'Rose',
    'Ruby',
    'Russell',
    'Ryan',
    'Sarah',
    'Scott',
    'Simon',
    'Stephen',
    'Steven',
    'Taylor',
    'Thomas',
    'Victoria',
    'Warren',
    'William'
]


ABC_TEMPLATES = [
    "Then, [A], [B] and [C] went to the [PLACE]. [B] and [C] gave a [OBJECT] to [A]",
    "Afterwards [A], [B] and [C] went to the [PLACE]. [B] and [C] gave a [OBJECT] to [A]",
    "When [A], [B] and [C] arrived at the [PLACE], [B] and [C] gave a [OBJECT] to [A]",
    "Friends [A], [B] and [C] went to the [PLACE]. [B] and [C] gave a [OBJECT] to [A]",
]

BAC_TEMPLATES = [
    template.replace("[B]", "[A]", 1).replace("[A]", "[B]", 1)
    for template in ABC_TEMPLATES
]

BABA_TEMPLATES = [
    "Then, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]",
    "Then, [B] and [A] had a lot of fun at the [PLACE]. [B] gave a [OBJECT] to [A]",
    "Then, [B] and [A] were working at the [PLACE]. [B] decided to give a [OBJECT] to [A]",
    "Then, [B] and [A] were thinking about going to the [PLACE]. [B] wanted to give a [OBJECT] to [A]",
    "Then, [B] and [A] had a long argument, and afterwards [B] said to [A]",
    "After [B] and [A] went to the [PLACE], [B] gave a [OBJECT] to [A]",
    "When [B] and [A] got a [OBJECT] at the [PLACE], [B] decided to give it to [A]",
    "When [B] and [A] got a [OBJECT] at the [PLACE], [B] decided to give the [OBJECT] to [A]",
    "While [B] and [A] were working at the [PLACE], [B] gave a [OBJECT] to [A]",
    "While [B] and [A] were commuting to the [PLACE], [B] gave a [OBJECT] to [A]",
    "After the lunch, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]",
    "Afterwards, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]",
    "Then, [B] and [A] had a long argument. Afterwards [B] said to [A]",
    "The [PLACE] [B] and [A] went to had a [OBJECT]. [B] gave it to [A]",
    "Friends [B] and [A] found a [OBJECT] at the [PLACE]. [B] gave it to [A]",
]

BABA_LONG_TEMPLATES = [
    "Then in the morning, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]",
    "Then in the morning, [B] and [A] had a lot of fun at the [PLACE]. [B] gave a [OBJECT] to [A]",
    "Then in the morning, [B] and [A] were working at the [PLACE]. [B] decided to give a [OBJECT] to [A]",
    "Then in the morning, [B] and [A] were thinking about going to the [PLACE]. [B] wanted to give a [OBJECT] to [A]",
    "Then in the morning, [B] and [A] had a long argument, and afterwards [B] said to [A]",
    "After taking a long break [B] and [A] went to the [PLACE], [B] gave a [OBJECT] to [A]",
    "When soon afterwards [B] and [A] got a [OBJECT] at the [PLACE], [B] decided to give it to [A]",
    "When soon afterwards [B] and [A] got a [OBJECT] at the [PLACE], [B] decided to give the [OBJECT] to [A]",
    "While spending time together [B] and [A] were working at the [PLACE], [B] gave a [OBJECT] to [A]",
    "While spending time together [B] and [A] were commuting to the [PLACE], [B] gave a [OBJECT] to [A]",
    "After the lunch in the afternoon, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]",
    "Afterwards, while spending time together [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]",
    "Then in the morning afterwards, [B] and [A] had a long argument. Afterwards [B] said to [A]",
    "The local big [PLACE] [B] and [A] went to had a [OBJECT]. [B] gave it to [A]",
    "Friends separated at birth [B] and [A] found a [OBJECT] at the [PLACE]. [B] gave it to [A]",
]

BABA_LATE_IOS = [
    "Then, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]",
    "Then, [B] and [A] had a lot of fun at the [PLACE]. [B] gave a [OBJECT] to [A]",
    "Then, [B] and [A] were working at the [PLACE]. [B] decided to give a [OBJECT] to [A]",
    "Then, [B] and [A] were thinking about going to the [PLACE]. [B] wanted to give a [OBJECT] to [A]",
    "Then, [B] and [A] had a long argument and after that [B] said to [A]",
    "After the lunch, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]",
    "Afterwards, [B] and [A] went to the [PLACE]. [B] gave a [OBJECT] to [A]",
    "Then, [B] and [A] had a long argument. Afterwards [B] said to [A]",
]

BABA_EARLY_IOS = [
    "Then [B] and [A] went to the [PLACE], and [B] gave a [OBJECT] to [A]",
    "Then [B] and [A] had a lot of fun at the [PLACE], and [B] gave a [OBJECT] to [A]",
    "Then [B] and [A] were working at the [PLACE], and [B] decided to give a [OBJECT] to [A]",
    "Then [B] and [A] were thinking about going to the [PLACE], and [B] wanted to give a [OBJECT] to [A]",
    "Then [B] and [A] had a long argument, and after that [B] said to [A]",
    "After the lunch [B] and [A] went to the [PLACE], and [B] gave a [OBJECT] to [A]",
    "Afterwards [B] and [A] went to the [PLACE], and [B] gave a [OBJECT] to [A]",
    "Then [B] and [A] had a long argument, and afterwards [B] said to [A]",
]

ABBA_TEMPLATES = BABA_TEMPLATES[:]
ABBA_LATE_IOS = BABA_LATE_IOS[:]
ABBA_EARLY_IOS = BABA_EARLY_IOS[:]

for TEMPLATES in [ABBA_TEMPLATES, ABBA_LATE_IOS, ABBA_EARLY_IOS]:
    for i in range(len(TEMPLATES)):
        first_clause = True
        for j in range(1, len(TEMPLATES[i]) - 1):
            if TEMPLATES[i][j - 1 : j + 2] == "[B]" and first_clause:
                TEMPLATES[i] = TEMPLATES[i][:j] + "A" + TEMPLATES[i][j + 1 :]
            elif TEMPLATES[i][j - 1 : j + 2] == "[A]" and first_clause:
                first_clause = False
                TEMPLATES[i] = TEMPLATES[i][:j] + "B" + TEMPLATES[i][j + 1 :]

VERBS = [" tried", " said", " decided", " wanted", " gave"]

PLACES = [
    "store",
    "garden",
    "restaurant",
    "school",
    "hospital",
    "office",
    "house",
    "station",
]

OBJECTS = [
    "ring",
    "kiss",
    "bone",
    "basketball",
    "computer",
    "necklace",
    "drink",
    "snack",
]


def gen_prompt_uniform(
    templates: list[str], names, nouns_dict, N, symmetric, prefixes=None, abc=False
):
    nb_gen = 0
    ioi_prompts = []
    while nb_gen < N:
        temp = random.choice(templates)
        temp_id = templates.index(temp)
        name_1 = ""
        name_2 = ""
        name_3 = ""
        while len(set([name_1, name_2, name_3])) < 3:
            name_1 = random.choice(names)
            name_2 = random.choice(names)
            name_3 = random.choice(names)

        nouns = {}
        ioi_prompt = {}
        for k in nouns_dict:
            nouns[k] = random.choice(nouns_dict[k])
            ioi_prompt[k] = nouns[k]
        prompt = temp
        for k in nouns_dict:
            prompt = prompt.replace(k, nouns[k])

        if prefixes is not None:
            L = random.randint(30, 40)
            pref = ".".join(random.choice(prefixes).split(".")[:L])
            pref += "<|endoftext|>"
        else:
            pref = ""

        prompt1 = prompt.replace("[A]", name_1)
        prompt1 = prompt1.replace("[B]", name_2)
        if abc:
            prompt1 = prompt1.replace("[C]", name_3)
        prompt1 = pref + prompt1
        ioi_prompt["text"] = prompt1
        ioi_prompt["IO"] = name_1
        ioi_prompt["S"] = name_2
        ioi_prompt["TEMPLATE_IDX"] = temp_id
        ioi_prompts.append(ioi_prompt)
        if abc:
            ioi_prompts[-1]["C"] = name_3

        nb_gen += 1

        if symmetric and nb_gen < N:
            prompt2 = prompt.replace("[A]", name_2)
            prompt2 = prompt2.replace("[B]", name_1)
            prompt2 = pref + prompt2
            ioi_prompts.append(
                {"text": prompt2, "IO": name_2, "S": name_1, "TEMPLATE_IDX": temp_id}
            )
            nb_gen += 1
    return ioi_prompts



def flip_words_in_prompt(prompt: str, word1: str, word2: str, instances: Optional[Union[int, List[int]]] = None):
    '''
    Flips instances of word `word1` with `word2` in the string `string`.

    By default it flips all instances, but the optional `instances` argument specifies which
    instances to flip (e.g. if instances = 0, then it only flips the 0th instance of either
    word1 or word2.

    Examples of (arguments) -> return value:

        ("ABA", "A", "B") -> "BAB"
        ("ABA", "A", "B", 1) -> "AAA"
        ("ABA", "A", "B", [0, 1]) -> "BAA
    '''
    split_prompt = re.split("({}|{})".format(word1, word2), prompt)
    indices_of_names = [i for i, s in enumerate(split_prompt) if s in (word1, word2)]
    indices_to_flip = [indices_of_names[i] for i in instances]
    for i in indices_to_flip:
        split_prompt[i] = word1 if split_prompt[i] == word2 else word1
    prompt = "".join(split_prompt)
    return prompt
    


def gen_flipped_prompts(prompts: List[dict], templates_by_prompt: List[str], flip: str, names: List[str], seed: int) -> List[dict]:
    '''
    Flip prompts in a way described by the flip argument. Returns new prompts.

    prompts: List[dict]
        list of prompts, each prompt is a dict with keys "S", "IO", "text", etc

    templates_by_prompt: List[str]
        each element is "ABBA" or "BABA"

    flip: str
        "ABB -> XYZ, BAB -> XYZ" means that the prompt "A and B went to [place], B gave [object] to A" becomes "X and Y went to [place], Z gave [object] to A" (and equivalent for the BABA case)

    names: List[str]
        list of names, for when flip involves random tokens

    seed: int
        provides reproducibility

    Note that we don't bother flipping the last token in the prompt (IO2), since
    we don't use it for anything (intuitively, we use this function to create 
    datasets to provide us with corrupted signals, but we still use the IO2 from
    the original uncorrupted IOI database as our "correct answer", so we don't 
    care about what the correct answer (IO2) for the corrupted set is).
    '''
    random.seed(seed)
    np.random.seed(seed)
    abba_flip, baba_flip = flip.split(",")
    flip_dict = {
        "ABB": [flip.strip() for flip in abba_flip.split("->")],
        "BAB": [flip.strip() for flip in baba_flip.split("->")]
    }

    new_prompts = []
    
    for idx, (prompt, template) in enumerate(zip(prompts, templates_by_prompt)):

        flip_orig, flip_new = flip_dict[template[:-1]]

        prompt = copy.copy(prompt)

        # Get indices and original values of first three names int the prompt
        prompt_split = prompt["text"].split(" ")
        orig_names_and_posns = [(i, s) for i, s in enumerate(prompt_split) if s in names][:3]
        orig_names = list(zip(*orig_names_and_posns))[1]

        # Get a dictionary of the correspondence between orig names and letters in flip_orig
        # (and get a subdict for those names which are kept in flip_new)
        orig_names_key = {
            letter: s
            for s, letter in zip(orig_names, flip_orig)
        }
        kept_names_key = {
            k: v
            for k, v in orig_names_key.items() if k in flip_new
        }
        # This line will throw an error if flip_orig is wrong (e.g. if it says "SOS" but the
        # S1 and S2 tokens don't actually match
        assert len(orig_names_key) == len(set(flip_orig))
        
        # Get all random names we'll need, in the form of a dictionary
        sorted_names = sorted(set(names) - set(orig_names), key=names.index)
        sorted_flips = sorted(set(flip_new) - set(flip_orig), key=flip_new.index)
        rand_names = {
            letter: np.random.choice(sorted_names)
            for letter in sorted_flips
        }
        
        # Get a "full dictionary" which maps letters in flip_new to the new values they will have
        name_replacement_dict = {**kept_names_key, **rand_names}
        assert len(name_replacement_dict) == len(set(flip_new)), (name_replacement_dict, flip_new)

        # Populate the new names, with either random names or with the corresponding orig names
        for (i, s), letter in zip(orig_names_and_posns, flip_new):
            prompt_split[i] = name_replacement_dict[letter]

        # Join the prompt back together
        prompt["text"] = " ".join(prompt_split)

        # Change the identity of the S and IO tokens.
        # S token is just same as S2, but IO is a bit messier because it might not be 
        # well-defined (it's defined as the unique non-duplicated name of the first 
        # two). If it's ill-defined, WLOG set it to be the second name.
        prompt["S"] = name_replacement_dict[flip_new[-1]]
        possible_IOs = [name_replacement_dict[letter] for letter in flip_new[:2] if list(flip_new).count(letter) == 1]
        # Case where IO is well-defined
        if len(possible_IOs) == 1:
            prompt["IO"] = possible_IOs[0]
        # Case where it isn't well-defined
        else:
            prompt["IO"] = name_replacement_dict[flip_new[1]]

        new_prompts.append(prompt)

    return new_prompts



def get_name_idxs(prompts, tokenizer, model_family, idx_types=["IO", "S1", "S2"], prepend_bos=False):
    name_idx_dict = dict((idx_type, []) for idx_type in idx_types)
    for prompt in prompts:
        text_split = prompt["text"].split(" ")
        toks = tokenizer.tokenize(" ".join(text_split[:-1]))
        # Get the first instance of IO token
        name_idx_dict["IO"].append(
            toks.index(tokenizer.tokenize(" " + prompt["IO"] if model_family  in ["gpt2", "pythia", "gemma"] else prompt["IO"])[0])
        )
        # Get the first instance of S token
        name_idx_dict["S1"].append(
            toks.index(tokenizer.tokenize(" " + prompt["S"] if model_family  in ["gpt2", "pythia", "gemma"] else prompt["S"])[0])
        )
        # Get the last instance of S token
        name_idx_dict["S2"].append(
            len(toks) - toks[::-1].index(tokenizer.tokenize(" " + prompt["S"] if model_family  in ["gpt2", "pythia", "gemma"] else prompt["S"])[0]) - 1
        )

    return [
        int(prepend_bos) + t.tensor(name_idx_dict[idx_type])
        for idx_type in idx_types
    ]


def get_word_idxs(prompts, word_list, tokenizer, model_family):
    """Get the index of the words in word_list in the prompts. Exactly one of the word_list word has to be present in each prompt"""
    if model_family in ["gpt2", "pythia"]:
        add_special_tokens = False
    elif model_family in ["llama2", "gemma"]:
        add_special_tokens = True
    
    idxs = []
    tokenized_words = [
        tokenizer.decode(tokenizer(word, add_special_tokens=False)["input_ids"][0]) for word in word_list
    ]
    for prompt in prompts:
        toks = [
            tokenizer.decode(t)
            for t in tokenizer(prompt["text"], return_tensors="pt", padding=True, add_special_tokens=add_special_tokens)[
                "input_ids"
            ][0]
        ]
        idx = None
        for i, w_tok in enumerate(tokenized_words):
            if word_list[i] in prompt["text"]:
                try:
                    idx = toks.index(w_tok)
                    if toks.count(w_tok) > 1:
                        idx = len(toks) - toks[::-1].index(w_tok) - 1
                except:
                    idx = toks.index(w_tok)
                    # raise ValueError(toks, w_tok, prompt["text"])
        if idx is None:
            raise ValueError(f"Word {word_list} and {i} not found {prompt}")
        idxs.append(idx)
    return t.tensor(idxs)


def get_end_idxs(toks, tokenizer, model_family, name_tok_len=1, prepend_bos=False):
    # if the sentence begins with an end token
    # AND the model pads at the end with the same end token,
    # then we need make special arrangements
    if model_family in ["gpt2", "pythia"]:
        relevant_idx = int(prepend_bos)
    elif model_family in ["llama2", "gemma"]:
        relevant_idx = 0

    pad_token_id = tokenizer.pad_token_id

    end_idxs_raw = []
    for i in range(toks.shape[0]):
        if pad_token_id not in toks[i][1:]:
            end_idxs_raw.append(toks.shape[1])
            continue
        nonzers = (toks[i] == pad_token_id).nonzero()[relevant_idx][0].item()
        end_idxs_raw.append(nonzers)
    end_idxs = t.tensor(end_idxs_raw)
    end_idxs = end_idxs - 1 - name_tok_len

    for i in range(toks.shape[0]):
        assert toks[i][end_idxs[i] + 1] != 0 and (
            toks.shape[1] == end_idxs[i] + 2 or toks[i][end_idxs[i] + 2] == pad_token_id
        ), (
            toks[i],
            end_idxs[i],
            toks[i].shape,
            "the END idxs aren't properly formatted",
        )

    return end_idxs

def get_idx_dict(ioi_prompts, tokenizer, model_family, prepend_bos=False, toks=None):
    (IO_idxs, S1_idxs, S2_idxs,) = get_name_idxs(
        ioi_prompts,
        tokenizer,
        model_family=model_family,
        idx_types=["IO", "S1", "S2"],
        prepend_bos=prepend_bos,
    )

    end_idxs = get_end_idxs(
        toks,
        tokenizer,
        model_family=model_family,
        name_tok_len=1,
        prepend_bos=prepend_bos,
    )

    punct_idxs = get_word_idxs(ioi_prompts, [",", "."], tokenizer, model_family)

    if model_family in ["gpt2", "pythia", "gemma"]:
        verbs = get_word_idxs(ioi_prompts, [" gave", " give", " said"], tokenizer, model_family)
    elif model_family in ["llama2"]:
         verbs = get_word_idxs(ioi_prompts, ["gave", "give", "said"], tokenizer, model_family)

    return {
        "IO": IO_idxs,
        "IO-1": IO_idxs - 1,
        "IO+1": IO_idxs + 1,
        "S1": S1_idxs,
        "S1-1": S1_idxs - 1,
        "S1+1": S1_idxs + 1,
        "S2": S2_idxs,
        "S2-1": S2_idxs - 1,
        "S2+1": S2_idxs + 1,
        "end": end_idxs,
        "end-1": end_idxs - 1,
        "starts": t.zeros_like(end_idxs),
        "punct": punct_idxs,
        "verb": verbs,
    }

class IOIDataset:
    def __init__(
        self,
        prompt_type: Union[
            str, List[str]
        ],  # if list, then it will be a list of templates
        model_family,
        N=500,
        tokenizer=None,
        prompts=None,
        symmetric=False,
        prefixes=None,
        nb_templates=None,
        prepend_bos=False,
        manual_word_idx=None,
        has_been_flipped:bool=False,
        seed=0,
        device="cuda"
    ):
        self.seed = seed
        random.seed(self.seed)
        np.random.seed(self.seed)

        self.model_family = model_family
        if not self.model_family in ["gpt2", "pythia", "llama2", "gemma"]:
            raise Exception(f"model_family: {model_family} is not implemented.")

        if not (
            N == 1
            or prepend_bos == False
            or tokenizer.bos_token_id == tokenizer.eos_token_id
        ):
            warnings.warn(
                "Probably word_idx will be calculated incorrectly due to this formatting"
            )
        self.has_been_flipped = has_been_flipped
        assert not (symmetric and prompt_type == "ABC")
        assert (
            (prompts is not None) or (not symmetric) or (N % 2 == 0)
        ), f"{symmetric} {N}"
        self.prompt_type = prompt_type

        if nb_templates is None:
            nb_templates = len(BABA_TEMPLATES)

        if prompt_type == "ABBA":
            self.templates = ABBA_TEMPLATES[:nb_templates].copy()
        elif prompt_type == "BABA":
            self.templates = BABA_TEMPLATES[:nb_templates].copy()
        elif prompt_type == "mixed":
            self.templates = (
                BABA_TEMPLATES[: nb_templates // 2].copy()
                + ABBA_TEMPLATES[: nb_templates // 2].copy()
            )
            random.shuffle(self.templates)
        elif prompt_type == "ABC":
            self.templates = ABC_TEMPLATES[:nb_templates].copy()
        elif prompt_type == "BAC":
            self.templates = BAC_TEMPLATES[:nb_templates].copy()
        elif prompt_type == "ABC mixed":
            self.templates = (
                ABC_TEMPLATES[: nb_templates // 2].copy()
                + BAC_TEMPLATES[: nb_templates // 2].copy()
            )
            random.shuffle(self.templates)
        elif isinstance(prompt_type, list):
            self.templates = prompt_type
        else:
            raise ValueError(prompt_type)

        if tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
            self.tokenizer.pad_token = self.tokenizer.eos_token
        else:
            self.tokenizer = tokenizer

        self.prefixes = prefixes
        self.prompt_type = prompt_type
        if prompts is None:
            self.ioi_prompts = gen_prompt_uniform(  # list of dict of the form {"text": "Alice and Bob bla bla. Bob gave bla to Alice", "IO": "Alice", "S": "Bob"}
                self.templates,
                NAMES,
                nouns_dict={"[PLACE]": PLACES, "[OBJECT]": OBJECTS},
                N=N,
                symmetric=symmetric,
                prefixes=self.prefixes,
                abc=(prompt_type in ["ABC", "ABC mixed", "BAC"]),
            )
        else:
            assert N == len(prompts), f"{N} and {len(prompts)}"
            self.ioi_prompts = prompts

        all_ids = [prompt["TEMPLATE_IDX"] for prompt in self.ioi_prompts]
        all_ids_ar = np.array(all_ids)
        self.groups = []
        for id in list(set(all_ids)):
            self.groups.append(np.where(all_ids_ar == id)[0])

        small_groups = []
        for group in self.groups:
            if len(group) < 5:
                small_groups.append(len(group))

        self.sentences = [
            prompt["text"] for prompt in self.ioi_prompts
        ]  # a list of strings. Renamed as this should NOT be forward passed

        self.templates_by_prompt = []  # for each prompt if it's ABBA or BABA
        for i in range(N):
            if self.sentences[i].index(self.ioi_prompts[i]["IO"]) < self.sentences[
                i
            ].index(self.ioi_prompts[i]["S"]):
                self.templates_by_prompt.append("ABBA")
            else:
                self.templates_by_prompt.append("BABA")

        texts = [
            (self.tokenizer.bos_token if prepend_bos else "") + prompt["text"]
            for prompt in self.ioi_prompts
        ]
        self.toks = t.Tensor(self.tokenizer(texts, padding=True, add_special_tokens=False).input_ids).long()

        self.word_idx = get_idx_dict(
            self.ioi_prompts,
            self.tokenizer,
            model_family=self.model_family,
            prepend_bos=prepend_bos,
            toks=self.toks,
        )
        self.prepend_bos = prepend_bos
        if manual_word_idx is not None:
            self.word_idx = manual_word_idx

        self.N = N
        self.max_len = max(
            [
                len(self.tokenizer(prompt["text"]).input_ids)
                for prompt in self.ioi_prompts
            ]
        )

        self.io_tokenIDs = [
            self.tokenizer.encode(" " + prompt["IO"], add_special_tokens=False)[0] for prompt in self.ioi_prompts
        ]
        self.s_tokenIDs = [
            self.tokenizer.encode(" " + prompt["S"], add_special_tokens=False)[0] for prompt in self.ioi_prompts
        ]

        self.tokenized_prompts = []

        for i in range(self.N):
            self.tokenized_prompts.append(
                "|".join([self.tokenizer.decode(tok) for tok in self.toks[i]])
            )

        self.device = device
        self.to(device)
    
    def gen_flipped_prompts(self, flip):
        # Check if it's already been flipped (shouldn't string 2 flips together)
        if self.has_been_flipped:
            warnings.warn("This dataset has already been flipped. Generally, you should try and apply flips in one step, because this can lead to errors.")

        # Redefine seed (so it's different depending on what the flip is, e.g. we don't want (IO, RAND) then (S, RAND) to give us the same rand names)
        seed = self.seed + sum(map(ord, list("".join(flip))))

        # Get flipped prompts
        flipped_prompts = gen_flipped_prompts(self.ioi_prompts, self.templates_by_prompt, flip, NAMES, seed)

        flipped_ioi_dataset = IOIDataset(
            model_family=self.model_family,
            prompt_type=self.prompt_type,
            N=self.N,
            tokenizer=self.tokenizer,
            prompts=flipped_prompts,
            prefixes=self.prefixes,
            prepend_bos=self.prepend_bos,
            manual_word_idx=self.word_idx,
            has_been_flipped=True,
            seed=seed,
            device=self.device,
        )
        return flipped_ioi_dataset

    def copy(self):
        copy_ioi_dataset = IOIDataset(
            prompt_type=self.prompt_type,
            N=self.N,
            tokenizer=self.tokenizer,
            prompts=self.ioi_prompts.copy(),
            prefixes=self.prefixes.copy() if self.prefixes is not None else self.prefixes,
        )
        return copy_ioi_dataset

    def __getitem__(self, key):
        sliced_prompts = self.ioi_prompts[key]
        sliced_dataset = IOIDataset(
            prompt_type=self.prompt_type,
            N=len(sliced_prompts),
            tokenizer=self.tokenizer,
            prompts=sliced_prompts,
            prefixes=self.prefixes,
            prepend_bos=self.prepend_bos,
        )
        return sliced_dataset

    def __setitem__(self, key, value):
        raise NotImplementedError()

    def __delitem__(self, key):
        raise NotImplementedError()

    def __len__(self):
        return self.N

    def tokenized_prompts(self):
        return self.toks

    def to(self, device):
        self.toks = self.toks.to(device)
        return self