import logging
import pathlib
import typing as ty
from argparse import ArgumentParser
from pathlib import Path

import numpy as np

import egr.util as util
from egr.log import init_logging

LOG = logging.getLogger('create_index')


def write(indices, train_size, test_size, path: Path):
    data = {
        'train': indices[:train_size].tolist(),
        'val': indices[train_size:-test_size].tolist(),
        'test': indices[-test_size:].tolist(),
    }
    LOG.info('Saving to %s', path)
    util.save_json(data, path)


def make_balanced(args):
    LOG.info('labels file path: %s', args.labels_file)
    labels: ty.List = [
        int(e) for e in args.labels_file.open().read().split(',')
    ]
    total_count: int = len(labels)
    LOG.info('labels: %s', labels)
    LOG.info('type: %s', type(labels))

    from collections import Counter

    counter = Counter(labels)
    LOG.info('counter: %s', counter)

    indices = {item: [] for item in sorted(set(labels))}
    for index, label in enumerate(labels):
        indices[label].append(index)
    LOG.info('indices: %s', indices)

    rng = np.random.default_rng()
    for label_key in indices:
        rng.shuffle(indices[label_key])

    rng = np.random.default_rng()

    index_dir = args.labels_file.parent / 'indices'
    LOG.info('Creating index directory: %s', index_dir)
    index_dir.mkdir(parents=True, exist_ok=True)
    for fold in range(args.folds.begin, args.folds.end + 1):
        train_indices: ty.List = []
        val_indices: ty.List = []
        test_indices: ty.List = []
        for label, index_list in indices.items():
            train_size: int = int(len(index_list) * args.train_fraction)
            test_size: int = (len(index_list) - train_size) // 2
            indexes = np.roll(index_list, (fold - 1) * test_size)

            train_indices.extend(indexes[:train_size])
            val_indices.extend(indexes[train_size:-test_size])
            test_indices.extend(indexes[-test_size:])

        data = dict(
            train=sorted(train_indices),
            val=sorted(val_indices),
            test=sorted(test_indices),
        )

        file_path: pathlib.Path = index_dir / f'{fold:02d}.json'
        LOG.info('Saving index file: %s', file_path)
        util.save_json(data, file_path)


def main(args):
    if args.labels_file:
        make_balanced(args)
    else:
        make_unbalanced(args)


def make_unbalanced(args):
    rng = np.random.default_rng()
    indices = np.arange(args.count)
    train_size: int = int(indices.shape[0] * args.train_fraction)
    test_size: int = (indices.shape[0] - train_size) // 2

    rng.shuffle(indices)
    LOG.info('num indices: %s', indices.shape[0])
    output_dir = args.output_root / f'{args.count}'
    output_dir.mkdir(parents=True, exist_ok=True)
    LOG.info('args=%s', args)
    # for i in range(args.folds.begin, args.folds.end + 1):
    for i in range(1, 11):
        indices = np.roll(indices, test_size)
        output_path: Path = output_dir / f'{i:02d}.json'
        write(indices, train_size, test_size, output_path)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument(
        '-l',
        '--log-level',
        type=str,
        default='info',
        choices=['debug', 'info', 'warning', 'error', 'critical'],
    )
    parser.add_argument('--count', type=int)
    parser.add_argument('--train-fraction', type=float, default=0.8)
    parser.add_argument('--output-root', type=Path, required=True)
    parser.add_argument('--folds', type=int, default=10)
    parser.add_argument('--label-file', type=pathlib.Path)

    args = parser.parse_args()
    init_logging(args.log_level)
    main(args)
