"""We do informativeness scoring on the CLC dataset.
"""

import click
import os
from scipy.special import softmax
import numpy as np
from vllm import LLM, SamplingParams
from vllm import RequestOutput
from transformers import AutoTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
try:
    import ujson as json
except ImportError:
    import json
from src.rank_dicts import (
    SingleLabelRankDict,
    BaseRankDict
)
from src.chat_templates import UNLITemplate


def _fine_grained_score(
    rank_dict: BaseRankDict,
    tokenizer: PreTrainedTokenizer,
    output: RequestOutput
) -> float:
    """ """
    logprobs = output.outputs[0].logprobs[0]
    rd = rank_dict.get_rank_dict(tokenizer)
    
    # sub-select the logprobs
    lk = list(rd.keys())
    selected_logprobs = np.array([logprobs[k].logprob for k in lk if k in logprobs], dtype=np.float32)
    probs = softmax(selected_logprobs)
    es = np.array([rd[k] for k in lk if k in logprobs], dtype=np.float32)

    # calculate the expected score
    return np.dot(probs, es).item()


@click.command()
@click.option(
    "--input-path",
    type=click.Path(exists=True),
    help="Path to the dataset.",
)
def main(
    input_path
):
    """
    """
    with open(input_path, 'r', encoding='utf-8') as file_:
        data = [json.loads(line) for line in file_]
        
    template = UNLITemplate()

    llm = LLM(
        model="./ckpt/merged-dp3",
        max_num_batched_tokens=4096,
        tensor_parallel_size=2,
    )
    
    tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
        "./ckpt/merged-dp3",
    )

    rank_dict = SingleLabelRankDict.from_tokenizer(tokenizer)

    sampling_params = SamplingParams(
        temperature=0.7,
        top_p=0.9,
        max_tokens=1,
        logprobs=20,
        allowed_token_ids=list(rank_dict.get_rank_dict(tokenizer).keys()),
        # logits_processors=self.logits_processor,
    )
    
    all_answers = [
        {
            "idx": idx,
            "bidx": bidx,
            "premise": item['bleached_premise'],
            "hypothesis": backoff['short_answer'],
        } for idx, item in enumerate(data) for bidx, backoff in enumerate(item['backoffs'])
    ]
    
    inputs = [   
        tokenizer.apply_chat_template(
            template.get_prompt_template(
                **answer,
            ) + template.get_completion_template(
                is_completion=True,
            ),
            tokenize=False,
            continue_final_message=True,
        )
        for answer in all_answers
    ]

    results = llm.generate(
        inputs,
        sampling_params=sampling_params,
    )
    assert len(results) == len(inputs), "Results and claims length mismatch."
    
    # match back results and inputs to data
    for answer, result in zip(all_answers, results):
        idx = answer['idx']
        bidx = answer['bidx']
        data[idx]['backoffs'][bidx]['fine'] = _fine_grained_score(rank_dict, tokenizer, result)
        data[idx]['backoffs'][bidx]['coarse'] = rank_dict[result.outputs[0].text.strip()]
        data[idx]['backoffs'][bidx]['completion'] = result.outputs[0].text

    with open(
        os.path.join(
            "data",
            "conformal-backoff",
            "scored",
            os.path.basename(input_path)
        ),
        'w',
        encoding='utf-8'
    ) as file_:
        for item in data:
            file_.write(json.dumps(item) + "\n")
            
            
if __name__ == "__main__":
    main()