import argparse
import csv
import itertools
from xml.etree.ElementTree import Element, SubElement, tostring
from xml.dom import minidom

from datasets import load_dataset

from curriculum.csv_to_lesson import prettify


def create_xml(dataset, max_item):
    # Root element
    lessons = Element('lessons')

    hf_dataset = load_dataset("squadshifts", dataset, trust_remote_code=True)["test"]
    for i, item in enumerate(hf_dataset):
        if i >= max_item:
            break

        exercise = item['question']
        context = item['context']
        lesson = SubElement(lessons, 'lesson', id=f'{dataset}_default_{max_item}_test_{i}')
            
        # Material description
        material = SubElement(lesson, 'material')
        material.text = f"{context}"
        ex_element = SubElement(lesson, 'exercise')
        ex_element.text = exercise

    # Return formatted XML
    try:
        return prettify(lessons)
    except:
        return tostring(lessons, 'unicode')

def main(
        dataset: str = "nyt",
        max_item: int = 1000,
    ):
    output_path = f"curriculum/exam_{dataset}_default_{max_item}.xml"

    print(f"Processing {dataset}")
    xml_output = create_xml(dataset, max_item)
    with open(output_path, 'w') as f:
        f.write(xml_output)
    print(f"XML written to {output_path}")


if __name__ == "__main__":
    from jsonargparse import CLI
    CLI(main)
