import argparse
import os
import json

from src.generators.wsmc_generator import WeightedSetMultiCoverGenerator

PATH = os.path.join("data", "wsmc")


def parse_arguments() -> argparse.Namespace:

    parser = argparse.ArgumentParser(description='Data generation')
    parser.add_argument('name', type=str, help='Dataset name')
    parser.add_argument('input_dim', type=int, help='Input dimension')
    parser.add_argument('n_items', type=int, help='Number of items')
    parser.add_argument('n_sets', type=int, help='Number of sets')
    parser.add_argument('num_instances', type=int, help='Number of instances')
    parser.add_argument('--seed', type=int, help='Random seed', default=24)
    parser.add_argument('--density', type=float, help='Expected density in cover matrix', default=0.02)
    parser.add_argument('--deg', type=int, default=5)
    parser.add_argument('--mul', type=float, help='Multiplicative noise', default=0.1)
    parser.add_argument('--add', type=float, help='Additive noise', default=0.03)
    parser.add_argument('--penalty', type=float, default=10.0)

    _args = parser.parse_args()

    return _args


def save_params(_args: argparse.Namespace, _path: str) -> None:

    params = {
        'name': _args.name,
        'input_dim': _args.input_dim,
        'n_items': _args.n_items,
        'n_sets': _args.n_sets,
        'num_instances': _args.num_instances,
        'seed': _args.seed,
        'density': _args.density,
        'deg': _args.deg,
        'mul': _args.mul,
        'add': _args.add,
        'penalty': _args.penalty
    }

    save_path = os.path.join(_path, _args.name + "_config.json")

    with open(save_path, "w") as outfile:
        json.dump(params, outfile, indent=4)


if __name__ == '__main__':

    args = parse_arguments()

    name = args.name
    input_dim = args.input_dim
    n_items = args.n_items
    n_sets = args.n_sets
    num_instances = args.num_instances
    seed = args.seed
    density = args.density
    deg = args.deg
    mul = args.mul
    add = args.add
    penalty = args.penalty

    wsmc_generator = WeightedSetMultiCoverGenerator(name, input_dim, n_items, n_sets, density, penalty, deg, mul, add)

    wsmc_generator.generate(PATH, num_instances, seed)

    save_params(args, PATH)
