import argparse
import os
import re
import subprocess
import sys
import time

USE_ADF = True                  # toggles --cb_explore_adf
USE_LDF = False                  # toggles --cbify_ldf
USE_CS = False                   # toggles --cbify_cs

# EDIT THIS TO POINT TO YOUR VW EXECUTABLE
VW = '/YOUR PATH HERE'

if USE_CS:
    VW_DS_DIR='rcv1_test/data'
    DIR_PATTERN = 'res/cbresults_{}/'
else:
    VW_DS_DIR='multiclass/'
    DIR_PATTERN = 'res/cbresults_{}/'

rgx = re.compile('^average loss = (.*)$', flags=re.M)


def expand_cover(policies):
    algs = []
    for psi in [0, 0.01, 0.1, 1.0]:
        algs.append(('cover', policies, 'psi', psi))
        algs.append(('cover', policies, 'psi', psi, 'nounif', None))
    return algs

params = {
    'alg': [
        # ('supervised',),
        ('supervised', 'loss_function', 'logistic', 'sigmoid', None),
        # 
        ('squarecb', None, 'gamma_scale', 1e3),
        ('squarecb', None, 'gamma_scale', 700),
        ('squarecb', None, 'gamma_scale', 400),
        ('squarecb', None, 'gamma_scale', 1e2),
        ('squarecb', None, 'gamma_scale', 50),
        ('squarecb', None, 'gamma_scale', 1e1),
        ('squarecb', None, 'gamma_scale', 1e3, 'gamma_exponent', .25),
        ('squarecb', None, 'gamma_scale', 700, 'gamma_exponent', .25),
        ('squarecb', None, 'gamma_scale', 400, 'gamma_exponent', .25),
        ('squarecb', None, 'gamma_scale', 1e2, 'gamma_exponent', .25),
        ('squarecb', None, 'gamma_scale', 50, 'gamma_exponent', .25),
        ('squarecb', None, 'gamma_scale', 1e1, 'gamma_exponent', .25),
        ('squarecb', None, 'gamma_scale', 1e3, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 700, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 400, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 1e2, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 50, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 1e1, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 1e3, 'gamma_exponent', .25, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 700, 'gamma_exponent', .25, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 400, 'gamma_exponent', .25, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 1e2, 'gamma_exponent', .25, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 50, 'gamma_exponent', .25, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 1e1, 'gamma_exponent', .25, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 1e3, 'fast', None, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 700, 'fast', None, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 400, 'fast', None, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 1e2, 'fast', None, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 50, 'fast', None, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 1e1, 'fast', None, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 1e3, 'gamma_exponent', .25, 'fast', None, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 700, 'gamma_exponent', .25, 'fast', None, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 400, 'gamma_exponent', .25, 'fast', None, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 1e2, 'gamma_exponent', .25, 'fast', None, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 50, 'gamma_exponent', .25, 'fast', None, 'loss_function', 'logistic', 'sigmoid', None),
        ('squarecb', None, 'gamma_scale', 1e1, 'gamma_exponent', .25, 'fast', None, 'loss_function', 'logistic', 'sigmoid', None),
        # ('cover', 1),
        # ('cover', 1, 'nounif', None),
        # ] + expand_cover(4) + expand_cover(8) + expand_cover(16),
    ],
    'learning_rate': [0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1.0, 3.0, 10.0],
    'cb_type': ['dr', 'ips', 'mtr'],
    }

extra_flags = None

def param_grid():
    grid = [{}]
    for k in params:
        new_grid = []
        for g in grid:
            for param in params[k]:
                gg = g.copy()
                gg[k] = param
                new_grid.append(gg)
        grid = new_grid

    # return sorted(grid)
    return grid


def ds_files():
    import glob
    return sorted(glob.glob(os.path.join(VW_DS_DIR, '*.vw.gz')))


def get_task_name(ds, params):
    did, n_actions = os.path.basename(ds).split('.')[0].split('_')[1:]

    task_name = 'ds:{}|na:{}'.format(did, n_actions)
    if len(params) > 1:
        task_name += '|' + '|'.join('{}:{}'.format(k, v) for k, v in sorted(params.items()) if k != 'alg')
    task_name += '|' + ':'.join([str(p) for p in params['alg'] if p is not None])
    return task_name


