import os
import json
import time
import argparse
import subprocess

from multiprocessing.pool import ThreadPool


parser = argparse.ArgumentParser(description='training PAP.')
parser.add_argument('--pcstable_config', type=str, default='./configs/pcstable_configs.json', help='algorithm configuration')
parser.add_argument('--fges_config', type=str, default='./configs/fges_configs.json', help='algorithm configuration')
parser.add_argument('--train', type=str, default=None, help='train data')
parser.add_argument('--k', type=int, default=4, help='number of algorithm in the AE')
parser.add_argument('--n', type=int, default=10, help='number of SMAC each iteration')
parser.add_argument('--output', type=str, default=None, help='output directory')
parser.add_argument('--processes_smac', type=int, default=1, help='number of processes used in training')
parser.add_argument('--processes_algo', type=int, default=1, help='number of processes used in training')
args = parser.parse_args()

with open(args.train, 'r') as f:
    json_train = json.load(f)

OUTPUT_DIR = args.output


def print_with_time(content):
    print(time.strftime("%Y-%m-%d %H:%M:%S UTC %z", time.gmtime()))
    print(content, '\n')


def open_and_wait_timeout(cmd, logfile, timeout=None):
    with open(logfile, 'w') as f:
        process = subprocess.Popen(cmd, shell=True, stdout=f, stderr=f)
        print_with_time('open cmd: {}'.format(cmd))
        try:
            process.wait(timeout=timeout)
            print_with_time('finished cmd: {}'.format(cmd))
        except subprocess.TimeoutExpired:
            process.terminate()
            print_with_time('timeout cmd: {}'.format(cmd))


def main():
    print_with_time('args - {}'.format(args))

    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
    
    history = {'best': {}}
    for d in json_train['train']:
        history['best'][d['name']] = 0.0

    PAP = {'configs': []}

    for k_cnt in range(1, args.k + 1):
        print_with_time('start finding algorithm {}'.format(k_cnt))

        ALGO_DIR = os.path.join(OUTPUT_DIR, 'algorithm_{}'.format(k_cnt))
        if not os.path.exists(ALGO_DIR):
            os.makedirs(ALGO_DIR)
        
        for n_cnt in range(1, args.n + 1):
            SMAC_DIR = os.path.join(ALGO_DIR, 'run_{}'.format(n_cnt))
            if not os.path.exists(SMAC_DIR):
                os.makedirs(SMAC_DIR)
            
        history_before = os.path.join(ALGO_DIR, 'history_before_{}.log'.format(k_cnt))
        with open(history_before, 'w') as f:
            json.dump(history, f, indent='\t')
        print_with_time('store history - {}'.format(history_before))

        tp = ThreadPool(args.processes_smac)
        for n_cnt in range(1, args.n + 1):
            if n_cnt % 2 == 0:
                cfg_file = args.fges_config
            else:
                cfg_file = args.pcstable_config
            SMAC_DIR = os.path.join(ALGO_DIR, 'run_{}'.format(n_cnt))

            cmd = 'python smac_runner.py' + \
                    ' --config {}'.format(cfg_file) + \
                    ' --train {}'.format(args.train) + \
                    ' --history {}'.format(history_before) + \
                    ' --output {}'.format(SMAC_DIR) + \
                    ' --processes {}'.format(args.processes_algo) + \
                    ' --seed {}'.format(k_cnt * args.n + n_cnt + 1)
            tp.apply_async(open_and_wait_timeout, (cmd, os.path.join(
                SMAC_DIR,
                'logger.log'
            ), None))
        tp.close()
        tp.join()

        result = []
        for n_cnt in range(1, args.n + 1):
            SMAC_DIR = os.path.join(ALGO_DIR, 'run_{}'.format(n_cnt))
            with open(os.path.join(SMAC_DIR, 'result.json')) as f:
                result.append([n_cnt, json.load(f)])
        result.sort(key=lambda x: x[1]['cost'])

        PAP['configs'].append(result[0][1]['config'])
        cfg = result[0][1]['config']
        if result[0][0] % 2 == 0:
            temp = 'fges_{}-{}_{}-{}_{}-{}'.format(
                'penalty', cfg['penalty'],
                'faithful', cfg['faithful'],
                'degree', cfg['degree']
            )
            PAP['configs'][-1]['algorithm'] = 'fges'
        else:
            temp = 'pcstable_{}-{}_{}-{}'.format(
                'alpha', cfg['alpha'],
                'depth', cfg['depth']
            )
            PAP['configs'][-1]['algorithm'] = 'pcstable'

        for d in json_train['train']:
            if not os.path.exists(os.path.join(ALGO_DIR, 'run_{}'.format(result[0][0]), 'cache', 
                                   '{}_{}_metrics.json'.format(d['name'], temp))):
                print_with_time('file not exist - {}'.format(os.path.join(ALGO_DIR, 'run_{}'.format(result[0][0]), 'cache', 
                                   '{}_{}_metrics.json'.format(d['name'], temp))))
                continue


            with open(os.path.join(ALGO_DIR, 'run_{}'.format(result[0][0]), 'cache', 
                                   '{}_{}_metrics.json'.format(d['name'], temp))) as f:
                metrics = json.load(f)
            history['best'][d['name']] = max(history['best'][d['name']], metrics['f1_adj']+metrics['f1_arrow'])

        history_after = os.path.join(ALGO_DIR, 'history_after_{}.log'.format(k_cnt))
        with open(history_after, 'w') as f:
            json.dump(history, f, indent='\t')
        print_with_time('store history - {}'.format(history_after))

        sub_pap = os.path.join(ALGO_DIR, 'pap.json')
        with open(sub_pap, 'w') as f:
            json.dump(PAP, f, indent='\t')
        print_with_time('store pap - {}'.format(sub_pap))

    pap_dir = os.path.join(OUTPUT_DIR, 'pap.json')
    with open(pap_dir, 'w') as f:
            json.dump(PAP, f, indent='\t')
    print_with_time('store pap - {}'.format(pap_dir))
            

if __name__ == '__main__':
    main()