import json
import random


def main():
    with open(
        "full-agenttrek-dataset-pruned.json",
        "r",
    ) as f:
        original_data = json.load(f)

    with open("full_selected_dataset_indices_T0_3_agenttrek.json", "r") as f:
        selected_dataset_indices = json.load(f)

    # Flatten the list of lists so we can sample individual indices.
    flat_indices = [idx for group in selected_dataset_indices for idx in group]
    # i always want to sample len(flat_indices)
    sample_size = min(100000, len(flat_indices))
    sampled_indices = random.sample(flat_indices, sample_size)
    sampled_data = [original_data[i] for i in sampled_indices]

    print(len(sampled_data))

    refined = []
    for item in sampled_data:
        if len(item["messages"][1]["content"]) > 40000:
            continue
        refined.append(item)

    print(len(refined))

    refined_sampled = random.sample(refined, 10000)

    with open(
        "agenttrek_dataset_pruned_centered_w_threshold_sampled_10k_T0_3_agenttrek.json",
        "w",
    ) as f:
        json.dump(refined_sampled, f, ensure_ascii=False, indent=4)

    print(len(refined_sampled))


if __name__ == "__main__":
    main()
