import click
import logging
import json
from collections import namedtuple

logger = logging.getLogger(__file__)
from datasets import load_dataset


def run(args):
    dataset = load_dataset("tatsu-lab/alpaca")["train"]
    print(dataset)

    new_content = []
    for i, c in enumerate(dataset):
        # print(c)
        if len(c["input"].strip()) > 1:
            q, a = c["instruction"] + "\nInput:\n" + c["input"], c["output"]
        else:
            q, a = c["instruction"], c["output"]
        new_content.append(
            {
                "id": f"alpaca_{i}",
                "origin": c,
                "conversations": [
                    {"from": "human", "value": q},
                    {"from": "gpt", "value": a},
                ],
            }
        )

    print(f"#out: {len(new_content)}")
    json.dump(new_content, open(args.output, "w"), indent=1, ensure_ascii=False)


@click.command()
@click.option('-o', '--output', default='')
def main(**kwargs):
    Arg = namedtuple('Arg', kwargs.keys())
    args = Arg(**kwargs)
    logger.info(kwargs)
    run(args)


if __name__ == '__main__':
    logging.basicConfig(filename=__file__ + '.log',
                        filemode='a',
                        format='[%(levelname)-5.5s][%(asctime)s][%(filename)s %(lineno)d]: %(message)s',
                        datefmt='%d-%m-%Y %H:%M:%S',
                        level=logging.INFO)
    logging.getLogger().addHandler(logging.StreamHandler())
    main()