import os
import time
import re
import time
from argparse import ArgumentParser
import pandas as pd


# TASK = 'gsm8k'
# DO_PLOTTING = True

def parse_args():
    parser = ArgumentParser()
    parser.add_argument('--task', type=str, default='gsm8k')
    parser.add_argument('--do-plotting', type=str, default='False')
    parser.add_argument('--run-data', type=str, default='False')
    parser.add_argument('--repeats', type=int, default=5)
    parser.add_argument('--unlimited-k', type=str, default='False')
    parser.add_argument('--user', type=str, default='ab')
    parser.add_argument('--range-normalize', type=str, default='False')
    parser.add_argument('--reward-model', type=str, default='')
    return parser.parse_args()









def get_policies(path_to_task):
    try:
        return os.listdir(path_to_task)
    except:
        task = os.path.basename(path_to_task)
        root = os.path.dirname(os.path.dirname(path_to_task))
        new_path_to_task = os.path.join(root, 'parsed_data', task)
        return os.listdir(new_path_to_task)



def get_rewards(path_to_task, policy):
    try:
        dirs = os.listdir(os.path.join(path_to_task, policy))
    except:
        task = os.path.basename(path_to_task)
        root = os.path.dirname(os.path.dirname(path_to_task))
        new_path_to_task = os.path.join(root, 'parsed_data', task, policy)
        csvs = os.listdir(new_path_to_task)
        csv = None
        for c in csvs:
            if '.csv' in c:
                csv = c
                break
        
        temp_df = pd.read_csv(os.path.join(new_path_to_task, csv))

        possible_rewards = ['oasst-rm', 'rm-gemma-2b', 'grm-llama-3b', 'armo-rm']
        reward_list = []
        for col in temp_df.columns:
            if col in possible_rewards:
                reward_list.append(col)
        return reward_list

    output = []
    for dir in dirs:
        if dir == 'generations':
            continue
        else:
            output.append(dir)
    return output



def main():
    args = parse_args()
    TASK = args.task
    repeats = args.repeats
    unlimited_k = args.unlimited_k
    user = args.user
    range_normalize = args.range_normalize
    do_plotting = args.do_plotting
    single_reward_model = args.reward_model


    if 'true' in do_plotting.lower():
        DO_PLOTTING = True
    else:
        DO_PLOTTING = False
    
    if 'true' in args.run_data.lower():
        run_data = True
    else:
        run_data = False
    
    if 'true' in unlimited_k.lower():
        unlimited_k = True
    else:
        unlimited_k = False
    if 'true' in range_normalize.lower():
        range_normalize = True
    else:
        range_normalize = False

    methods = [
        # 'piref',
        'bon',
        'rejection',
    ]
    if unlimited_k:
        # methods.append('rejection_unlimited_k')
        methods = ['rejection_unlimited_k']
    elif range_normalize:
        methods[-1] = 'rejection_range_normalized'

    if user == 'ab':  
        root = '/home/blockadam/InferencePessimisim'
    elif user == 'ak':
        root = '/home/anonymouskr/inference_rlhf'

    data_root = os.path.join(root, 'data')




    task_root = os.path.join(data_root, TASK)
    policies = get_policies(task_root)
    for policy in policies:
        rewards = get_rewards(task_root, policy)
        if len(single_reward_model) > 0: ## Only run a single reward model
            rewards = [single_reward_model]
        for reward in rewards:
            if run_data:
                print(f"\nRunning policy: {policy} with reward {reward}\t" + "#" * 50 + "\n")

                for method in methods:
                    try:
                        if method == 'rejection_unlimited_k': # Hack to allow for unlimited k
                            command = f"python {root}/code/rlhf.py user={user} task={TASK} policy={policy} reward={reward} method=rejection repeats={repeats} ks.kmax=-1"
                        elif method == 'rejection_range_normalized': # Hack to allow for range normalized
                            command = f"python {root}/code/rlhf.py user={user} task={TASK} policy={policy} reward={reward} method=rejection repeats={repeats} method.batch_rmax=False reward_normalization=range"
                        else:
                            command = f"python {root}/code/rlhf.py user={user} task={TASK} policy={policy} reward={reward} method={method} repeats={repeats}"
                        if range_normalize:
                            command += ' io.save_prefix=range-'

                        print(f"Running command: {command}")
                        os.system(command)
                    except Exception as e:
                        if 'KeyboardInterrupt' in str(e):
                            print(f"KeyboardInterrupt caught. Exiting")
                            return
                        else:
                            print(e)
                            continue
                
            if DO_PLOTTING:
                try:
                    plot_command = f"python {root}/code/plot.py user={user}  task={TASK} policy={policy} reward={reward} refresh_data=True"
                    if range_normalize:
                        plot_command += ' io.save_prefix=range-'

                    print(f"\nRunning command: {plot_command}")
                    os.system(plot_command)
                except Exception as e:
                    if e is KeyboardInterrupt:
                        print(f"KeyboardInterrupt caught. Exiting")
                        return
                    else:
                        print(e)
                        continue









if __name__ == '__main__':
    master_start = time.time()
    main()
    master_end = time.time()
    print(f"\nMaster time: {master_end - master_start:.0f} seconds\n")