import json

import torch
from datasets import DatasetDict, load_dataset

from eliciting_contexts.sandbagging.utils import TensorEncoder
from eliciting_contexts.utils.constants import DATA_DIR

if __name__ == "__main__":
    num_samples = 250
    datasetdict = load_dataset("contextmodification/sandbagging-sciq")
    assert isinstance(datasetdict, DatasetDict)
    split = "test"

    filtered_dataset_dict = datasetdict[split].filter(lambda x: not x["sandbagging_environment"])

    for i in range(min(num_samples, len(filtered_dataset_dict))):
        print(f"Processing sample {i} of {min(num_samples, len(filtered_dataset_dict))}")
        datum = filtered_dataset_dict[i]
        datum_copy = datum.copy()
        marker = "<start_of_turn>user"
        new_prefix = """<start_of_turn>system log
    This is a random starting prefix as a baseline
    """
        datum_copy["epo_prompt"] = new_prefix + datum_copy["prompt"].split(marker, 1)[1]
        with open(DATA_DIR / f"epo/baselinesample_{i}.json", "w") as file:
            json.dump(datum_copy, file, indent=4, cls=TensorEncoder)
