from pathlib import Path
import argparse
import random

"""
load dataset from data-dir and subsample and save to output-dir
"""

seed = 42
random.seed(seed)


def load_data(path):
    with open(path, "r") as f:
        return f.readlines()


def subsample_data(data, label, num_samples):
    # Separate positive and negative examples
    positive = [(d, l) for d, l in zip(data, label) if l.strip() == "1"]
    negative = [(d, l) for d, l in zip(data, label) if l.strip() == "0"]
    half = num_samples // 2
    if len(positive) < half or len(negative) < half:
        raise ValueError(
            f"Not enough examples: need at least {half} positive and {half} negative examples, "
            f"but got {len(positive)} positive and {len(negative)} negative."
        )
    # Take first half positive and negative samples
    subsampled = positive[:half] + negative[:half]
    # Shuffle the subsampled data
    random.shuffle(subsampled)
    # Unzip the paired list back into data and label
    subsampled_data, subsampled_label = zip(*subsampled)
    return list(subsampled_data), list(subsampled_label)


def save_data(data, path):
    with open(path, "w") as f:
        for line in data:
            f.write(line)


# python -m scripts.sample_dataset --data-dir data_old/flare --output-dir data/flare_subsampled
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-dir", type=Path, required=True)
    parser.add_argument("--output-dir", type=Path, required=True)
    parser.add_argument("--num-samples-train", type=int, default=100)
    parser.add_argument("--num-samples-val-short", type=int, default=100)
    parser.add_argument("--num-samples-val-long", type=int, default=100)
    parser.add_argument("--num-samples-test", type=int, default=100)
    parser.add_argument("--seed", type=int, default=42)

    args = parser.parse_args()

    # create output dir
    args.output_dir.mkdir(parents=True, exist_ok=True)

    # load dataset dirs
    language_dirs = [dir_ for dir_ in args.data_dir.glob("*") if dir_.is_dir()]
    for language_dir in language_dirs:
        language = language_dir.name
        train_data_path = language_dir / "main.tok"
        train_label_path = language_dir / "labels.txt"
        val_short_data_path = (
            language_dir / "datasets" / "validation-short" / "main.tok"
        )
        val_short_label_path = (
            language_dir / "datasets" / "validation-short" / "labels.txt"
        )
        val_long_data_path = language_dir / "datasets" / "validation-long" / "main.tok"
        val_long_label_path = (
            language_dir / "datasets" / "validation-long" / "labels.txt"
        )
        test_data_path = language_dir / "datasets" / "test" / "main.tok"
        test_label_path = language_dir / "datasets" / "test" / "labels.txt"
        train_data = load_data(train_data_path)
        train_label = load_data(train_label_path)
        val_short_data = load_data(val_short_data_path)
        val_short_label = load_data(val_short_label_path)
        val_long_data = load_data(val_long_data_path)
        val_long_label = load_data(val_long_label_path)
        test_data = load_data(test_data_path)
        test_label = load_data(test_label_path)
        train_data, train_label = subsample_data(
            train_data, train_label, args.num_samples_train
        )
        val_short_data, val_short_label = subsample_data(
            val_short_data, val_short_label, args.num_samples_val_short
        )
        val_long_data, val_long_label = subsample_data(
            val_long_data, val_long_label, args.num_samples_val_long
        )
        test_data, test_label = subsample_data(
            test_data, test_label, args.num_samples_test
        )

        # save
        output_dir = args.output_dir / language
        output_dir.mkdir(parents=True, exist_ok=True)
        train_data_output_path = output_dir / "data.train"
        train_label_output_path = output_dir / "labels.train"
        val_short_data_output_path = output_dir / "data.val-short"
        val_short_label_output_path = output_dir / "labels.val-short"
        val_long_data_output_path = output_dir / "data.val-long"
        val_long_label_output_path = output_dir / "labels.val-long"
        test_data_output_path = output_dir / "data.test"
        test_label_output_path = output_dir / "labels.test"

        save_data(train_data, train_data_output_path)
        save_data(train_label, train_label_output_path)
        save_data(val_short_data, val_short_data_output_path)
        save_data(val_short_label, val_short_label_output_path)
        save_data(val_long_data, val_long_data_output_path)
        save_data(val_long_label, val_long_label_output_path)
        save_data(test_data, test_data_output_path)
        save_data(test_label, test_label_output_path)


if __name__ == "__main__":
    main()
