from datasets import load_dataset
from langchain.llms import OpenAIChat
from langchain import PromptTemplate
from langchain.chains import LLMChain
from functools import partial
import argparse
import sys
sys.path.append('../')
from utils.constant import FEATURES
from utils.common import get_lm, feature_score

def is_single_round(element: dict[str, str]) -> bool:
    return (element["chosen"].count("\n\nHuman:") == 1 and
            element["chosen"].count("\n\nAssistant:") == 1 and
            element["rejected"].count("\n\nHuman:") == 1 and
            element["rejected"].count("\n\nAssistant:") == 1)

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--dataset-name', default='Anthropic/hh-rlhf')
    ap.add_argument('--model-name', default='google/flan-t5-xl')
    ap.add_argument('--model-type', default='huggingface', choices=['openai', 'huggingface'])
    args = ap.parse_args()

    dataset = load_dataset(args.dataset_name, split="train")
    dataset = dataset.filter(is_single_round).filter(lambda x: len(x["chosen"]) < 1000 and len(x["rejected"]) < 1000)
    dataset = dataset.shuffle(seed=2137).select(range(20_000))
    print(len(dataset))
    
    lm = get_lm(args.model_name, args.model_type)
    feature_extractor = get_feature_extractor(lm)
    annotate = partial(all_features_scores, feature_extractor=feature_extractor, features=FEATURES)
    dataset = dataset.map(annotate)
    dataset = dataset.filter(
        lambda x: all(
            x[f'{feature}_chosen'] is not None and
            x[f'{feature}_rejected'] is not None
            for feature in FEATURES
        )
    )
    dataset.to_json(f'hh-rlhf_with_features_{args.model_name.replace("/","_")}.jsonl', orient='records', lines=True)
    # dataset.push_to_hub('anonymous/hh-rlhf_with_features')
    
def get_feature_extractor(lm):
    feature_extractor_template = """\
    You will be given a conversation between a human and an AI assistant. Your job is to judge whether assistant's reply {attribute_desc}. Score that on a scale from 1 to 10 where 1 means {attr_min} and 10 means {attr_max}. Here's the conversation:

    ====================
    {conversation}
    ====================

    Now answer by outputting a number from 1 to 10 (and nothing else).

    Answer:"""
    feature_extractor_prompt = PromptTemplate(
        input_variables=["conversation", "attribute_desc", "attr_min", "attr_max"],
        template=feature_extractor_template,
    )
    feature_extractor = LLMChain(llm=lm, prompt=feature_extractor_prompt, output_key="score")
    return feature_extractor

def all_features_scores(element, feature_extractor, features):
    for feature, feature_kwargs in features.items():
        element[f'{feature}_chosen'] = feature_score(feature_extractor, conversation=element["chosen"], **feature_kwargs)
        element[f'{feature}_rejected'] = feature_score(feature_extractor, conversation=element["rejected"], **feature_kwargs)
    return element

if __name__ == '__main__':
    main()
