import os
import shutil



EXPERIMENT_NAME = '01-23-math-generate' #'01-13-alpaca-generate' # '01-15-tldr-generate' # '12-30-gsm8k-generate' # '01-02-tldr-pythia-generate' #  #  '12-31-tldr-pythia-generate' # 

BLACKLIST = [
    'armo-rm',
    'rm-gemma-2b'
    # 'mistral-7b',
    # 'oasst-rm'
]

def get_new_path_root(old_root, parent):
    root_without_parent = old_root[len(parent)+1:]
    new_root = '/'.join(root_without_parent.split('/')[2:])
    return new_root



def move_generations_fs(parent, destination, blacklist=None):
    if blacklist is None:
        blacklist = []
    for root, _, dirs in os.walk(parent):
        for dir in dirs:
            if dir.endswith('.json') or dir.endswith('.npy'):
                try:
                    do_move = True
                    if len(blacklist) > 0:
                        for no_go in blacklist:
                            if no_go in root:
                                do_move = False
                                # print(os.path.join(root, dir))
                    
                    if not do_move:
                        continue

                    new_path = get_new_path_root(root, parent)
                    new_path = os.path.join(destination, new_path)
                    if not os.path.exists(new_path):
                        print(f'Creating {new_path}')
                        os.makedirs(new_path, exist_ok=True)
                    print(f'Moving {os.path.join(root, dir)} to {new_path}')
                    shutil.move(os.path.join(root, dir), new_path)
                except Exception as e:
                    print(f'Error moving {os.path.join(root, dir)}: {e}')
                    continue




if __name__ == '__main__':

    # amlt_parent = '/home/blockadam/InferencePessimisim/amlt'
    amlt_parent = '/home/anonymouskr/inference_rlhf/amlt'
    exp = EXPERIMENT_NAME
    # destination = '/home/blockadam/InferencePessimisim/data'
    destination = '/home/anonymouskr/inference_rlhf/data'
    move_generations_fs(os.path.join(amlt_parent, exp), destination, blacklist=BLACKLIST)
