import argparse
import pathlib

from evaluator.graphon_evaluator import DiscretizedGraphonEvaluatorFinite
from games.finite.beach import BeachGraphon
from simulator.graphon_simulator import DiscretizedGraphonExactSimulatorFinite
from solver.graphon_solver import DiscretizedGraphonExactSolverFinite
from solver.rmd_solver_array import RMDSolver
from solver.amd_solver_array import AMDSolver
from solver.prox_solver import ProxRMDSolver


def parse_args():
    parser = argparse.ArgumentParser(description="Approximate MFGs")
    parser.add_argument('--game', help='game setting', default='Beach-Graphon')
    parser.add_argument('--graphon', help='graphon', default='power')
    parser.add_argument('--solver', help='solver', default='omd')
    parser.add_argument('--simulator', help='simulator', choices=['exact'], default='exact')
    parser.add_argument('--evaluator', help='evaluator', choices=['exact'], default='exact')
    parser.add_argument('--eval_solver', help='eval solver', choices=['exact', 'ppo'], default='exact')

    parser.add_argument('--iterations', type=int, help='number of outer iterations', default=5)
    parser.add_argument('--total_iterations', type=int, help='number of inner solver iterations', default=5000)

    parser.add_argument('--eta', type=float, help='learning rate', default=1.)
    parser.add_argument('--reg_param', type=float, help='regularization parameter', default=0.)
    parser.add_argument('--sigma_update_time', type=int, help='sigma update time', default=10)
    parser.add_argument('--id', type=int, help='experiment name', default=None)

    parser.add_argument('--results_dir', help='results directory')
    parser.add_argument('--exp_name', help='experiment name')
    parser.add_argument('--verbose', type=int, help='debug outputs', default=0)
    parser.add_argument('--num_alphas', type=int, help='number of discretization points', default=1)

    parser.add_argument('--env_params', type=int, help='Environment parameter set', default=0)

    parser.add_argument('--seed', type=int, help='random seed', default=0)

    return parser.parse_args()


def generate_config(args):
    return generate_config_from_kw(**{
        'game': args.game,
        'graphon': args.graphon,
        'solver': args.solver,
        'simulator': args.simulator,
        'evaluator': args.evaluator,
        'eval_solver': args.eval_solver,
        'iterations': args.iterations,
        'total_iterations': args.total_iterations,
        'eta': args.eta,
        'reg_param': args.reg_param,
        'sigma_update_time': args.sigma_update_time,
        'results_dir': args.results_dir,
        'exp_name': args.exp_name,
        'id': args.id,
        'verbose': args.verbose,
        'num_alphas': args.num_alphas,
        'env_params': args.env_params,
        'seed': args.seed,
    })


def generate_config_from_kw(**kwargs):
    if kwargs['results_dir'] is None:
        kwargs['results_dir'] = "./results/"

    if kwargs['exp_name'] is None:
        kwargs['exp_name'] = "%s_%s_%s_%s_%s_0_0_%f_%d_%d" % (
            kwargs['game'], kwargs['graphon'], kwargs['solver'], kwargs['simulator'], kwargs['evaluator'],
            kwargs['eta'], kwargs['num_alphas'], kwargs['env_params'])

    if 'id' in kwargs and kwargs['id'] is not None:
        kwargs['exp_name'] = kwargs['exp_name'] + "_%d" % (kwargs['id'])

    experiment_directory = kwargs['results_dir'] + kwargs['exp_name'] + "/"
    pathlib.Path(experiment_directory).mkdir(parents=True, exist_ok=True)

    if kwargs['game'] == 'Beach-Graphon':
        game = BeachGraphon
    else:
        raise NotImplementedError

    if kwargs['solver'] == 'rmd_array':

        solver = RMDSolver
    elif kwargs['solver'] == 'amd_array':
        
        solver = AMDSolver
    elif kwargs['solver'] == 'mirror_prox_rmd':
        solver = ProxRMDSolver
    else:
        raise NotImplementedError

    if kwargs['simulator'] == 'exact':
        simulator = DiscretizedGraphonExactSimulatorFinite
    else:
        raise NotImplementedError

    if kwargs['evaluator'] == 'exact':
        evaluator = DiscretizedGraphonEvaluatorFinite
    else:
        raise NotImplementedError

    if kwargs['eval_solver'] == 'exact':
        eval_solver = DiscretizedGraphonExactSolverFinite
    else:
        raise NotImplementedError

    return {
        # === Algorithm modules ===
        "game": game,
        "solver": solver,
        "simulator": simulator,
        "evaluator": evaluator,
        "eval_solver": eval_solver,

        # === General settings ===
        "iterations": kwargs['iterations'],

    # === Default module settings ===
    "game_config": {
        **(kwargs['env_params'] if isinstance(kwargs['env_params'], dict) else {"env_params": kwargs['env_params']}),
    },
        "solver_config": {
            "total_iterations": kwargs['total_iterations'],
            "eta": kwargs['eta'],
            "reg_param": kwargs['reg_param'],
            "sigma_update_time": kwargs['sigma_update_time'],
            'verbose': kwargs['verbose'],
            'num_alphas': kwargs['num_alphas'] if 'num_alphas' in kwargs else 101,
        },
        "eval_solver_config": {
            "total_iterations": kwargs['total_iterations'],
            "eta": 0,
            'verbose': kwargs['verbose'],
            'num_alphas': kwargs['num_alphas'] if 'num_alphas' in kwargs else 101,
        },
        "simulator_config": {
            'num_alphas': kwargs['num_alphas'] if 'num_alphas' in kwargs else 101,
        },
        "evaluator_config": {
            'num_alphas': kwargs['num_alphas'] if 'num_alphas' in kwargs else 101,
        },

        "experiment_directory": experiment_directory,
        "seed": kwargs["seed"],
    }


def parse_config():
    args = parse_args()
    return generate_config(args)
