"""Generate answers with local models.

Usage:
python gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-name fastchat-t5-3b-v1.0
"""
import argparse
import json
import os
from aisafetylab.models import load_model

import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM


temperature_config = {
    "writing": 0.7,
    "roleplay": 0.7,
    "extraction": 0.0,
    "math": 0.0,
    "coding": 0.0,
    "reasoning": 0.0,
    "stem": 0.1,
    "humanities": 0.1,
    "arena-hard-200": 0.0,
}

def init_model(model_path, tokenizer_path):
    device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
    print('loading model...')
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16).to(device).eval()
    tokenizer.padding_side = 'left'
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    print('finish loading')
    
    return model, tokenizer, device

def run_eval(
    model_path,
    tokenizer_path,
    question_file,
    answer_file,
    model_tag
):
    questions = []
    with open(question_file, "r") as ques_file:
        for line in ques_file:
            if line:
                questions.append(json.loads(line))

    # random shuffle the questions to balance the loading
    # random.shuffle(questions)
    _model, tokenizer, device = init_model(model_path, tokenizer_path)

    for question in tqdm(questions):
        if question["category"] in temperature_config:
            temperature = temperature_config[question["category"]]
        else:
            temperature = 0.7


        turns = []
        for j in range(len(question["turns"])):
            qs = question["turns"][j]
            turns.append({'role': 'user', 'content': qs})
            # text = tokenizer.apply_chat_template(
            #     turns,
            #     tokenize=False,
            #     add_generation_prompt=True
            # )
            # model_inputs = tokenizer([text], return_tensors="pt").to(device)

            if temperature < 1e-4:
                do_sample = False
            else:
                do_sample = True

            generation_config = {"temperature": temperature, "do_sample": do_sample, "max_new_tokens": 2048}
            model = load_model(_model, tokenizer, model_name=model_tag, generation_config=generation_config)

            # generated_ids = model.generate(
            #     **model_inputs,
            #     do_sample=do_sample,
            #     temperature=temperature,
            #     max_new_tokens=2048,
            #     pad_token_id=tokenizer.eos_token_id
            # )
            # generated_ids = [
            #     output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
            # ]
            # response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
            inputs = turns.copy()

            responses = model.batch_chat([inputs], batch_size=1)
            response = responses[0]
            turns.append({'role': 'assistant', 'content': response})

        # Dump answers
        os.makedirs(os.path.dirname(answer_file), exist_ok=True)
        with open(os.path.expanduser(answer_file), "a") as fout:
            ans_json = {
                "question_id": question["question_id"],
                "turns": turns,
            }
            fout.write(json.dumps(ans_json) + "\n")


def reorg_answer_file(answer_file):
    """Sort by question id and de-duplication"""
    answers = {}
    with open(answer_file, "r") as fin:
        for l in fin:
            qid = json.loads(l)["question_id"]
            answers[qid] = l

    qids = sorted(list(answers.keys()))
    with open(answer_file, "w") as fout:
        for qid in qids:
            fout.write(answers[qid])


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name", type=str, required=True)
    parser.add_argument("--model-path", type=str, required=True)
    parser.add_argument("--model-tag", type=str, required=True)
    parser.add_argument("--tokenizer-path", type=str, default=None)
    parser.add_argument("--answer-file", type=str, default=None)
    parser.add_argument("--seed", type=int, default=42)

    args = parser.parse_args()

    torch.manual_seed(args.seed)

    question_file = f"datasets/question.jsonl"
    if args.answer_file:
        answer_file = args.answer_file
    else:
        answer_file = f"data/model_answer/{args.model_name}.jsonl"

    if args.tokenizer_path is None:
        args.tokenizer_path = args.model_path

    print(f"Output to {answer_file}")

    run_eval(
        model_path=args.model_path,
        tokenizer_path=args.tokenizer_path,
        question_file=question_file,
        answer_file=answer_file,
        model_tag=args.model_tag
    )

    reorg_answer_file(answer_file)
