import argparse
import os
import json

from src.generators.production_planning_generator import ProductionPlanningGenerator

PATH = os.path.join("data", "production planning")


def parse_arguments() -> argparse.Namespace:

    parser = argparse.ArgumentParser(description='Data generation')
    parser.add_argument('name', type=str, help='Dataset name')
    parser.add_argument('input_dim', type=int, help='Input dimension')
    parser.add_argument('n_products', type=int, help='Number of products types (output dimension)')
    parser.add_argument('capacity', type=int, help='Production capacity')
    parser.add_argument('num_instances', type=int, help='Number of instances')
    parser.add_argument('--use_binomial_distribution', type=bool, help='Use binomial or normal distribution', default=True)
    parser.add_argument('--n_customers', type=int, help='Number of customers', default=0)
    parser.add_argument('--sigma', type=float, help='Standard deviation for demand sampling', default=0.0)
    parser.add_argument('--max_demands', type=int, help='Maximum number of demands per product', default=100)
    parser.add_argument('--costs_asymmetry', type=float, help='Asymmetry level in costs', default=0.5)
    parser.add_argument('--seed', type=int, help='Random seed', default=24)

    _args = parser.parse_args()

    return _args


def save_params(_args: argparse.Namespace, _path: str) -> None:

    params = {
        'name': _args.name,
        'input_dim': _args.input_dim,
        'n_products': _args.n_products,
        'capacity': _args.capacity,
        'num_instances': _args.num_instances,
        'use_binomial_distribution': args.use_binomial_distribution,
        'n_customers': _args.n_customers,
        'costs_asymmetry': _args.costs_asymmetry,
        'sigma': _args.sigma,
        'max_demands': _args.max_demands,
        'seed': _args.seed,
    }

    save_path = os.path.join(_path, _args.name + "_config.json")

    with open(save_path, "w") as outfile:
        json.dump(params, outfile, indent=4)


if __name__ == '__main__':

    args = parse_arguments()

    name = args.name
    input_dim = args.input_dim
    n_products = args.n_products
    capacity = args.capacity
    num_instances = args.num_instances
    use_binomial_distribution = args.use_binomial_distribution
    n_customers = args.n_customers
    costs_asymmetry = args.costs_asymmetry
    sigma = args.sigma
    max_demands = args.max_demands
    seed = args.seed

    generator = ProductionPlanningGenerator(name, n_products, input_dim, capacity, use_binomial_distribution,
                                            n_customers=n_customers, sigma=sigma, max_demands=max_demands,
                                            costs_asymmetry=costs_asymmetry)

    generator.generate(PATH, num_instances, seed)

    save_params(args, PATH)
