#%%
import random
from datasets import load_dataset, concatenate_datasets
from collections import defaultdict

if __name__ == "__main__":
    # Load the dataset
    dataset = load_dataset("contextmodification/sandbagging-sciq")
    print(dataset)

    n_samples = 700

    # Collect a random sample of examples from the train, validation, and test sets
    train_dataset = dataset['train'].shuffle(seed=42).select(range(n_samples))
    val_dataset = dataset['validation'].shuffle(seed=42).select(range(n_samples))
    test_dataset = dataset['test'].shuffle(seed=42).select(range(n_samples))
    combined_dataset = concatenate_datasets([train_dataset, val_dataset, test_dataset])
    # Dataset({
    #     features: ['correct_answer', 'index', 'prompt', 'sandbagging_environment', 'desired_answer', 'prefix'],
    #     num_rows: 1500
    # })

    
    # Print a bunch of random examples
    for i in range(10):
        example = combined_dataset.select([random.randint(0, len(combined_dataset) - 1)])[0]
        print(example)
        print('sandbagging_environment:' if example['sandbagging_environment'] else 'clean_environment:', '\n\tcorrect_answer:', example['correct_answer'], '\n\tdesired_answer:', example['desired_answer'])
        print('prompt:', example['prompt'])
        print()

    # Check that the 'correct_answer' is the 'desired_answer' when sandbagging_environment is False
    n_total_clean = 0
    n_ABCD = defaultdict(int)
    for i in range(len(combined_dataset)):
        example = combined_dataset[i]
        if not example['sandbagging_environment']:
            assert example['correct_answer'] == example['desired_answer']
            n_ABCD[example['correct_answer']] += 1
            n_total_clean += 1
    prop_clean = n_total_clean / len(combined_dataset)
    prop_ABCD = {k: v / n_total_clean for k, v in n_ABCD.items()}

    print('When sandbagging_environment is False:')
    print(f"\tprop_clean: {prop_clean}, n_total_clean: {n_total_clean}")
    print("\tAll desired answers are correct.")
    print(f"\tprop_ABCD: {prop_ABCD}")
    for k in prop_ABCD:
        assert abs(prop_ABCD[k] - 0.25) < 0.05, f"prop_ABCD: {prop_ABCD[k]}"


    # Check that the 'correct_answer' is random from the 'desired_answer' when sandbagging_environment is True
    print('When sandbagging_environment is True:')
    n_desired_correct = 0
    n_desired_incorrect = 0
    n_total_sandbag = 0
    n_ABCD_correct = defaultdict(int)
    n_ABCD_desired = defaultdict(int)
    for i in range(len(combined_dataset)):
        example = combined_dataset[i]
        if example['sandbagging_environment']:
            n_total_sandbag += 1
            if example['correct_answer'] == example['desired_answer']:
                n_desired_correct += 1
            else:
                n_desired_incorrect += 1
            n_ABCD_correct[example['correct_answer']] += 1
            n_ABCD_desired[example['desired_answer']] += 1

    prop_sandbag = n_total_sandbag / len(combined_dataset)
    prop_desired_correct = n_desired_correct / n_total_sandbag  
    prop_ABCD_correct = {k: v / n_total_sandbag for k, v in n_ABCD_correct.items()}
    prop_ABCD_desired = {k: v / n_total_sandbag for k, v in n_ABCD_desired.items()}

    print(f"\tprop_sandbag: {prop_sandbag}, n_total_sandbag: {n_total_sandbag}")
    print(f"\tn_desired_correct: {n_desired_correct}, n_desired_incorrect: {n_desired_incorrect}")
    print(f"\tprop_desired_correct: {prop_desired_correct}")
    print(f"\tprop_ABCD_correct: {prop_ABCD_correct}")
    print(f"\tprop_ABCD_desired: {prop_ABCD_desired}")

    assert abs(prop_sandbag - 0.5) < 0.05, f"prop_sandbag: {prop_sandbag}"
    assert abs(prop_desired_correct - 0.25) < 0.05, f"prop_desired_correct: {prop_desired_correct}"
    for k in prop_ABCD:
        assert abs(prop_ABCD_correct[k] - 0.25) < 0.05, f"prop_ABCD_correct: {prop_ABCD_correct[k]}"
        assert abs(prop_ABCD_desired[k] - 0.25) < 0.05, f"prop_ABCD_desired: {prop_ABCD_desired[k]}"

    print("\nAll checks passed!")

    
# %%

