# This is a fork of shp4.py which focuses only on annotating samples with LM features
from tqdm import tqdm
import dill as pickle
import pandas as pd
from langchain.chat_models import ChatOpenAI
from langchain import PromptTemplate, HuggingFacePipeline
from langchain.chains import LLMChain
import numpy as np
from pathlib import Path
import argparse
import re
import sys
sys.path.append('../')
from utils.constant import FEATURES
from utils.common import get_lm, feature_score


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--samples-path', type=Path, default='./generated_flan_t5_large.json')
    ap.add_argument('--model-name', default='google/flan-t5-xl')
    ap.add_argument('--model-type', default='huggingface', choices=['openai', 'huggingface'])
    args = ap.parse_args()

    lm = get_lm(args.model_name, args.model_type)

    feature_extractor = get_feature_extractor(lm)

    annotated_samples = annotate_samples(args.samples_path, feature_extractor, FEATURES)

    out_fn = args.samples_path.stem + f'_annotated_{args.model_name.replace("/","_")}' + args.samples_path.suffix

    annotated_samples.to_json(out_fn)


def annotate_samples(samples_path, feature_extractor, features, max_samples=100, max_responses=16):
    t5_samples = pd.read_json(samples_path).to_dict()
    feature_names = list(features.keys())
    annotated_samples = []
    for _, entry in tqdm(list(t5_samples.items())[:max_samples]):
        prompt = entry['prompt']
        for response in tqdm(entry['response'][:max_responses], leave=False):
            scores = all_features_scores(feature_extractor, prompt, response, features)
            if any(score is None for feature_name, score in scores.items()):
                continue
            entry = scores
            entry.update({'prompt': prompt, 'response': response})
            annotated_samples.append(entry)
    return pd.DataFrame(annotated_samples, columns=['response', 'prompt'] + feature_names)

def get_feature_extractor(lm):
    feature_extractor_template = """
    You will be given a Reddit post and a reply. Your job is to judge whether the reply {attribute_desc}. Score that on a scale from 1 to 10 where 1 means {attr_min} and 10 means {attr_max}.

    POST:
    {history}

    Reply:
    {reply}

    Answer by outputting a number from 1 to 10 (and nothing else).

    Answer:"""
    feature_extractor_prompt = PromptTemplate(
        input_variables=["history", "reply", "attribute_desc", "attr_min", "attr_max"],
        template=feature_extractor_template,
    )
    feature_extractor = LLMChain(llm=lm, prompt=feature_extractor_prompt, output_key="score")
    return feature_extractor



def all_features_scores(feature_extractor, prompt, response, features):
    feature_scores = {}
    for feature_name, feature_kwargs in features.items():
        feature_scores[feature_name] = feature_score(feature_extractor, history=prompt, reply=response, **feature_kwargs)
    return feature_scores

if __name__ == '__main__':
    main()
