import json
import logging
from argparse import ArgumentParser
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split

from experiments.utils import common_parser, set_logger


def generate_splits(
    data_path,
    save_path,
    name="cnn_generalization_splits.json",
    val_size=10000,
    test_size=10000,
    seed=42
):
    data_path = Path(data_path)
    # We have to sort the files to make sure that the order between checkpoints
    # and progresses is the same. We will randomize later.
    checkpoints = sorted(data_path.glob('*/*/*/checkpoints'))
    progress_files = sorted(data_path.glob('*/*/progress.csv'))
    progresses = np.concatenate([pd.read_csv(f)['test_acc'].to_numpy() for f in progress_files])

    trainval_indices, test_indices = train_test_split(
        range(progresses.shape[0]), test_size=test_size, random_state=seed
    )
    train_indices, val_indices = train_test_split(
        trainval_indices, test_size=val_size, random_state=seed
    )

    data_split = defaultdict(lambda: defaultdict(list))
    data_split["train"]["path"] = [checkpoints[idx].as_posix() for idx in train_indices]
    data_split["train"]["score"] = [progresses[idx] for idx in train_indices]

    data_split["val"]["path"] = [checkpoints[idx].as_posix() for idx in val_indices]
    data_split["val"]["score"] = [progresses[idx] for idx in val_indices]

    data_split["test"]["path"] = [checkpoints[idx].as_posix() for idx in test_indices]
    data_split["test"]["score"] = [progresses[idx] for idx in test_indices]

    logging.info(
        f"train size: {len(data_split['train']['path'])}, "
        f"val size: {len(data_split['val']['path'])}, "
        f"test size: {len(data_split['test']['path'])}"
    )

    save_path = Path(save_path) / name
    with open(save_path, "w") as file:
        json.dump(data_split, file)


if __name__ == "__main__":
    parser = ArgumentParser("CNN Generalization - generate data splits", parents=[common_parser])
    parser.add_argument(
        "--name", type=str, default="cnn_generalization_splits.json", help="json file name"
    )
    parser.add_argument(
        "--val-size", type=int, default=10000, help="number of validation examples"
    )
    parser.add_argument(
        "--test-size", type=int, default=10000, help="number of test examples"
    )
    parser.set_defaults(
        save_path="dataset",
        data_path="raw_dataset",
    )
    args = parser.parse_args()

    set_logger()

    generate_splits(
        args.data_path,
        args.save_path,
        name=args.name,
        val_size=args.val_size,
        test_size=args.test_size,
    )
