#!/u/s/k/skbharti/anaconda3/envs/pytorch/bin/python

#-----------------------------------------------------------------------------------------------------------------
#   This is the most comprehensive version of driver code.
#       -   Supports specifying #'EPISODES_PER_GRADIENT' and  'REPEAT_RUNS' of the same game
#       -   Uses - reinforce_full.py 
#-----------------------------------------------------------------------------------------------------------------

from dqn import DQN
from rule_sets import *

import torch
import pandas as pd
import numpy as np
from joblib import Parallel, delayed
import time, os, copy, sys, random, yaml

def single_execution(args):

    # exp_id is the experiment id as logged on neptune(exp 'RULE-6'), run_id is the id of parallel_task inside the experiment
    run_id, exp_dir, exp_id = args['RUN_ID'], args['EXP_DIR'], args['EXP_ID']
    print("RUN : ",run_id)

    #---------------------------------------------
    #
    #   Set the optimizer and CGS seed here(distinct seed for each execution)
    #   Seed of each run is the run_id+1
    #---------------------------------------------
        
    if args['SEED'] == -1:
        args['SEED'] =  args["SEEDS1"][run_id]    
        torch.manual_seed(args["SEEDS2"][run_id])     
        np.random.seed(args["SEEDS3"][run_id])
        random.seed(args["SEEDS4"][run_id])
    else:
        seed = run_id + args['SEED']
        args['SEED'] = seed      # if fix_cgs_seed is true then cgs_seed is args['SEED'] for all runs else cgs_seed is run_id 
        torch.manual_seed(seed)     # for each run_id, set seed = run_id + 1 
        np.random.seed(seed)
        random.seed(seed)

    #if(args['FIX_CGS_SEED']==False):
    #    args['SEED'] = seed      # if fix_cgs_seed is true then cgs_seed is args['SEED'] for all runs else cgs_seed is run_id 
    #-------------------------------------------------------------
    #   Set the output directory of this run
    #   If not recording : set default to 'other'
    #-------------------------------------------------------------

    exp = None  # if not recording neptune exp is none

    if(args['RECORD']):
        #exp_dir = os.path.join('outputs/',exp_id)
        exp_dir = args["EXP_DIR"]

    #all_data_df = pd.DataFrame(columns=['episode', 'time', 'action_type', 'action', 'reward', 'done', 'other'])
    #weight_df = pd.DataFrame(columns=['episode', 'time', 'board', 'weight', 'gradient'])

    #-------------------------------------------------------------
    #   Initialize the environment and the learner
    #-------------------------------------------------------------
    if(args['FEATURIZATION']=='UNARY'):
        env = RuleUnary(args)
    elif(args['FEATURIZATION']=='UNARY_BINARY'):
        env = RuleUnaryBinary(args)
    elif(args['FEATURIZATION']=='UNARY_BINARY_TEST'):
        env = RuleUnaryBinary_TEST(args)
    elif(args['FEATURIZATION']=='SHAPE_MATCH'):
        env = ShapeMatch(args)        
    elif(args['FEATURIZATION']=='COLOR_MATCH'):
        env = ColorMatch(args)                
    elif(args['FEATURIZATION']=='ONE_STEP_UNARY'):
        env = RuleOneStepUnary(args)
    elif(args['FEATURIZATION']=='ONE_STEP_UNARY_BINARY'):
        env = RuleOneStepUnaryBinary(args)
    elif(args['FEATURIZATION']=='ONE_STEP_UNARY_BINARY_WITH_INDEX'):
        env = RuleOneStepUnaryBinaryWithIndex(args)
    elif(args['FEATURIZATION']=='ONE_STEP_UNARY_BINARY_WITHOUT_INDEX'):
        env = RuleOneStepUnaryBinaryWithoutIndex(args)
    else:
        if(args['RULE_NAME']=='rules-05.txt'):
            env = Rule5(args)
        elif(args['RULE_NAME']=='rules-06.txt'):
            env = Rule6(args)  
        elif(args['RULE_NAME']=='rules-07.txt'):
            env = Rule7(args)
    

    if(args['LEARNER']=='DQN'):
        #lrn = DQN(env, exp, all_data_df, weight_df, args)
        lrn = DQN(env, exp, args)
    else:
        print("learner not correct")
    
    lrn.train(0, args['TRAIN_EPISODES'])

    run_dir = os.path.join(exp_dir, str(run_id))
    if(not os.path.exists(run_dir)):
        os.makedirs(run_dir)


    lrn.all_data_df.to_csv(os.path.join(run_dir, 'light_data.csv'))
    lrn.weight_df.to_csv(os.path.join(run_dir, 'heavy_data.csv'))


