""" Take a trained peft-model and run some sampled inference. """

import click
import re
import torch
from typing import (
    Text,
    Optional
)
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig


@click.command()
@click.option('--model-path', type=click.Path(exists=True), help='Path to the model.')
def main(
    model_path: str
):
    """ """
    
    config = PeftConfig.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.padding_size = 'left'

    # peft_model = PeftModel.from_pretrained(model, model_path)
    peft_model = model
    peft_model.eval()
    peft_model.to('cuda')

    # Sample some claims
    
    # claims = [
    #     # "Miguel D\u00edaz is a professional baseball player.",
    #     # "Miguel D\u00edaz is from San Cristobal, Dominican Republic.",
    #     "The united states was founded in 1776.",
    #     # "The baseball player Miguel D\u00edaz was born on May 8, 1994.",
    #     # "Miguel D\u00edaz began his professional baseball career in 2015.",
    #     # "Miguel D\u00edaz was signed by the Milwaukee Brewers as an international free agent.",
    #     # "The capital of France is Paris.",
    #     # "Miguel D\u00edaz spent his first three seasons in the Milwaukee Brewers' minor league system.",
    # ]
    claim = "The united states was founded on July 4th, 1776."
    # claim = "The baseball player Miguel D\u00edaz was born on May 8, 1994."
    
    confidence_levels = [
        "would guess",
        "somewhat confident",
        "confident",
        "certain"
    ]

    def _process_sentence(sentence: Text) -> Optional[Text]:
        # print(sentence)
        match = re.search(r"\<\|start_header_id\|\>assistant\<\|end_header_id\|\>(.+?)\<\|eot_id\|\>", sentence, re.DOTALL)
        if match is not None:
            return match.group(1).strip()
        else:
            return None
    
    with torch.no_grad():
        for confidence_level in confidence_levels:
            tokenized = tokenizer.apply_chat_template(
                [
                    [
                        {
                            "role": "user",
                            "content": (
                                f"**Task**: Rewrite the following claim to be less specific until you {confidence_level} it is true: {claim}\n\n"
                                "Your response should only contain the claim itself, without any additional context."
                            )
                        },
                    ]
                ],
                padding=True,
                truncation=True,
                add_generation_prompt=True,
                max_length=256,
                return_tensors='pt'
            )
            
            outputs = peft_model.generate(
                tokenized.to('cuda'),
                max_new_tokens=256,
            )
        
            processed = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=False)
            print(f"{confidence_level}: {processed}")

            
if __name__ == '__main__':
    main()