"""Splits one of the MTD Datasets into train / val / test.
"""

import os
import sys
import pandas as pd

filepath = sys.argv[1]

try:
    if "," in sys.argv[2]:
        split_sizes = [int(num) for num in sys.argv[2].split(",")]
        assert len(split_sizes) in [2, 3]
        is_sizes = True
    else:
        assert os.path.isdir(sys.argv[2])
        filenames = os.listdir(sys.argv[2])
        try:
            train_file = [os.path.join(sys.argv[2], f) for f in filenames if f.endswith("_train.jsonl")][0]
            valid_file = [os.path.join(sys.argv[2], f) for f in filenames if f.endswith("_valid.jsonl")][0]
            test_file = [os.path.join(sys.argv[2], f) for f in filenames if f.endswith("_test.jsonl")][0]
        except:
            train_file = [os.path.join(sys.argv[2], f) for f in filenames if f.endswith("_train.jsonl")][0]
            test_file = [os.path.join(sys.argv[2], f) for f in filenames if f.endswith("_test.jsonl")][0]
        is_sizes = False
except:
    assert False, "Please provide a comma-separated list of split sizes"

df = pd.read_json(filepath, lines=True)
if "author_id" not in df and not is_sizes:
    assert False, "If author_id is not present, must provide sizes..."

breakpoint()
if "author_id" not in df:
    df = df.sample(frac=1.)
    split_dfs = []
    start = 0
    for size in split_sizes:
        split_dfs.append(df[start:start+size])
        start += size
else:
    to_explode = [col for col in df.columns if col != "author_id"]
    if is_sizes:
        df = df.groupby("author_id").agg(list).reset_index()
        assert len(df) == sum(split_sizes), f"Expected {len(df)}, got {sum(split_sizes)}"

        split_dfs = []
        start = 0
        for size in split_sizes:
            split_dfs.append(df[start:start+size])
            split_dfs[-1] = split_dfs[-1].explode(to_explode).reset_index(drop=True)
            start += size
    else:
        train_author_ids = pd.read_json(train_file, lines=True)["author_id"].values
        valid_author_ids = pd.read_json(valid_file, lines=True)["author_id"].values
        test_author_ids = pd.read_json(test_file, lines=True)["author_id"].values
        split_dfs = [
            df[df["author_id"].isin(train_author_ids)],
            df[df["author_id"].isin(valid_author_ids)],
            df[df["author_id"].isin(test_author_ids)],
        ]

breakpoint()
bname = os.path.basename(filepath)
dirname = "./data/{}".format(bname.replace(".jsonl", ""))
os.makedirs(dirname, exist_ok=True)
filepath = os.path.join(dirname, bname)
if len(split_dfs) == 2:
    # train / test
    split_dfs[0].to_json(filepath.replace(".jsonl", "_train.jsonl"), lines=True, orient="records")
    split_dfs[1].to_json(filepath.replace(".jsonl", "_test.jsonl"), lines=True, orient="records")
else:
    # train / val / test
    split_dfs[0].to_json(filepath.replace(".jsonl", "_train.jsonl"), lines=True, orient="records")
    split_dfs[1].to_json(filepath.replace(".jsonl", "_valid.jsonl"), lines=True, orient="records")
    split_dfs[2].to_json(filepath.replace(".jsonl", "_test.jsonl"), lines=True, orient="records")