import argparse
import os
import json

from src.generators.kp_values_generator import KnapsackValuesGenerator
from src.generators.kp_weights_generator import KnapsackWeightsGenerator
from src.generators.kp_capacity_generator import KnapsackCapacityGenerator

from src.solvers.kp_values_solver import KnapsackValuesSolver
from src.solvers.kp_weights_solver import KnapsackWeightsSolver
from src.solvers.kp_capacity_solver import KnapsackCapacitySolver

VALID_MODES = ["kp values", "kp weights", "kp capacity"]
DATA_PATH = "data"


class InvalidModeException(Exception):
    pass


def parse_arguments() -> argparse.Namespace:

    parser = argparse.ArgumentParser(description='Data generation')
    parser.add_argument('mode', type=str, help='Generation mode')
    parser.add_argument('name', type=str, help='Dataset name')
    parser.add_argument('input_dim', type=int, help='Input dimension')
    parser.add_argument('output_dim', type=int, help='Output dimension')
    parser.add_argument('num_instances', type=int, help='Number of instances')
    parser.add_argument('--seed', type=int, help='Random seed', default=24)
    parser.add_argument('--deg', type=int, default=5)
    parser.add_argument('--mul', type=float, help='Multiplicative noise', default=0.1)
    parser.add_argument('--add', type=float, help='Additive noise', default=0.03)
    parser.add_argument('--cap', type=float, help='Relative capacity', default=0.5)
    parser.add_argument('--corr', type=int, help='Correlate values and weights', default=1)
    parser.add_argument('--rho', type=float, default=0.0)
    parser.add_argument('--penalty', type=float, default=10.0)

    _args = parser.parse_args()

    return _args


def save_params(_args: argparse.Namespace, _path: str) -> None:

    params = {
        'mode': _args.mode,
        'name': _args.name,
        'input_dim': _args.input_dim,
        'output_dim': _args.output_dim,
        'num_instances': _args.num_instances,
        'seed': _args.seed,
        'deg': _args.deg,
        'mul': _args.mul,
        'add': _args.add,
        'cap': _args.cap,
        'corr': _args.corr,
        'rho': _args.rho,
        'penalty': _args.penalty
    }

    if mode == "kp capacity":
        del params['cap']
        del params['add']

    save_path = os.path.join(_path, _args.name + "_config.json")

    with open(save_path, "w") as outfile:
        json.dump(params, outfile, indent=4)


if __name__ == '__main__':

    args = parse_arguments()

    mode = args.mode
    name = args.name
    input_dim = args.input_dim
    output_dim = args.output_dim
    num_instances = args.num_instances
    seed = args.seed
    deg = args.deg
    mul = args.mul
    add = args.add
    cap = args.cap
    corr = args.corr
    rho = args.rho
    penalty = args.penalty

    if mode == "kp values":
        solver = KnapsackValuesSolver()
        generator = KnapsackValuesGenerator(name, solver, input_dim, output_dim, deg, mul, add, cap, corr, rho)
    elif mode == "kp weights":
        solver = KnapsackWeightsSolver()
        generator = KnapsackWeightsGenerator(name, solver, input_dim, output_dim, deg, mul, add, cap, corr, rho, penalty)
    elif mode == "kp capacity":
        solver = KnapsackCapacitySolver()
        generator = KnapsackCapacityGenerator(name, solver, input_dim, output_dim, deg, mul, corr, rho, penalty)
    else:
        raise InvalidModeException("mode parameter must match one of these:", VALID_MODES)

    path = os.path.join(DATA_PATH, mode)

    generator.generate(path, num_instances, seed)

    save_params(args, path)