def process(ds, params, results_dir, test=False):
    print( 'processing', ds, params)
    did, n_actions = os.path.basename(ds).split('.')[0].split('_')[1:]

    cmd = [VW, ds, '-b', '24']
    for k, v in params.items():
        if k == 'alg':
            if v[0] == 'supervised':
                cmd += ['--csoaa' if USE_CS else '--oaa', str(n_actions)]
            else:
                if not USE_LDF:
                    cmd += ['--cbify', str(n_actions)]
                if USE_CS:
                    cmd += ['--cbify_cs']
                if USE_LDF:
                    cmd += ['--cbify_ldf']
                if extra_flags:
                    cmd += extra_flags
                if USE_ADF:
                    cmd += ['--cb_explore_adf']
                # --randomtie arg seems to be removed from latest VW branch.
                # if RANDOM_TIE:
                #     cmd += ['--randomtie']
                assert len(v) % 2 == 0, 'params should be in pairs of (option, value)'
                for i in range(int(len(v) / 2)):
                    cmd += ['--{}'.format(v[2 * i])]
                    if v[2 * i + 1] is not None:
                        cmd += [str(v[2 * i + 1])]
        else:
            if params['alg'][0] == 'supervised' and k == 'cb_type':
                pass
            else:
                cmd += ['--{}'.format(k), str(v)]

    print( 'running', cmd)
    if test:
        return
    t = time.time()
    output = subprocess.check_output(cmd,  stderr=subprocess.STDOUT).decode('ascii')
    sys.stderr.write('\n\n{}, {}, time: {}, output:\n'.format(ds, params, time.time() - t))
    sys.stderr.write(output)
    pv_loss = float(rgx.findall(output)[0])
    print( 'elapsed time:', time.time() - t, 'pv loss:', pv_loss)

    return pv_loss


def skip_params(params):
    # skip evaluating the following
    return (params['alg'][0] == 'supervised' and params['cb_type'] != 'mtr') or \
        (params['alg'][0].startswith('regcb') and params['cb_type'] != 'mtr') or \
        (params['alg'][0].startswith('squarecb') and params['cb_type'] != 'mtr')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='vw job')
    parser.add_argument('task_id', type=int, help='task ID, between 0 and num_tasks - 1')
    parser.add_argument('num_tasks', type=int) # 
    parser.add_argument('--task_offset', type=int, default=0,
                        help='offset for task_id in output filenames')
    parser.add_argument('--results_dir', default=DIR_PATTERN.format('agree01'))
    parser.add_argument('--name', default=None)
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--flags', default=None, help='extra flags for cb algorithms')
    args = parser.parse_args()

    if args.name is not None:
        args.results_dir = DIR_PATTERN.format(args.name)

    if args.flags is not None:
        extra_flags = args.flags.split()
    grid = param_grid()
    dss = ds_files()
    tot_jobs = len(grid) * len(dss)

    if args.task_id == 0:
        if not os.path.exists(args.results_dir):
            os.makedirs(args.results_dir)
            import stat
            os.chmod(args.results_dir, os.stat(args.results_dir).st_mode | stat.S_IWOTH)
    else:
        while not os.path.exists(args.results_dir):
            time.sleep(1)
    if not args.test:
        fname = os.path.join(args.results_dir, 'loss{}.txt'.format(args.task_offset + args.task_id))
        done_tasks = set()
        if os.path.exists(fname):
            done_tasks = set([line.split()[0] for line in open(fname).readlines()])
        loss_file = open(fname, 'a')
    idx = args.task_id
    while idx < tot_jobs:
        ds = dss[int(idx / len(grid))]
        params = grid[idx % len(grid)]
        if args.test:
            # Note: The test option prints the command string but doesn't actually run it.
            process(ds, params, args.results_dir, test=True)
        else:
            # task_name is the string that actually gets written to file for this run.
            task_name = get_task_name(ds, params)
            if task_name not in done_tasks and not skip_params(params):
                try:
                    pv_loss = process(ds, params, args.results_dir)
                    loss_file.write('{} {}\n'.format(task_name, pv_loss))
                    loss_file.flush()
                    os.fsync(loss_file.fileno())
                except subprocess.CalledProcessError as err:
                    sys.stderr.write('\nERROR: TASK FAILED {} {}\n\n'.format(ds, params))
                    print(err)
                    print( 'ERROR: TASK FAILED', ds, params)
        idx += args.num_tasks

    if not args.test:
        loss_file.close()
