import os
import sys
import json
import argparse
import logging
import random
import warnings

import numpy as np

from . import benchmark
from . import settings
from . import methods

logger = logging.getLogger(__name__)


LOG_LEVELS = {
    'debug': logging.DEBUG,
    'info': logging.INFO,
    'warning': logging.WARNING,
    'error': logging.ERROR,
    'critical': logging.CRITICAL,
}


def argsparser():
    parser = argparse.ArgumentParser()
    parser.add_argument('output')
    parser.add_argument('--input', '-i', default=None, type=str, metavar='FILE')
    parser.add_argument('--seed', '-s', default=17, type=int, metavar='N')
    parser.add_argument('--observations', '-n', default=1000, type=int, metavar='N')
    parser.add_argument('--repetitions', '-r', default=200, type=int, metavar='N')
    parser.add_argument('--dgp-config', '-dgp', required=True, type=str, metavar='JSON')
    parser.add_argument('--rdd-settings-config', '-rdd', default=None, type=str, metavar='JSON')
    parser.add_argument('--log-level', '-log', default='info', type=str, choices=LOG_LEVELS.keys())
    parser.add_argument('--log-file', '-log-file', default=None, type=str, metavar='FILE')
    parser.add_argument('--ignore-warnings', '-iw', default=True, type=bool, metavar='W')
    parser.add_argument(
            '--operator-policy', '-in', default='cautious',
            type=str, choices=['cautious', 'acknowledge'],
            help='Operator type (acknowledge: T=D=I_D ^ I_Y, cautious T=I_D ^ I_Y, D=I_D ^ I_Y ^ I_E)'
    )
    return parser


if __name__ == "__main__":
    parser = argsparser()
    args = parser.parse_args()

    # logging and warnings
    if args.log_file is None:
        logging.basicConfig(
            stream=sys.stdout,
            format='%(asctime)s %(levelname)s:%(name)s: %(message)s',
        )
    else:
        logging.basicConfig(
            filename=args.log_file,
            format='%(asctime)s %(levelname)s:%(name)s: %(message)s',
        )
    logger.setLevel(LOG_LEVELS[args.log_level.lower()])
    benchmark.logger.setLevel(LOG_LEVELS[args.log_level.lower()])

    if args.ignore_warnings:
        warnings.filterwarnings('ignore')

    # seeding
    rnd = np.random.default_rng(args.seed)

    # seed legacy random generators
    np.random.seed(rnd.integers(0, 2**32 - 1))
    random.seed(int(rnd.integers(0, 2**32 - 1)))

    # dgp parameters
    with open(args.dgp_config, 'r') as dgp_json:
        dgp_config = json.load(dgp_json)

    if args.rdd_settings_config is None:
        rdd_settings = settings.default_settings()
    else:
        with open(args.rdd_settings_config, 'r') as rdd_json:
            rdd_settings = json.load(rdd_json)

    # learner config
    dml_methods = methods.default_methods()

    # ensure output folder
    if not os.path.isdir(args.output):
        os.mkdir(args.output)

    # dump used config
    with open(os.path.join(args.output, 'dgp_config.json'), 'w') as dpg_json:
        json.dump(dgp_config, dpg_json)
    with open(os.path.join(args.output, 'rdd_config.json'), 'w') as rdd_json:
        json.dump(rdd_settings, rdd_json)

    # use dpg
    # include nevertaker switches
    # between the "cautious operator" and the "acknowledging operator"
    # TODO naming
    data_generator = benchmark.dgp_datagen(
        dgp_params=dgp_config | {"include_nevertakers": args.operator_policy == 'cautious'},
        n_obs=args.observations,
        n_rep=args.repetitions,
        rnd=rnd
    )

    logger.info('starting benchmark')

    try:
        benchmark.benchmark(
            data_generator=data_generator,
            settings=rdd_settings,
            methods=dml_methods,
            output=args.output,
            basename='rdd_result'
        )
    finally:
        # collect
        try:
            df = benchmark.fetch_json_results(
                args.output,
                basename='rdd_result',
                settings=rdd_settings
            )
            df.to_csv(os.path.join(args.output, 'rdd_result.csv'), index=False)

            for fname in os.listdir(args.output):
                if fname.startswith('rdd_result') and fname.endswith('.json'):
                    logger.info(f'delete {fname}')
                    os.remove(os.path.join(args.output, fname))

        except BaseException as ex:
            logger.critical('exception during json consolidation: %s', str(ex))
