import os
import json
import argparse

def create_split(data: list, nsplits: int, method='round_robin'):
    if method == 'round_robin':
        splits = []
        for i in range(nsplits):
            splits.append(data[i::nsplits])
    else:
        n = len(data)
        split_size = n // nsplits
        remainder = n % nsplits
        splits = []
        start = 0
        for i in range(nsplits):
            end = start + split_size + (i < remainder)
            splits.append(data[start:end])
            start = end
    return splits

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--data_path",
        type=str,
        default=None,
        help="json file with image paths and annotations",
    )

    parser.add_argument(
        "--nsplits",
        type=int,
        default=8,
        help="Number of splits to make",
    )

    parser.add_argument(
        "--method",
        type=str,
        default='round_robin',
        choices=['round_robin', 'continuous'],
        help="Method to use for splitting",
    )

    args = parser.parse_args()

    if args.data_path.endswith('.json'):
        with open(args.data_path, 'r') as f:
            data = json.load(f)
        splits = create_split(data, args.nsplits, args.method)
        os.makedirs(args.data_path.replace('.json', ''), exist_ok=True)
        for i, split in enumerate(splits):
            filename = args.data_path.replace('.json', f'/{i}_of_{args.nsplits}.json')
            with open(filename, 'w') as f:
                json.dump(split, f, indent=2)
    elif args.data_path.endswith('.jsonl'):
        with open(args.data_path, 'r') as f:
            data = [json.loads(line) for line in f]
        splits = create_split(data, args.nsplits, args.method)
        os.makedirs(args.data_path.replace('.jsonl', ''), exist_ok=True)
        for i, split in enumerate(splits):
            filename = args.data_path.replace('.jsonl', f'/{i}_of_{args.nsplits}.jsonl')
            with open(filename, 'w') as f:
                for line in split:
                    f.write(json.dumps(line) + '\n')
    elif args.data_path.endswith('.tsv'):
        import pandas as pd
        import numpy as np
        data = pd.read_table(args.data_path)
        splits = np.array_split(data, args.nsplits)
        os.makedirs(args.data_path.replace('.tsv', ''), exist_ok=True)
        for i, split in enumerate(splits):
            filename = args.data_path.replace('.tsv', f'/{i}_of_{args.nsplits}.tsv')
            split.to_csv(filename, sep='\t', index=False)

