import json

import torch
from datasets import DatasetDict, load_dataset

from eliciting_contexts.utils.constants import DATA_DIR
from eliciting_contexts.sandbagging.utils import TensorEncoder


if __name__ == "__main__":

    num_samples = 250
    datasetdict = load_dataset("contextmodification/sandbagging-sciq")
    assert isinstance(datasetdict, DatasetDict)
    split = "test"

    filtered_dataset_dict = datasetdict[split].filter(
        lambda x: not x["sandbagging_environment"]
    )

    for i in range(min(num_samples, len(filtered_dataset_dict))):
        print(f"Processing sample {i} of {min(num_samples, len(filtered_dataset_dict))}")
        datum = filtered_dataset_dict[i]
        datum_copy = datum.copy()
        marker = "<start_of_turn>user"
        new_prefix = """<start_of_turn>system log
    This is a random starting prefix as a baseline
    """
        datum_copy["epo_prompt"] = new_prefix + datum_copy["prompt"].split(marker, 1)[1]
        with open(DATA_DIR / f"epo/baselinesample_{i}.json", "w") as file:
            json.dump(datum_copy, file, indent=4, cls=TensorEncoder)
