import datasets
import json
import os
import hydra
import time


LOAD_ROOT = "/home/blockadam/InferencePessimisim/base_generations/data/alpaca"
SAVE_ROOT = "/home/blockadam/InferencePessimisim/comparators/alpaca"

GENERATION_SEED = 133742

def get_alpaca_dataset():
    # Load the dataset
    alpaca_dataset = datasets.load_dataset("tatsu-lab/alpaca_eval", split='eval')
    return alpaca_dataset


def add_prompts(data, base_data, question_field, comparator_field='response'):
    prompt_idxs, prompts, responses = [], [], []
    for datum in data:
        base_datum = base_data[datum['prompt_idx']]
        prompt_idxs.append(datum['prompt_idx'])
        prompts.append(base_datum[question_field])
        responses.append(datum['response'])
    new_data = {
        'prompt_idx': prompt_idxs,
        question_field: prompts,
        comparator_field: responses
    }
    return datasets.Dataset.from_dict(new_data)


def do_alpaca_generations(load_path,save_path, base_data, question_field, comparator_field='response'):
    with open(load_path, 'r') as f:
        data = json.load(f)
    new_dataset = add_prompts(data, base_data, question_field, comparator_field)
    print(f"Saving to {save_path}")
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    new_dataset.save_to_disk(save_path)






@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg):


    if cfg.task.name != 'alpaca':
        raise ValueError("This script is only for the Alpaca task; rerun setting `task=alpaca`")
    
    os.makedirs(SAVE_ROOT, exist_ok=True)

    alpaca_dataset = get_alpaca_dataset()

    for policy in os.listdir(LOAD_ROOT):
        load_parent = os.path.join(LOAD_ROOT, policy, 'generations')
        load_path = os.path.join(load_parent, os.listdir(load_parent)[0])
        save_path = os.path.join(SAVE_ROOT, policy, f"Base_{GENERATION_SEED}")
        do_alpaca_generations(load_path, save_path, alpaca_dataset, cfg.task.data.question_field, cfg.task.data.comparator_field)








if __name__ == "__main__":
    main()