from disco.distributions.lm_distribution import LMDistribution, TextSample
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM
import argparse


def reformat(example):
    assert human_context(example['chosen']) == human_context(example['rejected'])
    example['human'] = human_context(example['chosen'])
    example['assistant_chosen'] = assistant_response(example['chosen'])
    example['assistant_rejected'] = assistant_response(example['rejected'])
    return example

def human_context(text):
    human_prefix = " Human: "
    assistant_start = text.find("Assistant:")
    return text[len(human_prefix):assistant_start].strip()

def assistant_response(text):
    assistant_prefix = "Assistant: "
    assistant_start = text.find(assistant_prefix)
    return text[assistant_start+len(assistant_prefix):].strip()

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--dataset', default='anonymous/hh-rlhf_with_features')
    ap.add_argument('--options', nargs='+', default=['chosen', 'rejected'])
    ap.add_argument('--option-prefix', default='human_ref')
    ap.add_argument('--context', default='history')
    ap.add_argument('--save_local', action='store_true', help='Option to save data in local')

    args = ap.parse_args()

    ds = load_dataset(args.dataset)
    ds = ds.map(reformat)
    ds['train'].add_column('labels', ['chosen'] * len(ds['train']))
    ds = ds['train'].train_test_split(test_size=0.5, shuffle=False)

    dataset_name = args.dataset.split('/')[1]
    if args.save_local:
        fname = f'out/{dataset_name}'
        ds.save_to_disk(fname)
        print(f'Saved to {fname}')
    else:
        ds.push_to_hub(f'anonymous/{dataset_name}')

main()