def run_experiment(args):
    #-----------------------------------------------------------------------------------------------------------------
    #   For each training episode, record the empirical return
    #       - After every 'test_freq' number of training episode, run a test trial with "test_episodes" number of episodes
    #           - Record mean empirical return and success fraction from the
    #           - A test episode is successful if the board is cleared in alpha*"object_count" many steps, we take alpha=1
    #   Terminate training when a test is successful
    #----------------------------------------------------------


    if args['SEED'] == -1:
    
        seeds1 = np.random.randint(1, 2**32-2, size = args["REPEAT"])
        seeds2 = np.random.randint(1, 2**32-2, size = args["REPEAT"])
        seeds3 = np.random.randint(1, 2**32-2, size = args["REPEAT"])
        seeds4 = np.random.randint(1, 2**32-2, size = args["REPEAT"])
        
        args.update({"SEEDS1": seeds1, "SEEDS2": seeds2, "SEEDS3": seeds3, "SEEDS4": seeds4})
    
    exp_id =  'other'


    ### create local experiment directory
    exp_dir =  os.path.join(args['OUTPUT_DIR'], exp_id + args['RULE_NAME'].split('/')[-1].split('.')[0])
    print("export directory", exp_dir)
    args.update({'EXP_DIR' : exp_dir, 'EXP_ID' : exp_id})
    if(not os.path.exists(exp_dir)):
        os.makedirs(exp_dir)

    #-------------------------------------------------------------------------------------
    #  Run 'REPEAT' number of instances on 'num_cores'
    #-------------------------------------------------------------------------------------
    # num_cores = multiprocessing.cpu_count()
    num_jobs, batch_size = args['REPEAT'], args['BATCH_SIZE']
    id_list = np.arange(args['REPEAT'])

    for batch_id in range(int(num_jobs/batch_size)) :
        id_list = np.arange(batch_id*batch_size, (batch_id+1)*batch_size)

        args_list = []
        for run_id in id_list:
            nargs = copy.deepcopy(args)
            nargs.update({'RUN_ID':run_id})
            args_list.append(nargs)

        output_list = Parallel(n_jobs=batch_size)(delayed(single_execution)(args) for args in args_list)


    print("All Completed!")

def get_nth_rule(trial_file, rule_entry):
    df = pd.read_csv(trial_file)
    rule =  df.iloc[rule_entry,:]
    rule_name = rule['rule_id']
    params = {'minO':rule['min_objects'], 'maxO':rule['max_objects'], 'minS':rule['min_shapes'], 'maxS':rule['max_shapes'],
              'minC':rule['min_colors'], 'maxC':rule['max_colors']}
    return rule_name+".txt", params
  
if __name__ == "__main__":
    print("starting driver")
    game_data_path, record, yaml_path = sys.argv[1], int(sys.argv[2]), sys.argv[3]                 # directory path of rules is provided by the caller

    loader = yaml.SafeLoader
    loader.add_implicit_resolver(
    u'tag:yaml.org,2002:float',
    re.compile(u'''^(?:
     [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
    |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
    |\\.[0-9_]+(?:[eE][-+][0-9]+)?
    |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
    |[-+]?\\.(?:inf|Inf|INF)
    |\\.(?:nan|NaN|NAN))$''', re.X),
    list(u'-+0123456789.'))
    with open(yaml_path, 'r') as param_file:
        #args = yaml.load(param_file, Loader = loader)
        args = yaml.load(param_file, Loader = yaml.SafeLoader)
    
    run_mode = args["RUN_MODE"]
    args.update({"RECORD": record})
    if(run_mode=='RULE'):
        ### If running individual rule
        rule_file_path = os.path.join(game_data_path, 'rules/', args['RULE_NAME'])
        args.update({'RULE_FILE_PATH' : rule_file_path,  # full rule-file path      
            })
    
    else:
        ### If running rule from a trial
        #trial_name = 'broad_clock_shape.csv'
        if "RULE_PARAM" in args.keys():
            rule_file_path = os.path.join(game_data_path, 'rules/', args['RULE_NAME'])
            args.update({'RULE_FILE_PATH' : rule_file_path,  # full rule-file path      
                })
        else:        
            trial_name = args['TRIAL_NAME']
            rule_entry = 0
            trial_file_path = os.path.join(game_data_path, 'trial-lists/all_pilot/', trial_name)
            rule_name, rule_params = get_nth_rule(trial_file_path, rule_entry)
            rule_file_path = os.path.join(game_data_path, 'rules/', rule_name)
            #print(rule_file_path)
            args.update({'RULE_PARAM' : rule_params,
                        'RULE_FILE_PATH' : rule_file_path,  # full rule-file path      
                        'RULE_NAME'  : rule_name,           # rule-name
                    })

    args.update({'SOURCE_FILES' : ['dqn.py', 'rule-sets.py', 'rule_game_env.py', 'rule_game_engine.py', 'driver_dqn.py', args['RULE_FILE_PATH']]})   # source files to record on neptune

    start = time.time()
    print(args)
    run_experiment(args)
    end = time.time()
    print(end-start)

############################
#
#   run till success
#   run fixed episodes
#
############################


## np.ravel_multi_index((index_tuple), (dimension_tuple))
## np.unravel_index(index_number, (dimension_tuple))
