"""
"""

import click
import os
try:
    import ujson as json
except ImportError:
    import json
from transformers import AutoTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
from vllm import (
    LLM,
    SamplingParams,
)
from src.chat_templates import (
    AnswerShorteningTemplate,
    BleachedAnswerTemplate
)


@click.command()
@click.option(
    "--input-path",
    type=click.Path(exists=True, dir_okay=False),
    help="Path to the input file containing claims to be scored.",
)
@click.option(
    "--output-path",
    type=click.Path(dir_okay=False),
    help="Path to the output file to save the results.",
)
def main(
    input_path,
    output_path
):
    """ """
    sampling_params = SamplingParams(
        temperature=0.7,
        top_p=0.8,
        max_tokens=256,
        repetition_penalty=1.05,
    )

    llm = LLM(
        model="Qwen/Qwen2.5-72B-Instruct-AWQ",
        tensor_parallel_size=2,
        quantization="awq_marlin",
        max_num_batched_tokens=2048,
    )
    
    tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
        "Qwen/Qwen2.5-72B-Instruct-AWQ",
    )
    
    bleached_template = BleachedAnswerTemplate()
    answer_shortening_template = AnswerShorteningTemplate()
    
    with open(input_path, 'r', encoding='utf-8') as file_:
        data = [json.loads(line) for line in file_]
        
    # now we need to rewrite the claims
    inputs = tokenizer.apply_chat_template(
        [
            bleached_template.get_prompt_template(
                **item
            ) for item in data
        ], 
        tokenize=False,
        add_generation_prompt=True
    )
    
    # now we need to run the model
    results = llm.generate(
        inputs,
        sampling_params=sampling_params,
    )
    
    for idx, result in enumerate(results):
        data[idx]["bleached_premise"] = result.outputs[0].text.strip()
        
    # rename claims to backoffs
    for item in data:
        if 'claims' not in item:
            continue
        item['backoffs'] = item['claims']
        del item['claims']

    # now we rewrite all the answers
    all_answers = [
        {
            "idx": idx,
            "bidx": bidx,
            "question": item['question'],
            "general_answer": backoff['backoff'],
        } for idx, item in enumerate(data) for bidx, backoff in enumerate(item['backoffs'])
    ]
    
    inputs = tokenizer.apply_chat_template(
        [
            answer_shortening_template.get_prompt_template(
                **item
            ) for item in all_answers
        ], 
        tokenize=False,
        add_generation_prompt=True
    )

    results = llm.generate(
        inputs,
        sampling_params=sampling_params,
    )

    for idx, result in enumerate(results):
        data[all_answers[idx]['idx']]['backoffs'][all_answers[idx]['bidx']]['short_answer'] = result.outputs[0].text.strip()

    with open(output_path, 'w', encoding='utf-8') as file_:
        for item in data:
            file_.write(json.dumps(item) + '\n')
    
    
if __name__ == "__main__":
    main()