import click
import logging
import ujson as json
from collections import namedtuple
from fastchat.train.train import preprocess
from tqdm import tqdm
import torch
import transformers
import os,sys
sys.path.append(os.path.dirname(__file__)+"./Open-Assistant")
from  model.model_training.models.reward_model import GPTNeoXRewardModel, GPTNeoXRewardModelConfig


logger = logging.getLogger(__file__)

def extract_first_conv(conv):
    if conv[0]["from"] == "gpt":
        # if first round begins by gpt
        if len(conv) >=3:
            conv_str = f"<|prompter|>{conv[0]['value']}\n{conv[1]['value']}<|endoftext|><|assistant|>{conv[2]['value']}"
        else:
            conv_str = f"<|prompter|><|endoftext|><|assistant|>{conv[0]['value']}"
    elif conv[0]["from"] == "human":
        # if first round begins by human
        conv_str = f"<|prompter|>{conv[0]['value']}<|endoftext|><|assistant|>{conv[1]['value']}"    
    else:
        pass

    return conv_str

def run(args):
    with open(args.input) as fd:
        raw_data = json.load(fd)
    # load reward model
    reward_model_name = "OpenAssistant/oasst-rm-2.1-pythia-1.4b-epoch-2.5"
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        reward_model_name,
        model_max_length=2048,
    )
    reward_model = transformers.AutoModelForSequenceClassification.from_pretrained(
        reward_model_name, 
    ).to("cuda")
    tokenizer.model_max_length=2048

    validated_data = []
    for one_data in raw_data:
        if not one_data["conversations"]:
            continue
        same_role_flag = False
        wrong_from_flag = False
        last_from = None
        for sent in one_data["conversations"]:
            if sent["from"] == last_from:
                same_role_flag = True
            if sent["from"] not in ["human", "gpt"]:
                wrong_from_flag = True
            last_from = sent["from"]
        if same_role_flag or wrong_from_flag:
            continue
        if len(one_data['conversations']) <= 1:
            continue
        validated_data.append(one_data)


    save_data = []
    for one_data in tqdm(validated_data):
        first_conv = extract_first_conv(one_data['conversations'])
        
        input_ids = tokenizer(first_conv, return_tensors="pt", truncation=True).input_ids.to(reward_model.device)
        score = reward_model(input_ids)
        print(score)
        score = score.logits.cpu().detach().numpy()[0][0]
        print(score)
        one_data["reward"] = float(score)
        save_data.append(one_data)
    
    with open(args.output, "w") as fd:
        json.dump(save_data, fd, indent=1)


@click.command()
@click.option('-i', '--input', default="")
@click.option('-o', '--output', default="")
def main(**kwargs):
    Arg = namedtuple('Arg', kwargs.keys())
    args = Arg(**kwargs)
    logger.info(kwargs)
    run(args)


if __name__ == '__main__':
    logging.basicConfig(filename=__file__ + '.log',
                        filemode='a',
                        format='[%(levelname)-5.5s][%(asctime)s][%(filename)s %(lineno)d]: %(message)s',
                        datefmt='%d-%m-%Y %H:%M:%S',
                        level=logging.DEBUG)
    logging.getLogger().addHandler(logging.StreamHandler())
    main()
